├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MODEL_LICENSE
├── README.md
├── example.txt
├── figs
├── 22_0.jpg
├── logo.png
├── methodv3.png
└── pipeline.png
├── flashvideo
├── arguments.py
├── base_model.py
├── base_transformer.py
├── configs
│ ├── stage1.yaml
│ └── stage2.yaml
├── demo.ipynb
├── diffusion_video.py
├── dist_inf_text_file.py
├── dit_video_concat.py
├── extra_models
│ └── dit_res_adapter.py
├── flow_video.py
├── sgm
│ ├── __init__.py
│ ├── lr_scheduler.py
│ ├── models
│ │ ├── __init__.py
│ │ └── autoencoder.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── autoencoding
│ │ │ ├── __init__.py
│ │ │ ├── losses
│ │ │ │ ├── __init__.py
│ │ │ │ ├── discriminator_loss.py
│ │ │ │ ├── lpips.py
│ │ │ │ └── video_loss.py
│ │ │ ├── lpips
│ │ │ │ ├── __init__.py
│ │ │ │ ├── loss
│ │ │ │ │ ├── .gitignore
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── lpips.py
│ │ │ │ ├── model
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── model.py
│ │ │ │ ├── util.py
│ │ │ │ └── vqperceptual.py
│ │ │ ├── magvit2_pytorch.py
│ │ │ ├── regularizers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── finite_scalar_quantization.py
│ │ │ │ ├── lookup_free_quantization.py
│ │ │ │ └── quantize.py
│ │ │ ├── temporal_ae.py
│ │ │ └── vqvae
│ │ │ │ ├── movq_dec_3d.py
│ │ │ │ ├── movq_dec_3d_dev.py
│ │ │ │ ├── movq_enc_3d.py
│ │ │ │ ├── movq_modules.py
│ │ │ │ ├── quantize.py
│ │ │ │ └── vqvae_blocks.py
│ │ ├── cp_enc_dec.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── denoiser.py
│ │ │ ├── denoiser_scaling.py
│ │ │ ├── denoiser_weighting.py
│ │ │ ├── discretizer.py
│ │ │ ├── guiders.py
│ │ │ ├── lora.py
│ │ │ ├── loss.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ ├── sampling.py
│ │ │ ├── sampling_utils.py
│ │ │ ├── sigma_sampling.py
│ │ │ ├── util.py
│ │ │ └── wrappers.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ └── video_attention.py
│ ├── util.py
│ └── webds.py
├── utils.py
└── vae_modules
│ ├── attention.py
│ ├── autoencoder.py
│ ├── cp_enc_dec.py
│ ├── ema.py
│ ├── regularizers.py
│ └── utils.py
├── inf_270_1080p.sh
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__/
2 | # *.png
3 | *.pt
4 | 1080_rope_ana
5 | 1080_rope_ana/
6 | samples*/
7 | runs/
8 | checkpoints/
9 | master_ip
10 | logs/
11 | *.DS_Store
12 | .idea
13 | output*
14 | test*
15 | cogvideox-5b-sat
16 | cogvideox-5b-sat/
17 | *.log
18 | */mini_tools/*
19 | *.mp4
20 | *vis_*
21 | *.pkl
22 | *.mp4
23 | vis_*
24 | checkpoints
25 | checkpoints/
26 | __pycache__/*
27 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: ^tests/data/
2 | repos:
3 | - repo: https://github.com/PyCQA/flake8
4 | rev: 5.0.4
5 | hooks:
6 | - id: flake8
7 | - repo: https://github.com/PyCQA/isort
8 | rev: 5.11.5
9 | hooks:
10 | - id: isort
11 | - repo: https://github.com/pre-commit/mirrors-yapf
12 | rev: v0.32.0
13 | hooks:
14 | - id: yapf
15 | - repo: https://github.com/pre-commit/pre-commit-hooks
16 | rev: v4.3.0
17 | hooks:
18 | - id: trailing-whitespace
19 | - id: check-yaml
20 | - id: end-of-file-fixer
21 | - id: requirements-txt-fixer
22 | - id: double-quote-string-fixer
23 | - id: check-merge-conflict
24 | - id: fix-encoding-pragma
25 | args: ["--remove"]
26 | - id: mixed-line-ending
27 | args: ["--fix=lf"]
28 | - repo: https://github.com/codespell-project/codespell
29 | rev: v2.2.1
30 | hooks:
31 | - id: codespell
32 | - repo: https://github.com/executablebooks/mdformat
33 | rev: 0.7.9
34 | hooks:
35 | - id: mdformat
36 | args: ["--number"]
37 | additional_dependencies:
38 | - mdformat-openmmlab
39 | - mdformat_frontmatter
40 | - linkify-it-py
41 | - repo: https://github.com/asottile/pyupgrade
42 | rev: v3.0.0
43 | hooks:
44 | - id: pyupgrade
45 | args: ["--py36-plus"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 CogVideo Model Team @ Zhipu AI
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | The CogVideoX License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means the CogVideoX Model Team that distributes its Software.
6 |
7 | “Software” means the CogVideoX model parameters made available under this license.
8 |
9 | 2. License Grant
10 |
11 | Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws.
12 | This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form .
13 | Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month.
14 | If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses.
15 | The above copyright statement and this license statement should be included in all copies or significant portions of this software.
16 |
17 | 3. Restriction
18 |
19 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
20 |
21 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
22 |
23 | 4. Disclaimer
24 |
25 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
26 |
27 | 5. Limitation of Liability
28 |
29 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
30 |
31 | 6. Dispute Resolution
32 |
33 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
34 |
35 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
36 |
37 | 1. 定义
38 |
39 | “许可方”是指分发其软件的 CogVideoX 模型团队。
40 |
41 | “软件”是指根据本许可提供的 CogVideoX 模型参数。
42 |
43 | 2. 许可授予
44 |
45 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。
46 | 本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。
47 |
48 | 经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
49 | 在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。
50 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
51 |
52 | 3.限制
53 |
54 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
55 |
56 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
57 |
58 | 4.免责声明
59 |
60 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。
61 | 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
62 |
63 | 5. 责任限制
64 |
65 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
66 |
67 | 6.争议解决
68 |
69 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
70 |
71 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
7 |
8 | # Flowing Fidelity to Detail for Efficient High-Resolution Video Generation
9 |
10 | [](https://arxiv.org/abs/2502.05179)
11 | [](https://jshilong.github.io/flashvideo-page/)
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | > [**FlashVideo: Flowing Fidelity to Detail for Efficient High-Resolution Video Generation**](https://arxiv.org/abs/2502.05179)
22 | > [Shilong Zhang](https://jshilong.github.io/), [Wenbo Li](https://scholar.google.com/citations?user=foGn_TIAAAAJ&hl=en), [Shoufa Chen](https://www.shoufachen.com/), [Chongjian Ge](https://chongjiange.github.io/), [Peize Sun](https://peizesun.github.io/),
[Yida Zhang](<>), [Yi Jiang](https://enjoyyi.github.io/), [Zehuan Yuan](https://shallowyuan.github.io/), [Bingyue Peng](<>), [Ping Luo](http://luoping.me/),
23 | >
HKU, CUHK, ByteDance
24 |
25 | ## 🤗 More video examples 👀 can be accessed at the [](https://jshilong.github.io/flashvideo-page/)
26 |
27 |
31 |
32 | #### ⚡⚡ User Prompt to 270p, NFE = 50, Takes ~30s ⚡⚡
33 | #### ⚡⚡ 270p to 1080p , NFE = 4, Takes ~72s ⚡⚡
34 |
35 | [![]()](https://github.com/FoundationVision/flashvideo-page/blob/main/static/images/output.gif)
36 |
37 |
38 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
53 |
54 |
55 | ## 🔥 Update
56 |
57 | - \[2025.02.10\] 🔥 🔥 🔥 Inference code and both stage model [weights](https://huggingface.co/FoundationVision/FlashVideo/tree/main) have been released.
58 |
59 | ## 🌿 Introduction
60 | In this repository, we provide:
61 |
62 | - [x] The stage-I weight for 270P video generation.
63 | - [x] The stage-II for enhancing 270P video to 1080P.
64 | - [x] Inference code of both stages.
65 | - [ ] Training code and related augmentation. Work in process [PR#12](https://github.com/FoundationVision/FlashVideo/pull/12)
66 | - [x] Loss function
67 | - [ ] Dataset and augmentation
68 | - [ ] Configuration and training script
69 | - [ ] Implementation with diffusers.
70 | - [ ] Gradio.
71 |
72 |
73 | ## Install
74 |
75 | ### 1. Environment Setup
76 |
77 | This repository is tested with PyTorch 2.4.0+cu121 and Python 3.11.11. You can install the necessary dependencies using the following command:
78 |
79 | ```shell
80 | pip install -r requirements.txt
81 | ```
82 |
83 | ### 2. Preparing the Checkpoints
84 |
85 | To get the 3D VAE (identical to CogVideoX), along with Stage-I and Stage-II weights, set them up as follows:
86 |
87 | ```shell
88 | cd FlashVideo
89 | mkdir -p ./checkpoints
90 | huggingface-cli download --local-dir ./checkpoints FoundationVision/FlashVideo
91 | ```
92 |
93 | The checkpoints should be organized as shown below:
94 |
95 | ```
96 | ├── 3d-vae.pt
97 | ├── stage1.pt
98 | └── stage2.pt
99 | ```
100 |
101 | ## 🚀 Text to Video Generation
102 |
103 | #### ⚠️ IMPORTANT NOTICE ⚠️ : Both stage-I and stage-II are trained with long prompts only. For achieving the best results, include comprehensive and detailed descriptions in your prompts, akin to the example provided in [example.txt](./example.txt).
104 |
105 | ### Jupyter Notebook
106 |
107 | You can conveniently provide user prompts in our Jupyter notebook. The default configuration for spatial and temporal slices in the VAE Decoder is tailored for an 80G GPU. For GPUs with less memory, one might consider increasing the [spatial and temporal slice](https://github.com/FoundationVision/FlashVideo/blob/400a9c1ef905eab3a1cb6b9f5a5a4c331378e4b5/sat/utils.py#L110).
108 |
109 |
110 | ```python
111 | flashvideo/demo.ipynb
112 | ```
113 |
114 | ### Inferring from a Text File Containing Prompts
115 |
116 | You can conveniently provide the user prompt in a text file and generate videos with multiple gpus.
117 |
118 | ```python
119 | bash inf_270_1080p.sh
120 | ```
121 |
122 | ## License
123 |
124 | This project is developed based on [CogVideoX](https://github.com/THUDM/CogVideo). Please refer to their original [license](https://github.com/THUDM/CogVideo?tab=readme-ov-file#model-license) for usage details.
125 |
126 | ## BibTeX
127 |
128 | ```bibtex
129 | @article{zhang2025flashvideo,
130 | title={FlashVideo: Flowing Fidelity to Detail for Efficient High-Resolution Video Generation},
131 | author={Zhang, Shilong and Li, Wenbo and Chen, Shoufa and Ge, Chongjian and Sun, Peize and Zhang, Yida and Jiang, Yi and Yuan, Zehuan and Peng, Binyue and Luo, Ping},
132 | journal={arXiv preprint arXiv:2502.05179},
133 | year={2025}
134 | }
135 | ```
136 |
--------------------------------------------------------------------------------
/figs/22_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/22_0.jpg
--------------------------------------------------------------------------------
/figs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/logo.png
--------------------------------------------------------------------------------
/figs/methodv3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/methodv3.png
--------------------------------------------------------------------------------
/figs/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/pipeline.png
--------------------------------------------------------------------------------
/flashvideo/configs/stage1.yaml:
--------------------------------------------------------------------------------
1 | share_cache_args:
2 | disable_ref : True
3 | num_vis_img: 4
4 | vis_ddpm: True
5 | eval_interval_list: [1, 50, 100, 1000]
6 | save_interval_list: [2000]
7 |
8 | args:
9 | checkpoint_activations: False # using gradient checkpointing
10 | model_parallel_size: 1
11 | experiment_name: lora-disney
12 | mode: finetune
13 | load: ""
14 | no_load_rng: True
15 | train_iters: 10000000000 # Suggest more than 1000 For Lora and SFT For 500 is enough
16 | eval_iters: 100000000
17 | eval_interval: 1000000000000
18 | eval_batch_size: 1
19 | save: ./
20 | save_interval: 1000
21 | log_interval: 20
22 | train_data: [ "disney" ] # Train data path
23 | valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended)
24 | split: 1,0,0
25 | num_workers: 2
26 | force_train: True
27 | only_log_video_latents: True
28 |
29 |
30 | deepspeed:
31 | # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
32 | train_micro_batch_size_per_gpu: 1
33 | gradient_accumulation_steps: 1
34 | steps_per_print: 50
35 | gradient_clipping: 0.1
36 | zero_optimization:
37 | stage: 2
38 | cpu_offload: false
39 | contiguous_gradients: false
40 | overlap_comm: true
41 | reduce_scatter: true
42 | reduce_bucket_size: 1000000000
43 | allgather_bucket_size: 1000000000
44 | load_from_fp32_weights: false
45 | zero_allow_untested_optimizer: true
46 | bf16:
47 | enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
48 | fp16:
49 | enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
50 | loss_scale: 0
51 | loss_scale_window: 400
52 | hysteresis: 2
53 | min_loss_scale: 1
54 |
55 |
56 |
57 | model:
58 | scale_factor: 0.7
59 | disable_first_stage_autocast: true
60 | log_keys:
61 | - txt
62 |
63 | denoiser_config:
64 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
65 | params:
66 | num_idx: 1000
67 | quantize_c_noise: False
68 |
69 | weighting_config:
70 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
71 | scaling_config:
72 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
73 | discretization_config:
74 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
75 | params:
76 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
77 |
78 | network_config:
79 | target: extra_models.dit_res_adapter.SMALLDiffusionTransformer
80 | params:
81 | time_embed_dim: 512
82 | elementwise_affine: True
83 | num_frames: 64
84 | time_compressed_rate: 4
85 | # latent_width: 90
86 | # latent_height: 60
87 | latent_width: 512
88 | latent_height: 512
89 | num_layers: 42 # different from cogvideox_2b_infer.yaml
90 | patch_size: 2
91 | in_channels: 16
92 | out_channels: 16
93 | hidden_size: 3072 # different from cogvideox_2b_infer.yaml
94 | adm_in_channels: 256
95 | num_attention_heads: 48 # different from cogvideox_2b_infer.yaml
96 |
97 | transformer_args:
98 | checkpoint_activations: False
99 | vocab_size: 1
100 | max_sequence_length: 64
101 | layernorm_order: pre
102 | skip_init: false
103 | model_parallel_size: 1
104 | is_decoder: false
105 |
106 | modules:
107 | pos_embed_config:
108 | target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml
109 | params:
110 | hidden_size_head: 64
111 | text_length: 226
112 |
113 | lora_config:
114 | target: extra_models.dit_res_adapter.ResLoraMixin
115 | params:
116 | r: 128
117 |
118 | patch_embed_config:
119 | target: dit_video_concat.ImagePatchEmbeddingMixin
120 | params:
121 | text_hidden_size: 4096
122 |
123 | adaln_layer_config:
124 | target: dit_video_concat.AdaLNMixin
125 | params:
126 | qk_ln: True
127 |
128 | final_layer_config:
129 | target: dit_video_concat.FinalLayerMixin
130 |
131 | conditioner_config:
132 | target: sgm.modules.GeneralConditioner
133 | params:
134 | emb_models:
135 | - is_trainable: false
136 | input_key: txt
137 | ucg_rate: 0.1
138 | target: sgm.modules.encoders.modules.FrozenT5Embedder
139 | params:
140 | model_dir: "google/t5-v1_1-xxl"
141 | max_length: 226
142 |
143 | first_stage_config:
144 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
145 | params:
146 | cp_size: 1
147 | ckpt_path: "checkpoints/3d-vae.pt"
148 | ignore_keys: [ 'loss' ]
149 |
150 | loss_config:
151 | target: torch.nn.Identity
152 |
153 | regularizer_config:
154 | target: vae_modules.regularizers.DiagonalGaussianRegularizer
155 |
156 | encoder_config:
157 | target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D
158 | params:
159 | double_z: true
160 | z_channels: 16
161 | resolution: 256
162 | in_channels: 3
163 | out_ch: 3
164 | ch: 128
165 | ch_mult: [ 1, 2, 2, 4 ]
166 | attn_resolutions: [ ]
167 | num_res_blocks: 3
168 | dropout: 0.0
169 | gather_norm: True
170 |
171 | decoder_config:
172 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
173 | params:
174 | double_z: True
175 | z_channels: 16
176 | resolution: 256
177 | in_channels: 3
178 | out_ch: 3
179 | ch: 128
180 | ch_mult: [ 1, 2, 2, 4 ]
181 | attn_resolutions: [ ]
182 | num_res_blocks: 3
183 | dropout: 0.0
184 | gather_norm: False
185 |
186 | loss_fn_config:
187 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
188 | params:
189 | offset_noise_level: 0
190 | sigma_sampler_config:
191 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
192 | params:
193 | uniform_sampling: True
194 | num_idx: 1000
195 | discretization_config:
196 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
197 | params:
198 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
199 |
200 | sampler_config:
201 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
202 | params:
203 | num_steps: 51
204 | verbose: True
205 |
206 | discretization_config:
207 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
208 | params:
209 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
210 |
211 | guider_config:
212 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG
213 | params:
214 | # TODO check this cfg
215 | scale: 8
216 | exp: 5
217 | num_steps: 51
218 |
--------------------------------------------------------------------------------
/flashvideo/configs/stage2.yaml:
--------------------------------------------------------------------------------
1 | custom_args:
2 | reload: ""
3 |
4 | share_cache_args:
5 | sample_ref_noise_step: 675
6 | time_size_embedding: True
7 |
8 | args:
9 | checkpoint_activations: True # using gradient checkpointing
10 | model_parallel_size: 1
11 | experiment_name: lora-disney
12 | mode: finetune
13 | load: "" # This is for Full model without lora adapter " # This is for Full model without lora adapter
14 | no_load_rng: True
15 | train_iters: 100000 # Suggest more than 1000 For Lora and SFT For 500 is enough
16 | eval_iters: 100000000
17 | eval_interval: [1, 200]
18 | eval_batch_size: 1
19 | save:
20 | # for debug
21 | save_interval: 250
22 | log_interval: 5
23 | train_data: [ "disney" ] # Train data path
24 | valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended)
25 | split: 1,0,0
26 | num_workers: 1
27 | force_train: True
28 | only_log_video_latents: True
29 |
30 |
31 | deepspeed:
32 | # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
33 | train_micro_batch_size_per_gpu: 1
34 | gradient_accumulation_steps: 1
35 | steps_per_print: 50
36 | gradient_clipping: 0.1
37 | zero_optimization:
38 | stage: 2
39 | cpu_offload: false
40 | contiguous_gradients: false
41 | overlap_comm: true
42 | reduce_scatter: true
43 | reduce_bucket_size: 1000000000
44 | allgather_bucket_size: 1000000000
45 | load_from_fp32_weights: false
46 | zero_allow_untested_optimizer: true
47 | bf16:
48 | enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
49 | fp16:
50 | enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
51 | loss_scale: 0
52 | loss_scale_window: 400
53 | hysteresis: 2
54 | min_loss_scale: 1
55 |
56 |
57 | model:
58 | scale_factor: 1.15258426
59 | disable_first_stage_autocast: true
60 | log_keys:
61 | - txt
62 |
63 | denoiser_config:
64 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
65 | params:
66 | num_idx: 1000
67 | quantize_c_noise: False
68 |
69 | weighting_config:
70 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
71 | scaling_config:
72 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
73 | discretization_config:
74 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
75 | params:
76 | shift_scale: 3.0 # different from cogvideox_2b_infer.yaml
77 |
78 | network_config:
79 | target: extra_models.dit_res_adapter.SMALLDiffusionTransformer
80 | params:
81 | time_embed_dim: 512
82 | elementwise_affine: True
83 | num_frames: 64
84 | time_compressed_rate: 4
85 | # latent_width: 90
86 | # latent_height: 60
87 | latent_width: 512
88 | latent_height: 512
89 | num_layers: 30 # different from cogvideox_2b_infer.yaml
90 | patch_size: 2
91 | in_channels: 16
92 | out_channels: 16
93 | hidden_size: 1920 # different from cogvideox_2b_infer.yaml
94 | adm_in_channels: 256
95 | num_attention_heads: 30 # different from cogvideox_2b_infer.yaml
96 |
97 | transformer_args:
98 | checkpoint_activations: True
99 | vocab_size: 1
100 | max_sequence_length: 64
101 | layernorm_order: pre
102 | skip_init: false
103 | model_parallel_size: 1
104 | is_decoder: false
105 |
106 | modules:
107 | pos_embed_config:
108 | target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml
109 | params:
110 | hidden_size_head: 64
111 | text_length: 226
112 |
113 | patch_embed_config:
114 | target: dit_video_concat.ImagePatchEmbeddingMixin
115 | params:
116 | text_hidden_size: 4096
117 |
118 | adaln_layer_config:
119 | target: dit_video_concat.AdaLNMixin
120 | params:
121 | qk_ln: True
122 |
123 | final_layer_config:
124 | target: dit_video_concat.FinalLayerMixin
125 |
126 | conditioner_config:
127 | target: sgm.modules.GeneralConditioner
128 | params:
129 | emb_models:
130 | - is_trainable: false
131 | input_key: txt
132 | ucg_rate: 0.1
133 | target: sgm.modules.encoders.modules.FrozenT5Embedder
134 | params:
135 | model_dir: "google/t5-v1_1-xxl"
136 | max_length: 226
137 |
138 | first_stage_config:
139 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
140 | params:
141 | cp_size: 1
142 | ckpt_path: "checkpoints/3d-vae.pt"
143 | ignore_keys: [ 'loss' ]
144 |
145 | loss_config:
146 | target: torch.nn.Identity
147 |
148 | regularizer_config:
149 | target: vae_modules.regularizers.DiagonalGaussianRegularizer
150 |
151 | encoder_config:
152 | target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D
153 | params:
154 | double_z: true
155 | z_channels: 16
156 | resolution: 256
157 | in_channels: 3
158 | out_ch: 3
159 | ch: 128
160 | ch_mult: [ 1, 2, 2, 4 ]
161 | attn_resolutions: [ ]
162 | num_res_blocks: 3
163 | dropout: 0.0
164 | gather_norm: True
165 |
166 | decoder_config:
167 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
168 | params:
169 | double_z: True
170 | z_channels: 16
171 | resolution: 256
172 | in_channels: 3
173 | out_ch: 3
174 | ch: 128
175 | ch_mult: [ 1, 2, 2, 4 ]
176 | attn_resolutions: [ ]
177 | num_res_blocks: 3
178 | dropout: 0.0
179 | gather_norm: False
180 |
181 | loss_fn_config:
182 | target: flow_video.FlowVideoDiffusionLoss
183 | params:
184 | offset_noise_level: 0
185 | sigma_sampler_config:
186 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
187 | params:
188 | uniform_sampling: False
189 | num_idx: 1000
190 | discretization_config:
191 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
192 | params:
193 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
194 |
195 | sampler_config:
196 | target: sgm.modules.diffusionmodules.sampling.CascadeVPSDEDPMPP2MSampler
197 | params:
198 | num_steps: 50
199 | verbose: True
200 |
201 | discretization_config:
202 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
203 | params:
204 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
205 |
206 | guider_config:
207 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG
208 | params:
209 | scale: 6
210 | exp: 5
211 | num_steps: 50
212 |
--------------------------------------------------------------------------------
/flashvideo/flow_video.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn as nn
7 | from sgm.modules import UNCONDITIONAL_CONFIG
8 | from sgm.modules.autoencoding.temporal_ae import VideoDecoder
9 | from sgm.modules.diffusionmodules.loss import StandardDiffusionLoss
10 | from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
11 | from sgm.util import (append_dims, default, disabled_train, get_obj_from_str,
12 | instantiate_from_config)
13 | from torch import nn
14 | from torchdiffeq import odeint
15 |
16 |
17 | class FlowEngine(nn.Module):
18 |
19 | def __init__(self, args, **kwargs):
20 | super().__init__()
21 | model_config = args.model_config
22 | log_keys = model_config.get('log_keys', None)
23 | input_key = model_config.get('input_key', 'mp4')
24 | network_config = model_config.get('network_config', None)
25 | network_wrapper = model_config.get('network_wrapper', None)
26 | denoiser_config = model_config.get('denoiser_config', None)
27 | sampler_config = model_config.get('sampler_config', None)
28 | conditioner_config = model_config.get('conditioner_config', None)
29 | first_stage_config = model_config.get('first_stage_config', None)
30 | loss_fn_config = model_config.get('loss_fn_config', None)
31 | scale_factor = model_config.get('scale_factor', 1.0)
32 | latent_input = model_config.get('latent_input', False)
33 | disable_first_stage_autocast = model_config.get(
34 | 'disable_first_stage_autocast', False)
35 | no_cond_log = model_config.get('disable_first_stage_autocast', False)
36 | not_trainable_prefixes = model_config.get(
37 | 'not_trainable_prefixes', ['first_stage_model', 'conditioner'])
38 | compile_model = model_config.get('compile_model', False)
39 | en_and_decode_n_samples_a_time = model_config.get(
40 | 'en_and_decode_n_samples_a_time', None)
41 | lr_scale = model_config.get('lr_scale', None)
42 | lora_train = model_config.get('lora_train', False)
43 | self.use_pd = model_config.get('use_pd', False)
44 |
45 | self.log_keys = log_keys
46 | self.input_key = input_key
47 | self.not_trainable_prefixes = not_trainable_prefixes
48 | self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
49 | self.lr_scale = lr_scale
50 | self.lora_train = lora_train
51 | self.noised_image_input = model_config.get('noised_image_input', False)
52 | self.noised_image_all_concat = model_config.get(
53 | 'noised_image_all_concat', False)
54 | self.noised_image_dropout = model_config.get('noised_image_dropout',
55 | 0.0)
56 | if args.fp16:
57 | dtype = torch.float16
58 | dtype_str = 'fp16'
59 | elif args.bf16:
60 | dtype = torch.bfloat16
61 | dtype_str = 'bf16'
62 | else:
63 | dtype = torch.float32
64 | dtype_str = 'fp32'
65 | self.dtype = dtype
66 | self.dtype_str = dtype_str
67 |
68 | network_config['params']['dtype'] = dtype_str
69 | model = instantiate_from_config(network_config)
70 | self.model = get_obj_from_str(
71 | default(network_wrapper,
72 | OPENAIUNETWRAPPER))(model,
73 | compile_model=compile_model,
74 | dtype=dtype)
75 |
76 | self.denoiser = instantiate_from_config(denoiser_config)
77 | self.sampler = instantiate_from_config(
78 | sampler_config) if sampler_config is not None else None
79 | self.conditioner = instantiate_from_config(
80 | default(conditioner_config, UNCONDITIONAL_CONFIG))
81 |
82 | self._init_first_stage(first_stage_config)
83 |
84 | self.loss_fn = instantiate_from_config(
85 | loss_fn_config) if loss_fn_config is not None else None
86 |
87 | self.latent_input = latent_input
88 | self.scale_factor = scale_factor
89 | self.disable_first_stage_autocast = disable_first_stage_autocast
90 | self.no_cond_log = no_cond_log
91 | self.device = args.device
92 |
93 | def disable_untrainable_params(self):
94 | pass
95 |
96 | def reinit(self, parent_model=None):
97 | pass
98 |
99 | def _init_first_stage(self, config):
100 | model = instantiate_from_config(config).eval()
101 | model.train = disabled_train
102 | for param in model.parameters():
103 | param.requires_grad = False
104 | self.first_stage_model = model
105 |
106 | def get_input(self, batch):
107 | return batch[self.input_key].to(self.dtype)
108 |
109 | @torch.no_grad()
110 | def decode_first_stage(self, z):
111 | z = 1.0 / self.scale_factor * z
112 | n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
113 | n_rounds = math.ceil(z.shape[0] / n_samples)
114 | all_out = []
115 | with torch.autocast('cuda',
116 | enabled=not self.disable_first_stage_autocast):
117 | for n in range(n_rounds):
118 | if isinstance(self.first_stage_model.decoder, VideoDecoder):
119 | kwargs = {
120 | 'timesteps': len(z[n * n_samples:(n + 1) * n_samples])
121 | }
122 | else:
123 | kwargs = {}
124 | out = self.first_stage_model.decode(
125 | z[n * n_samples:(n + 1) * n_samples], **kwargs)
126 | all_out.append(out)
127 | out = torch.cat(all_out, dim=0)
128 | return out
129 |
130 | @torch.no_grad()
131 | def encode_first_stage(self, x, batch):
132 | frame = x.shape[2]
133 |
134 | if frame > 1 and self.latent_input:
135 | x = x.permute(0, 2, 1, 3, 4).contiguous()
136 | return x * self.scale_factor # already encoded
137 |
138 | n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
139 | n_rounds = math.ceil(x.shape[0] / n_samples)
140 | all_out = []
141 | with torch.autocast('cuda',
142 | enabled=not self.disable_first_stage_autocast):
143 | for n in range(n_rounds):
144 | out = self.first_stage_model.encode(x[n * n_samples:(n + 1) *
145 | n_samples])
146 | all_out.append(out)
147 | z = torch.cat(all_out, dim=0)
148 | z = self.scale_factor * z
149 | return z
150 |
151 | @torch.no_grad()
152 | def save_memory_encode_first_stage(self, x, batch):
153 | splits_x = torch.split(x, [13, 12, 12, 12], dim=2)
154 |
155 | all_out = []
156 |
157 | with torch.autocast('cuda', enabled=False):
158 | for idx, input_x in enumerate(splits_x):
159 | if idx == len(splits_x) - 1:
160 | clear_fake_cp_cache = True
161 | else:
162 | clear_fake_cp_cache = False
163 | out = self.first_stage_model.encode(
164 | input_x.contiguous(),
165 | clear_fake_cp_cache=clear_fake_cp_cache)
166 | all_out.append(out)
167 |
168 | z = torch.cat(all_out, dim=2)
169 | z = self.scale_factor * z
170 | return z
171 |
172 | def single_function_evaluation(self,
173 | t,
174 | x,
175 | cond=None,
176 | uc=None,
177 | cfg=1,
178 | **kwargs):
179 | start_time = time.time()
180 | # for CFG
181 | x = torch.cat([x] * 2)
182 | t = t.reshape(1).to(x.dtype).to(x.device)
183 | t = torch.cat([t] * 2)
184 | idx = 1000 - (t * 1000)
185 |
186 | real_cond = dict()
187 | for k, v in cond.items():
188 | uncond_v = uc[k]
189 | real_cond[k] = torch.cat([v, uncond_v])
190 |
191 | vt = self.model(x, t=idx, c=real_cond, idx=idx)
192 | vt, uc_vt = vt.chunk(2)
193 | vt = uc_vt + cfg * (vt - uc_vt)
194 | end_time = time.time()
195 | print(f'single_function_evaluation time at {t}', end_time - start_time)
196 | return vt
197 |
198 | @torch.no_grad()
199 | def sample(
200 | self,
201 | ref_x,
202 | cond,
203 | uc,
204 | **sample_kwargs,
205 | ):
206 | """Stage 2 Sampling, start from the first stage results `ref_x`
207 |
208 | Args:
209 | ref_x (_type_): Stage1 low resolution video
210 | cond (dict): Dict contains condtion embeddings
211 | uc (dict): Dict contains uncondition embedding
212 |
213 | Returns:
214 | Tensor: Secondary stage results
215 | """
216 |
217 | sample_kwargs = sample_kwargs or {}
218 | print('sample_kwargs', sample_kwargs)
219 | # timesteps
220 | num_steps = sample_kwargs.get('num_steps', 4)
221 | t = torch.linspace(0, 1, num_steps + 1,
222 | dtype=ref_x.dtype).to(ref_x.device)
223 | print(self.share_cache['shift_t'])
224 | shift_t = float(self.share_cache['shift_t'])
225 | t = 1 - shift_t * (1 - t) / (1 + (shift_t - 1) * (1 - t))
226 |
227 | print('sample:', t)
228 | t = t
229 | single_function_evaluation = partial(self.single_function_evaluation,
230 | cond=cond,
231 | uc=uc,
232 | cfg=sample_kwargs.get('cfg', 1))
233 |
234 | ref_noise_step = self.share_cache['sample_ref_noise_step']
235 | print(f'ref_noise_step : {ref_noise_step}')
236 |
237 | ref_alphas_cumprod_sqrt = self.loss_fn.sigma_sampler.idx_to_sigma(
238 | torch.zeros(ref_x.shape[0]).fill_(ref_noise_step).long())
239 | ref_alphas_cumprod_sqrt = ref_alphas_cumprod_sqrt.to(ref_x.device)
240 | ori_dtype = ref_x.dtype
241 |
242 | ref_noise = torch.randn_like(ref_x)
243 | print('weight', ref_alphas_cumprod_sqrt, flush=True)
244 |
245 | ref_noised_input = ref_x * append_dims(ref_alphas_cumprod_sqrt, ref_x.ndim) \
246 | + ref_noise * append_dims(
247 | (1 - ref_alphas_cumprod_sqrt**2) ** 0.5, ref_x.ndim
248 | )
249 | ref_x = ref_noised_input.to(ori_dtype)
250 | self.share_cache['ref_x'] = ref_x
251 |
252 | results = odeint(single_function_evaluation,
253 | ref_x,
254 | t,
255 | method=sample_kwargs.get('method', 'euler'),
256 | atol=1e-6,
257 | rtol=1e-3)[-1]
258 |
259 | return results
260 |
261 |
262 | class FlowVideoDiffusionLoss(StandardDiffusionLoss):
263 |
264 | def __init__(self,
265 | block_scale=None,
266 | block_size=None,
267 | min_snr_value=None,
268 | fixed_frames=0,
269 | **kwargs):
270 | self.fixed_frames = fixed_frames
271 | self.block_scale = block_scale
272 | self.block_size = block_size
273 | self.min_snr_value = min_snr_value
274 | self.schedule = None
275 | super().__init__(**kwargs)
276 |
277 | def __call__(self, network, denoiser, conditioner, input, batch):
278 | pass
279 |
--------------------------------------------------------------------------------
/flashvideo/sgm/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import AutoencodingEngine
2 | from .util import get_configs_path, instantiate_from_config
3 |
4 | __version__ = '0.1.0'
5 |
--------------------------------------------------------------------------------
/flashvideo/sgm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 |
9 | def __init__(
10 | self,
11 | warm_up_steps,
12 | lr_min,
13 | lr_max,
14 | lr_start,
15 | max_decay_steps,
16 | verbosity_interval=0,
17 | ):
18 | self.lr_warm_up_steps = warm_up_steps
19 | self.lr_start = lr_start
20 | self.lr_min = lr_min
21 | self.lr_max = lr_max
22 | self.lr_max_decay_steps = max_decay_steps
23 | self.last_lr = 0.0
24 | self.verbosity_interval = verbosity_interval
25 |
26 | def schedule(self, n, **kwargs):
27 | if self.verbosity_interval > 0:
28 | if n % self.verbosity_interval == 0:
29 | print(
30 | f'current step: {n}, recent lr-multiplier: {self.last_lr}')
31 | if n < self.lr_warm_up_steps:
32 | lr = (self.lr_max -
33 | self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
34 | self.last_lr = lr
35 | return lr
36 | else:
37 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps -
38 | self.lr_warm_up_steps)
39 | t = min(t, 1.0)
40 | lr = self.lr_min + 0.5 * (self.lr_max -
41 | self.lr_min) * (1 + np.cos(t * np.pi))
42 | self.last_lr = lr
43 | return lr
44 |
45 | def __call__(self, n, **kwargs):
46 | return self.schedule(n, **kwargs)
47 |
48 |
49 | class LambdaWarmUpCosineScheduler2:
50 | """
51 | supports repeated iterations, configurable via lists
52 | note: use with a base_lr of 1.0.
53 | """
54 |
55 | def __init__(self,
56 | warm_up_steps,
57 | f_min,
58 | f_max,
59 | f_start,
60 | cycle_lengths,
61 | verbosity_interval=0):
62 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(
63 | f_start) == len(cycle_lengths)
64 | self.lr_warm_up_steps = warm_up_steps
65 | self.f_start = f_start
66 | self.f_min = f_min
67 | self.f_max = f_max
68 | self.cycle_lengths = cycle_lengths
69 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
70 | self.last_f = 0.0
71 | self.verbosity_interval = verbosity_interval
72 |
73 | def find_in_interval(self, n):
74 | interval = 0
75 | for cl in self.cum_cycles[1:]:
76 | if n <= cl:
77 | return interval
78 | interval += 1
79 |
80 | def schedule(self, n, **kwargs):
81 | cycle = self.find_in_interval(n)
82 | n = n - self.cum_cycles[cycle]
83 | if self.verbosity_interval > 0:
84 | if n % self.verbosity_interval == 0:
85 | print(
86 | f'current step: {n}, recent lr-multiplier: {self.last_f}, '
87 | f'current cycle {cycle}')
88 | if n < self.lr_warm_up_steps[cycle]:
89 | f = (self.f_max[cycle] - self.f_start[cycle]
90 | ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
91 | self.last_f = f
92 | return f
93 | else:
94 | t = (n - self.lr_warm_up_steps[cycle]) / (
95 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
96 | t = min(t, 1.0)
97 | f = self.f_min[cycle] + 0.5 * (
98 | self.f_max[cycle] - self.f_min[cycle]) * (1 +
99 | np.cos(t * np.pi))
100 | self.last_f = f
101 | return f
102 |
103 | def __call__(self, n, **kwargs):
104 | return self.schedule(n, **kwargs)
105 |
106 |
107 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
108 |
109 | def schedule(self, n, **kwargs):
110 | cycle = self.find_in_interval(n)
111 | n = n - self.cum_cycles[cycle]
112 | if self.verbosity_interval > 0:
113 | if n % self.verbosity_interval == 0:
114 | print(
115 | f'current step: {n}, recent lr-multiplier: {self.last_f}, '
116 | f'current cycle {cycle}')
117 |
118 | if n < self.lr_warm_up_steps[cycle]:
119 | f = (self.f_max[cycle] - self.f_start[cycle]
120 | ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
121 | self.last_f = f
122 | return f
123 | else:
124 | f = (self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) *
125 | (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]))
126 | self.last_f = f
127 | return f
128 |
--------------------------------------------------------------------------------
/flashvideo/sgm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .autoencoder import AutoencodingEngine
2 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoders.modules import GeneralConditioner
2 |
3 | UNCONDITIONAL_CONFIG = {
4 | 'target': 'sgm.modules.GeneralConditioner',
5 | 'params': {
6 | 'emb_models': []
7 | },
8 | }
9 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/__init__.py
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/losses/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | 'GeneralLPIPSWithDiscriminator',
3 | 'LatentLPIPS',
4 | ]
5 |
6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator
7 | from .lpips import LatentLPIPS
8 | from .video_loss import VideoAutoencoderLoss
9 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/losses/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ....util import default, instantiate_from_config
5 | from ..lpips.loss.lpips import LPIPS
6 |
7 |
8 | class LatentLPIPS(nn.Module):
9 |
10 | def __init__(
11 | self,
12 | decoder_config,
13 | perceptual_weight=1.0,
14 | latent_weight=1.0,
15 | scale_input_to_tgt_size=False,
16 | scale_tgt_to_input_size=False,
17 | perceptual_weight_on_inputs=0.0,
18 | ):
19 | super().__init__()
20 | self.scale_input_to_tgt_size = scale_input_to_tgt_size
21 | self.scale_tgt_to_input_size = scale_tgt_to_input_size
22 | self.init_decoder(decoder_config)
23 | self.perceptual_loss = LPIPS().eval()
24 | self.perceptual_weight = perceptual_weight
25 | self.latent_weight = latent_weight
26 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
27 |
28 | def init_decoder(self, config):
29 | self.decoder = instantiate_from_config(config)
30 | if hasattr(self.decoder, 'encoder'):
31 | del self.decoder.encoder
32 |
33 | def forward(self,
34 | latent_inputs,
35 | latent_predictions,
36 | image_inputs,
37 | split='train'):
38 | log = dict()
39 | loss = (latent_inputs - latent_predictions)**2
40 | log[f'{split}/latent_l2_loss'] = loss.mean().detach()
41 | image_reconstructions = None
42 | if self.perceptual_weight > 0.0:
43 | image_reconstructions = self.decoder.decode(latent_predictions)
44 | image_targets = self.decoder.decode(latent_inputs)
45 | perceptual_loss = self.perceptual_loss(
46 | image_targets.contiguous(), image_reconstructions.contiguous())
47 | loss = self.latent_weight * loss.mean(
48 | ) + self.perceptual_weight * perceptual_loss.mean()
49 | log[f'{split}/perceptual_loss'] = perceptual_loss.mean().detach()
50 |
51 | if self.perceptual_weight_on_inputs > 0.0:
52 | image_reconstructions = default(
53 | image_reconstructions, self.decoder.decode(latent_predictions))
54 | if self.scale_input_to_tgt_size:
55 | image_inputs = torch.nn.functional.interpolate(
56 | image_inputs,
57 | image_reconstructions.shape[2:],
58 | mode='bicubic',
59 | antialias=True,
60 | )
61 | elif self.scale_tgt_to_input_size:
62 | image_reconstructions = torch.nn.functional.interpolate(
63 | image_reconstructions,
64 | image_inputs.shape[2:],
65 | mode='bicubic',
66 | antialias=True,
67 | )
68 |
69 | perceptual_loss2 = self.perceptual_loss(
70 | image_inputs.contiguous(), image_reconstructions.contiguous())
71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean(
72 | )
73 | log[f'{split}/perceptual_loss_on_inputs'] = perceptual_loss2.mean(
74 | ).detach()
75 | return loss, log
76 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/lpips/__init__.py
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/loss/.gitignore:
--------------------------------------------------------------------------------
1 | vgg.pth
2 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/loss/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/lpips/loss/__init__.py
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/loss/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | from collections import namedtuple
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from ..util import get_ckpt_path
10 |
11 |
12 | class LPIPS(nn.Module):
13 | # Learned perceptual metric
14 | def __init__(self, use_dropout=True):
15 | super().__init__()
16 | self.scaling_layer = ScalingLayer()
17 | self.chns = [64, 128, 256, 512, 512] # vg16 features
18 | self.net = vgg16(pretrained=True, requires_grad=False)
19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24 | self.load_from_pretrained()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def load_from_pretrained(self, name='vgg_lpips'):
29 | ckpt = get_ckpt_path(name, 'sgm/modules/autoencoding/lpips/loss')
30 | self.load_state_dict(torch.load(ckpt,
31 | map_location=torch.device('cpu')),
32 | strict=False)
33 | print(f'loaded pretrained LPIPS loss from {ckpt}')
34 |
35 | @classmethod
36 | def from_pretrained(cls, name='vgg_lpips'):
37 | if name != 'vgg_lpips':
38 | raise NotImplementedError
39 | model = cls()
40 | ckpt = get_ckpt_path(name)
41 | model.load_state_dict(torch.load(ckpt,
42 | map_location=torch.device('cpu')),
43 | strict=False)
44 | return model
45 |
46 | def forward(self, input, target):
47 | in0_input, in1_input = (self.scaling_layer(input),
48 | self.scaling_layer(target))
49 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
50 | feats0, feats1, diffs = {}, {}, {}
51 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
52 | for kk in range(len(self.chns)):
53 | feats0[kk], feats1[kk] = normalize_tensor(
54 | outs0[kk]), normalize_tensor(outs1[kk])
55 | diffs[kk] = (feats0[kk] - feats1[kk])**2
56 |
57 | res = [
58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59 | for kk in range(len(self.chns))
60 | ]
61 | val = res[0]
62 | for l in range(1, len(self.chns)):
63 | val += res[l]
64 | return val
65 |
66 |
67 | class ScalingLayer(nn.Module):
68 |
69 | def __init__(self):
70 | super().__init__()
71 | self.register_buffer(
72 | 'shift',
73 | torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
74 | self.register_buffer(
75 | 'scale',
76 | torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
77 |
78 | def forward(self, inp):
79 | return (inp - self.shift) / self.scale
80 |
81 |
82 | class NetLinLayer(nn.Module):
83 | """A single linear layer which does a 1x1 conv"""
84 |
85 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
86 | super().__init__()
87 | layers = ([
88 | nn.Dropout(),
89 | ] if (use_dropout) else [])
90 | layers += [
91 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
92 | ]
93 | self.model = nn.Sequential(*layers)
94 |
95 |
96 | class vgg16(torch.nn.Module):
97 |
98 | def __init__(self, requires_grad=False, pretrained=True):
99 | super().__init__()
100 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
101 | self.slice1 = torch.nn.Sequential()
102 | self.slice2 = torch.nn.Sequential()
103 | self.slice3 = torch.nn.Sequential()
104 | self.slice4 = torch.nn.Sequential()
105 | self.slice5 = torch.nn.Sequential()
106 | self.N_slices = 5
107 | for x in range(4):
108 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
109 | for x in range(4, 9):
110 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(9, 16):
112 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(16, 23):
114 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(23, 30):
116 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
117 | if not requires_grad:
118 | for param in self.parameters():
119 | param.requires_grad = False
120 |
121 | def forward(self, X):
122 | h = self.slice1(X)
123 | h_relu1_2 = h
124 | h = self.slice2(h)
125 | h_relu2_2 = h
126 | h = self.slice3(h)
127 | h_relu3_3 = h
128 | h = self.slice4(h)
129 | h_relu4_3 = h
130 | h = self.slice5(h)
131 | h_relu5_3 = h
132 | vgg_outputs = namedtuple(
133 | 'VggOutputs',
134 | ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
135 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3,
136 | h_relu5_3)
137 | return out
138 |
139 |
140 | def normalize_tensor(x, eps=1e-10):
141 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
142 | return x / (norm_factor + eps)
143 |
144 |
145 | def spatial_average(x, keepdim=True):
146 | return x.mean([2, 3], keepdim=keepdim)
147 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/model/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR pix2pix --------------------------------
27 | BSD License
28 |
29 | For pix2pix software
30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | ----------------------------- LICENSE FOR DCGAN --------------------------------
44 | BSD License
45 |
46 | For dcgan.torch software
47 |
48 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
49 |
50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51 |
52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53 |
54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55 |
56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57 |
58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
59 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/lpips/model/__init__.py
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/model/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn as nn
4 |
5 | from ..util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | try:
12 | nn.init.normal_(m.weight.data, 0.0, 0.02)
13 | except:
14 | nn.init.normal_(m.conv.weight.data, 0.0, 0.02)
15 | elif classname.find('BatchNorm') != -1:
16 | nn.init.normal_(m.weight.data, 1.0, 0.02)
17 | nn.init.constant_(m.bias.data, 0)
18 |
19 |
20 | class NLayerDiscriminator(nn.Module):
21 | """Defines a PatchGAN discriminator as in Pix2Pix
22 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
23 | """
24 |
25 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
26 | """Construct a PatchGAN discriminator
27 | Parameters:
28 | input_nc (int) -- the number of channels in input images
29 | ndf (int) -- the number of filters in the last conv layer
30 | n_layers (int) -- the number of conv layers in the discriminator
31 | norm_layer -- normalization layer
32 | """
33 | super().__init__()
34 | if not use_actnorm:
35 | norm_layer = nn.BatchNorm2d
36 | else:
37 | norm_layer = ActNorm
38 | if type(
39 | norm_layer
40 | ) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
41 | use_bias = norm_layer.func != nn.BatchNorm2d
42 | else:
43 | use_bias = norm_layer != nn.BatchNorm2d
44 |
45 | kw = 4
46 | padw = 1
47 | sequence = [
48 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
49 | nn.LeakyReLU(0.2, True),
50 | ]
51 | nf_mult = 1
52 | nf_mult_prev = 1
53 | for n in range(1,
54 | n_layers): # gradually increase the number of filters
55 | nf_mult_prev = nf_mult
56 | nf_mult = min(2**n, 8)
57 | sequence += [
58 | nn.Conv2d(
59 | ndf * nf_mult_prev,
60 | ndf * nf_mult,
61 | kernel_size=kw,
62 | stride=2,
63 | padding=padw,
64 | bias=use_bias,
65 | ),
66 | norm_layer(ndf * nf_mult),
67 | nn.LeakyReLU(0.2, True),
68 | ]
69 |
70 | nf_mult_prev = nf_mult
71 | nf_mult = min(2**n_layers, 8)
72 | sequence += [
73 | nn.Conv2d(
74 | ndf * nf_mult_prev,
75 | ndf * nf_mult,
76 | kernel_size=kw,
77 | stride=1,
78 | padding=padw,
79 | bias=use_bias,
80 | ),
81 | norm_layer(ndf * nf_mult),
82 | nn.LeakyReLU(0.2, True),
83 | ]
84 |
85 | sequence += [
86 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
87 | ] # output 1 channel prediction map
88 | self.main = nn.Sequential(*sequence)
89 |
90 | def forward(self, input):
91 | """Standard forward."""
92 | return self.main(input)
93 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 |
4 | import requests
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 |
9 | URL_MAP = {
10 | 'vgg_lpips':
11 | 'https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1'
12 | }
13 |
14 | CKPT_MAP = {'vgg_lpips': 'vgg.pth'}
15 |
16 | MD5_MAP = {'vgg_lpips': 'd507d7349b931f0638a25a48a722f98a'}
17 |
18 |
19 | def download(url, local_path, chunk_size=1024):
20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
21 | with requests.get(url, stream=True) as r:
22 | total_size = int(r.headers.get('content-length', 0))
23 | with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
24 | with open(local_path, 'wb') as f:
25 | for data in r.iter_content(chunk_size=chunk_size):
26 | if data:
27 | f.write(data)
28 | pbar.update(chunk_size)
29 |
30 |
31 | def md5_hash(path):
32 | with open(path, 'rb') as f:
33 | content = f.read()
34 | return hashlib.md5(content).hexdigest()
35 |
36 |
37 | def get_ckpt_path(name, root, check=False):
38 | assert name in URL_MAP
39 | path = os.path.join(root, CKPT_MAP[name])
40 | if not os.path.exists(path) or (check
41 | and not md5_hash(path) == MD5_MAP[name]):
42 | print('Downloading {} model from {} to {}'.format(
43 | name, URL_MAP[name], path))
44 | download(URL_MAP[name], path)
45 | md5 = md5_hash(path)
46 | assert md5 == MD5_MAP[name], md5
47 | return path
48 |
49 |
50 | class ActNorm(nn.Module):
51 |
52 | def __init__(self,
53 | num_features,
54 | logdet=False,
55 | affine=True,
56 | allow_reverse_init=False):
57 | assert affine
58 | super().__init__()
59 | self.logdet = logdet
60 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
61 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
62 | self.allow_reverse_init = allow_reverse_init
63 |
64 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
65 |
66 | def initialize(self, input):
67 | with torch.no_grad():
68 | flatten = input.permute(1, 0, 2,
69 | 3).contiguous().view(input.shape[1], -1)
70 | mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(
71 | 3).permute(1, 0, 2, 3)
72 | std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(
73 | 3).permute(1, 0, 2, 3)
74 |
75 | self.loc.data.copy_(-mean)
76 | self.scale.data.copy_(1 / (std + 1e-6))
77 |
78 | def forward(self, input, reverse=False):
79 | if reverse:
80 | return self.reverse(input)
81 | if len(input.shape) == 2:
82 | input = input[:, :, None, None]
83 | squeeze = True
84 | else:
85 | squeeze = False
86 |
87 | _, _, height, width = input.shape
88 |
89 | if self.training and self.initialized.item() == 0:
90 | self.initialize(input)
91 | self.initialized.fill_(1)
92 |
93 | h = self.scale * (input + self.loc)
94 |
95 | if squeeze:
96 | h = h.squeeze(-1).squeeze(-1)
97 |
98 | if self.logdet:
99 | log_abs = torch.log(torch.abs(self.scale))
100 | logdet = height * width * torch.sum(log_abs)
101 | logdet = logdet * torch.ones(input.shape[0]).to(input)
102 | return h, logdet
103 |
104 | return h
105 |
106 | def reverse(self, output):
107 | if self.training and self.initialized.item() == 0:
108 | if not self.allow_reverse_init:
109 | raise RuntimeError(
110 | 'Initializing ActNorm in reverse direction is '
111 | 'disabled by default. Use allow_reverse_init=True to enable.'
112 | )
113 | else:
114 | self.initialize(output)
115 | self.initialized.fill_(1)
116 |
117 | if len(output.shape) == 2:
118 | output = output[:, :, None, None]
119 | squeeze = True
120 | else:
121 | squeeze = False
122 |
123 | h = output / self.scale - self.loc
124 |
125 | if squeeze:
126 | h = h.squeeze(-1).squeeze(-1)
127 | return h
128 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/lpips/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def hinge_d_loss(logits_real, logits_fake):
6 | loss_real = torch.mean(F.relu(1.0 - logits_real))
7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8 | d_loss = 0.5 * (loss_real + loss_fake)
9 | return d_loss
10 |
11 |
12 | def vanilla_d_loss(logits_real, logits_fake):
13 | d_loss = 0.5 * (torch.mean(torch.nn.functional.softplus(-logits_real)) +
14 | torch.mean(torch.nn.functional.softplus(logits_fake)))
15 | return d_loss
16 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/regularizers/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from ....modules.distributions.distributions import \
9 | DiagonalGaussianDistribution
10 | from .base import AbstractRegularizer
11 |
12 |
13 | class DiagonalGaussianRegularizer(AbstractRegularizer):
14 |
15 | def __init__(self, sample: bool = True):
16 | super().__init__()
17 | self.sample = sample
18 |
19 | def get_trainable_parameters(self) -> Any:
20 | yield from ()
21 |
22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
23 | log = dict()
24 | posterior = DiagonalGaussianDistribution(z)
25 | if self.sample:
26 | z = posterior.sample()
27 | else:
28 | z = posterior.mode()
29 | kl_loss = posterior.kl()
30 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
31 | log['kl_loss'] = kl_loss
32 | return z, log
33 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/regularizers/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 |
9 | class AbstractRegularizer(nn.Module):
10 |
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
15 | raise NotImplementedError()
16 |
17 | @abstractmethod
18 | def get_trainable_parameters(self) -> Any:
19 | raise NotImplementedError()
20 |
21 |
22 | class IdentityRegularizer(AbstractRegularizer):
23 |
24 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
25 | return z, dict()
26 |
27 | def get_trainable_parameters(self) -> Any:
28 | yield from ()
29 |
30 |
31 | def measure_perplexity(
32 | predicted_indices: torch.Tensor,
33 | num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
34 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
35 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
36 | encodings = F.one_hot(predicted_indices,
37 | num_centroids).float().reshape(-1, num_centroids)
38 | avg_probs = encodings.mean(0)
39 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
40 | cluster_use = torch.sum(avg_probs > 0)
41 | return perplexity, cluster_use
42 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py:
--------------------------------------------------------------------------------
1 | """
2 | Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3 | Code adapted from Jax version in Appendix A.1
4 | """
5 |
6 | from typing import List, Optional
7 |
8 | import torch
9 | import torch.nn as nn
10 | from einops import pack, rearrange, unpack
11 | from torch import Tensor, int32
12 | from torch.cuda.amp import autocast
13 | from torch.nn import Module
14 |
15 | # helper functions
16 |
17 |
18 | def exists(v):
19 | return v is not None
20 |
21 |
22 | def default(*args):
23 | for arg in args:
24 | if exists(arg):
25 | return arg
26 | return None
27 |
28 |
29 | def pack_one(t, pattern):
30 | return pack([t], pattern)
31 |
32 |
33 | def unpack_one(t, ps, pattern):
34 | return unpack(t, ps, pattern)[0]
35 |
36 |
37 | # tensor helpers
38 |
39 |
40 | def round_ste(z: Tensor) -> Tensor:
41 | """Round with straight through gradients."""
42 | zhat = z.round()
43 | return z + (zhat - z).detach()
44 |
45 |
46 | # main class
47 |
48 |
49 | class FSQ(Module):
50 |
51 | def __init__(
52 | self,
53 | levels: List[int],
54 | dim: Optional[int] = None,
55 | num_codebooks=1,
56 | keep_num_codebooks_dim: Optional[bool] = None,
57 | scale: Optional[float] = None,
58 | ):
59 | super().__init__()
60 | _levels = torch.tensor(levels, dtype=int32)
61 | self.register_buffer('_levels', _levels, persistent=False)
62 |
63 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]),
64 | dim=0,
65 | dtype=int32)
66 | self.register_buffer('_basis', _basis, persistent=False)
67 |
68 | self.scale = scale
69 |
70 | codebook_dim = len(levels)
71 | self.codebook_dim = codebook_dim
72 |
73 | effective_codebook_dim = codebook_dim * num_codebooks
74 | self.num_codebooks = num_codebooks
75 | self.effective_codebook_dim = effective_codebook_dim
76 |
77 | keep_num_codebooks_dim = default(keep_num_codebooks_dim,
78 | num_codebooks > 1)
79 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
80 | self.keep_num_codebooks_dim = keep_num_codebooks_dim
81 |
82 | self.dim = default(dim, len(_levels) * num_codebooks)
83 |
84 | has_projections = self.dim != effective_codebook_dim
85 | self.project_in = nn.Linear(
86 | self.dim,
87 | effective_codebook_dim) if has_projections else nn.Identity()
88 | self.project_out = nn.Linear(
89 | effective_codebook_dim,
90 | self.dim) if has_projections else nn.Identity()
91 | self.has_projections = has_projections
92 |
93 | self.codebook_size = self._levels.prod().item()
94 |
95 | implicit_codebook = self.indices_to_codes(torch.arange(
96 | self.codebook_size),
97 | project_out=False)
98 | self.register_buffer('implicit_codebook',
99 | implicit_codebook,
100 | persistent=False)
101 |
102 | def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
103 | """Bound `z`, an array of shape (..., d)."""
104 | half_l = (self._levels - 1) * (1 + eps) / 2
105 | offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
106 | shift = (offset / half_l).atanh()
107 | return (z + shift).tanh() * half_l - offset
108 |
109 | def quantize(self, z: Tensor) -> Tensor:
110 | """Quantizes z, returns quantized zhat, same shape as z."""
111 | quantized = round_ste(self.bound(z))
112 | half_width = self._levels // 2 # Renormalize to [-1, 1].
113 | return quantized / half_width
114 |
115 | def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
116 | half_width = self._levels // 2
117 | return (zhat_normalized * half_width) + half_width
118 |
119 | def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
120 | half_width = self._levels // 2
121 | return (zhat - half_width) / half_width
122 |
123 | def codes_to_indices(self, zhat: Tensor) -> Tensor:
124 | """Converts a `code` to an index in the codebook."""
125 | assert zhat.shape[-1] == self.codebook_dim
126 | zhat = self._scale_and_shift(zhat)
127 | return (zhat * self._basis).sum(dim=-1).to(int32)
128 |
129 | def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor:
130 | """Inverse of `codes_to_indices`."""
131 |
132 | is_img_or_video = indices.ndim >= (3 +
133 | int(self.keep_num_codebooks_dim))
134 |
135 | indices = rearrange(indices, '... -> ... 1')
136 | codes_non_centered = (indices // self._basis) % self._levels
137 | codes = self._scale_and_shift_inverse(codes_non_centered)
138 |
139 | if self.keep_num_codebooks_dim:
140 | codes = rearrange(codes, '... c d -> ... (c d)')
141 |
142 | if project_out:
143 | codes = self.project_out(codes)
144 |
145 | if is_img_or_video:
146 | codes = rearrange(codes, 'b ... d -> b d ...')
147 |
148 | return codes
149 |
150 | @autocast(enabled=False)
151 | def forward(self, z: Tensor) -> Tensor:
152 | """
153 | einstein notation
154 | b - batch
155 | n - sequence (or flattened spatial dimensions)
156 | d - feature dimension
157 | c - number of codebook dim
158 | """
159 |
160 | is_img_or_video = z.ndim >= 4
161 |
162 | # standardize image or video into (batch, seq, dimension)
163 |
164 | if is_img_or_video:
165 | z = rearrange(z, 'b d ... -> b ... d')
166 | z, ps = pack_one(z, 'b * d')
167 |
168 | assert z.shape[
169 | -1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
170 |
171 | z = self.project_in(z)
172 |
173 | z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
174 |
175 | codes = self.quantize(z)
176 | indices = self.codes_to_indices(codes)
177 |
178 | codes = rearrange(codes, 'b n c d -> b n (c d)')
179 |
180 | out = self.project_out(codes)
181 |
182 | # reconstitute image or video dimensions
183 |
184 | if is_img_or_video:
185 | out = unpack_one(out, ps, 'b * d')
186 | out = rearrange(out, 'b ... d -> b d ...')
187 |
188 | indices = unpack_one(indices, ps, 'b * c')
189 |
190 | if not self.keep_num_codebooks_dim:
191 | indices = rearrange(indices, '... 1 -> ...')
192 |
193 | return out, indices
194 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py:
--------------------------------------------------------------------------------
1 | """
2 | Lookup Free Quantization
3 | Proposed in https://arxiv.org/abs/2310.05737
4 |
5 | In the simplest setup, each dimension is quantized into {-1, 1}.
6 | An entropy penalty is used to encourage utilization.
7 | """
8 |
9 | from collections import namedtuple
10 | from math import ceil, log2
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from einops import pack, rearrange, reduce, unpack
15 | from torch import einsum, nn
16 | from torch.cuda.amp import autocast
17 | from torch.nn import Module
18 |
19 | # constants
20 |
21 | Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
22 |
23 | LossBreakdown = namedtuple(
24 | 'LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
25 |
26 | # helper functions
27 |
28 |
29 | def exists(v):
30 | return v is not None
31 |
32 |
33 | def default(*args):
34 | for arg in args:
35 | if exists(arg):
36 | return arg() if callable(arg) else arg
37 | return None
38 |
39 |
40 | def pack_one(t, pattern):
41 | return pack([t], pattern)
42 |
43 |
44 | def unpack_one(t, ps, pattern):
45 | return unpack(t, ps, pattern)[0]
46 |
47 |
48 | # entropy
49 |
50 |
51 | def log(t, eps=1e-5):
52 | return t.clamp(min=eps).log()
53 |
54 |
55 | def entropy(prob):
56 | return (-prob * log(prob)).sum(dim=-1)
57 |
58 |
59 | # class
60 |
61 |
62 | class LFQ(Module):
63 |
64 | def __init__(
65 | self,
66 | *,
67 | dim=None,
68 | codebook_size=None,
69 | entropy_loss_weight=0.1,
70 | commitment_loss_weight=0.25,
71 | diversity_gamma=1.0,
72 | straight_through_activation=nn.Identity(),
73 | num_codebooks=1,
74 | keep_num_codebooks_dim=None,
75 | codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer
76 | frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy
77 | ):
78 | super().__init__()
79 |
80 | # some assert validations
81 |
82 | assert exists(dim) or exists(
83 | codebook_size
84 | ), 'either dim or codebook_size must be specified for LFQ'
85 | assert (
86 | not exists(codebook_size) or log2(codebook_size).is_integer()
87 | ), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
88 |
89 | codebook_size = default(codebook_size, lambda: 2**dim)
90 | codebook_dim = int(log2(codebook_size))
91 |
92 | codebook_dims = codebook_dim * num_codebooks
93 | dim = default(dim, codebook_dims)
94 |
95 | has_projections = dim != codebook_dims
96 | self.project_in = nn.Linear(
97 | dim, codebook_dims) if has_projections else nn.Identity()
98 | self.project_out = nn.Linear(
99 | codebook_dims, dim) if has_projections else nn.Identity()
100 | self.has_projections = has_projections
101 |
102 | self.dim = dim
103 | self.codebook_dim = codebook_dim
104 | self.num_codebooks = num_codebooks
105 |
106 | keep_num_codebooks_dim = default(keep_num_codebooks_dim,
107 | num_codebooks > 1)
108 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
109 | self.keep_num_codebooks_dim = keep_num_codebooks_dim
110 |
111 | # straight through activation
112 |
113 | self.activation = straight_through_activation
114 |
115 | # entropy aux loss related weights
116 |
117 | assert 0 < frac_per_sample_entropy <= 1.0
118 | self.frac_per_sample_entropy = frac_per_sample_entropy
119 |
120 | self.diversity_gamma = diversity_gamma
121 | self.entropy_loss_weight = entropy_loss_weight
122 |
123 | # codebook scale
124 |
125 | self.codebook_scale = codebook_scale
126 |
127 | # commitment loss
128 |
129 | self.commitment_loss_weight = commitment_loss_weight
130 |
131 | # for no auxiliary loss, during inference
132 |
133 | self.register_buffer('mask', 2**torch.arange(codebook_dim - 1, -1, -1))
134 | self.register_buffer('zero', torch.tensor(0.0), persistent=False)
135 |
136 | # codes
137 |
138 | all_codes = torch.arange(codebook_size)
139 | bits = ((all_codes[..., None].int() & self.mask) != 0).float()
140 | codebook = self.bits_to_codes(bits)
141 |
142 | self.register_buffer('codebook', codebook, persistent=False)
143 |
144 | def bits_to_codes(self, bits):
145 | return bits * self.codebook_scale * 2 - self.codebook_scale
146 |
147 | @property
148 | def dtype(self):
149 | return self.codebook.dtype
150 |
151 | def indices_to_codes(self, indices, project_out=True):
152 | is_img_or_video = indices.ndim >= (3 +
153 | int(self.keep_num_codebooks_dim))
154 |
155 | if not self.keep_num_codebooks_dim:
156 | indices = rearrange(indices, '... -> ... 1')
157 |
158 | # indices to codes, which are bits of either -1 or 1
159 |
160 | bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
161 |
162 | codes = self.bits_to_codes(bits)
163 |
164 | codes = rearrange(codes, '... c d -> ... (c d)')
165 |
166 | # whether to project codes out to original dimensions
167 | # if the input feature dimensions were not log2(codebook size)
168 |
169 | if project_out:
170 | codes = self.project_out(codes)
171 |
172 | # rearrange codes back to original shape
173 |
174 | if is_img_or_video:
175 | codes = rearrange(codes, 'b ... d -> b d ...')
176 |
177 | return codes
178 |
179 | @autocast(enabled=False)
180 | def forward(
181 | self,
182 | x,
183 | inv_temperature=100.0,
184 | return_loss_breakdown=False,
185 | mask=None,
186 | ):
187 | """
188 | einstein notation
189 | b - batch
190 | n - sequence (or flattened spatial dimensions)
191 | d - feature dimension, which is also log2(codebook size)
192 | c - number of codebook dim
193 | """
194 |
195 | x = x.float()
196 |
197 | is_img_or_video = x.ndim >= 4
198 |
199 | # standardize image or video into (batch, seq, dimension)
200 |
201 | if is_img_or_video:
202 | x = rearrange(x, 'b d ... -> b ... d')
203 | x, ps = pack_one(x, 'b * d')
204 |
205 | assert x.shape[
206 | -1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
207 |
208 | x = self.project_in(x)
209 |
210 | # split out number of codebooks
211 |
212 | x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
213 |
214 | # quantize by eq 3.
215 |
216 | original_input = x
217 |
218 | codebook_value = torch.ones_like(x) * self.codebook_scale
219 | quantized = torch.where(x > 0, codebook_value, -codebook_value)
220 |
221 | # use straight-through gradients (optionally with custom activation fn) if training
222 |
223 | if self.training:
224 | x = self.activation(x)
225 | x = x + (quantized - x).detach()
226 | else:
227 | x = quantized
228 |
229 | # calculate indices
230 |
231 | indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c',
232 | 'sum')
233 |
234 | # entropy aux loss
235 |
236 | if self.training:
237 | # the same as euclidean distance up to a constant
238 | distance = -2 * einsum('... i d, j d -> ... i j', original_input,
239 | self.codebook)
240 |
241 | prob = (-distance * inv_temperature).softmax(dim=-1)
242 |
243 | # account for mask
244 |
245 | if exists(mask):
246 | prob = prob[mask]
247 | else:
248 | prob = rearrange(prob, 'b n ... -> (b n) ...')
249 |
250 | # whether to only use a fraction of probs, for reducing memory
251 |
252 | if self.frac_per_sample_entropy < 1.0:
253 | num_tokens = prob.shape[0]
254 | num_sampled_tokens = int(num_tokens *
255 | self.frac_per_sample_entropy)
256 | rand_mask = torch.randn(num_tokens).argsort(
257 | dim=-1) < num_sampled_tokens
258 | per_sample_probs = prob[rand_mask]
259 | else:
260 | per_sample_probs = prob
261 |
262 | # calculate per sample entropy
263 |
264 | per_sample_entropy = entropy(per_sample_probs).mean()
265 |
266 | # distribution over all available tokens in the batch
267 |
268 | avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
269 | codebook_entropy = entropy(avg_prob).mean()
270 |
271 | # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
272 | # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
273 |
274 | entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
275 | else:
276 | # if not training, just return dummy 0
277 | entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
278 |
279 | # commit loss
280 |
281 | if self.training:
282 | commit_loss = F.mse_loss(original_input,
283 | quantized.detach(),
284 | reduction='none')
285 |
286 | if exists(mask):
287 | commit_loss = commit_loss[mask]
288 |
289 | commit_loss = commit_loss.mean()
290 | else:
291 | commit_loss = self.zero
292 |
293 | # merge back codebook dim
294 |
295 | x = rearrange(x, 'b n c d -> b n (c d)')
296 |
297 | # project out to feature dimension if needed
298 |
299 | x = self.project_out(x)
300 |
301 | # reconstitute image or video dimensions
302 |
303 | if is_img_or_video:
304 | x = unpack_one(x, ps, 'b * d')
305 | x = rearrange(x, 'b ... d -> b d ...')
306 |
307 | indices = unpack_one(indices, ps, 'b * c')
308 |
309 | # whether to remove single codebook dim
310 |
311 | if not self.keep_num_codebooks_dim:
312 | indices = rearrange(indices, '... 1 -> ...')
313 |
314 | # complete aux loss
315 |
316 | aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
317 |
318 | ret = Return(x, indices, aux_loss)
319 |
320 | if not return_loss_breakdown:
321 | return ret
322 |
323 | return ret, LossBreakdown(per_sample_entropy, codebook_entropy,
324 | commit_loss)
325 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/temporal_ae.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Iterable, Union
2 |
3 | import torch
4 | from einops import rearrange, repeat
5 | from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE,
6 | AttnBlock, Decoder,
7 | MemoryEfficientAttnBlock,
8 | ResnetBlock)
9 | from sgm.modules.diffusionmodules.openaimodel import (ResBlock,
10 | timestep_embedding)
11 | from sgm.modules.video_attention import VideoTransformerBlock
12 | from sgm.util import partialclass
13 |
14 |
15 | class VideoResBlock(ResnetBlock):
16 |
17 | def __init__(
18 | self,
19 | out_channels,
20 | *args,
21 | dropout=0.0,
22 | video_kernel_size=3,
23 | alpha=0.0,
24 | merge_strategy='learned',
25 | **kwargs,
26 | ):
27 | super().__init__(out_channels=out_channels,
28 | dropout=dropout,
29 | *args,
30 | **kwargs)
31 | if video_kernel_size is None:
32 | video_kernel_size = [3, 1, 1]
33 | self.time_stack = ResBlock(
34 | channels=out_channels,
35 | emb_channels=0,
36 | dropout=dropout,
37 | dims=3,
38 | use_scale_shift_norm=False,
39 | use_conv=False,
40 | up=False,
41 | down=False,
42 | kernel_size=video_kernel_size,
43 | use_checkpoint=False,
44 | skip_t_emb=True,
45 | )
46 |
47 | self.merge_strategy = merge_strategy
48 | if self.merge_strategy == 'fixed':
49 | self.register_buffer('mix_factor', torch.Tensor([alpha]))
50 | elif self.merge_strategy == 'learned':
51 | self.register_parameter('mix_factor',
52 | torch.nn.Parameter(torch.Tensor([alpha])))
53 | else:
54 | raise ValueError(f'unknown merge strategy {self.merge_strategy}')
55 |
56 | def get_alpha(self, bs):
57 | if self.merge_strategy == 'fixed':
58 | return self.mix_factor
59 | elif self.merge_strategy == 'learned':
60 | return torch.sigmoid(self.mix_factor)
61 | else:
62 | raise NotImplementedError()
63 |
64 | def forward(self, x, temb, skip_video=False, timesteps=None):
65 | if timesteps is None:
66 | timesteps = self.timesteps
67 |
68 | b, c, h, w = x.shape
69 |
70 | x = super().forward(x, temb)
71 |
72 | if not skip_video:
73 | x_mix = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps)
74 |
75 | x = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps)
76 |
77 | x = self.time_stack(x, temb)
78 |
79 | alpha = self.get_alpha(bs=b // timesteps)
80 | x = alpha * x + (1.0 - alpha) * x_mix
81 |
82 | x = rearrange(x, 'b c t h w -> (b t) c h w')
83 | return x
84 |
85 |
86 | class AE3DConv(torch.nn.Conv2d):
87 |
88 | def __init__(self,
89 | in_channels,
90 | out_channels,
91 | video_kernel_size=3,
92 | *args,
93 | **kwargs):
94 | super().__init__(in_channels, out_channels, *args, **kwargs)
95 | if isinstance(video_kernel_size, Iterable):
96 | padding = [int(k // 2) for k in video_kernel_size]
97 | else:
98 | padding = int(video_kernel_size // 2)
99 |
100 | self.time_mix_conv = torch.nn.Conv3d(
101 | in_channels=out_channels,
102 | out_channels=out_channels,
103 | kernel_size=video_kernel_size,
104 | padding=padding,
105 | )
106 |
107 | def forward(self, input, timesteps, skip_video=False):
108 | x = super().forward(input)
109 | if skip_video:
110 | return x
111 | x = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps)
112 | x = self.time_mix_conv(x)
113 | return rearrange(x, 'b c t h w -> (b t) c h w')
114 |
115 |
116 | class VideoBlock(AttnBlock):
117 |
118 | def __init__(self,
119 | in_channels: int,
120 | alpha: float = 0,
121 | merge_strategy: str = 'learned'):
122 | super().__init__(in_channels)
123 | # no context, single headed, as in base class
124 | self.time_mix_block = VideoTransformerBlock(
125 | dim=in_channels,
126 | n_heads=1,
127 | d_head=in_channels,
128 | checkpoint=False,
129 | ff_in=True,
130 | attn_mode='softmax',
131 | )
132 |
133 | time_embed_dim = self.in_channels * 4
134 | self.video_time_embed = torch.nn.Sequential(
135 | torch.nn.Linear(self.in_channels, time_embed_dim),
136 | torch.nn.SiLU(),
137 | torch.nn.Linear(time_embed_dim, self.in_channels),
138 | )
139 |
140 | self.merge_strategy = merge_strategy
141 | if self.merge_strategy == 'fixed':
142 | self.register_buffer('mix_factor', torch.Tensor([alpha]))
143 | elif self.merge_strategy == 'learned':
144 | self.register_parameter('mix_factor',
145 | torch.nn.Parameter(torch.Tensor([alpha])))
146 | else:
147 | raise ValueError(f'unknown merge strategy {self.merge_strategy}')
148 |
149 | def forward(self, x, timesteps, skip_video=False):
150 | if skip_video:
151 | return super().forward(x)
152 |
153 | x_in = x
154 | x = self.attention(x)
155 | h, w = x.shape[2:]
156 | x = rearrange(x, 'b c h w -> b (h w) c')
157 |
158 | x_mix = x
159 | num_frames = torch.arange(timesteps, device=x.device)
160 | num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps)
161 | num_frames = rearrange(num_frames, 'b t -> (b t)')
162 | t_emb = timestep_embedding(num_frames,
163 | self.in_channels,
164 | repeat_only=False)
165 | emb = self.video_time_embed(t_emb) # b, n_channels
166 | emb = emb[:, None, :]
167 | x_mix = x_mix + emb
168 |
169 | alpha = self.get_alpha()
170 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
171 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
172 |
173 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
174 | x = self.proj_out(x)
175 |
176 | return x_in + x
177 |
178 | def get_alpha(self, ):
179 | if self.merge_strategy == 'fixed':
180 | return self.mix_factor
181 | elif self.merge_strategy == 'learned':
182 | return torch.sigmoid(self.mix_factor)
183 | else:
184 | raise NotImplementedError(
185 | f'unknown merge strategy {self.merge_strategy}')
186 |
187 |
188 | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
189 |
190 | def __init__(self,
191 | in_channels: int,
192 | alpha: float = 0,
193 | merge_strategy: str = 'learned'):
194 | super().__init__(in_channels)
195 | # no context, single headed, as in base class
196 | self.time_mix_block = VideoTransformerBlock(
197 | dim=in_channels,
198 | n_heads=1,
199 | d_head=in_channels,
200 | checkpoint=False,
201 | ff_in=True,
202 | attn_mode='softmax-xformers',
203 | )
204 |
205 | time_embed_dim = self.in_channels * 4
206 | self.video_time_embed = torch.nn.Sequential(
207 | torch.nn.Linear(self.in_channels, time_embed_dim),
208 | torch.nn.SiLU(),
209 | torch.nn.Linear(time_embed_dim, self.in_channels),
210 | )
211 |
212 | self.merge_strategy = merge_strategy
213 | if self.merge_strategy == 'fixed':
214 | self.register_buffer('mix_factor', torch.Tensor([alpha]))
215 | elif self.merge_strategy == 'learned':
216 | self.register_parameter('mix_factor',
217 | torch.nn.Parameter(torch.Tensor([alpha])))
218 | else:
219 | raise ValueError(f'unknown merge strategy {self.merge_strategy}')
220 |
221 | def forward(self, x, timesteps, skip_time_block=False):
222 | if skip_time_block:
223 | return super().forward(x)
224 |
225 | x_in = x
226 | x = self.attention(x)
227 | h, w = x.shape[2:]
228 | x = rearrange(x, 'b c h w -> b (h w) c')
229 |
230 | x_mix = x
231 | num_frames = torch.arange(timesteps, device=x.device)
232 | num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps)
233 | num_frames = rearrange(num_frames, 'b t -> (b t)')
234 | t_emb = timestep_embedding(num_frames,
235 | self.in_channels,
236 | repeat_only=False)
237 | emb = self.video_time_embed(t_emb) # b, n_channels
238 | emb = emb[:, None, :]
239 | x_mix = x_mix + emb
240 |
241 | alpha = self.get_alpha()
242 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
243 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
244 |
245 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
246 | x = self.proj_out(x)
247 |
248 | return x_in + x
249 |
250 | def get_alpha(self, ):
251 | if self.merge_strategy == 'fixed':
252 | return self.mix_factor
253 | elif self.merge_strategy == 'learned':
254 | return torch.sigmoid(self.mix_factor)
255 | else:
256 | raise NotImplementedError(
257 | f'unknown merge strategy {self.merge_strategy}')
258 |
259 |
260 | def make_time_attn(
261 | in_channels,
262 | attn_type='vanilla',
263 | attn_kwargs=None,
264 | alpha: float = 0,
265 | merge_strategy: str = 'learned',
266 | ):
267 | assert attn_type in [
268 | 'vanilla',
269 | 'vanilla-xformers',
270 | ], f'attn_type {attn_type} not supported for spatio-temporal attention'
271 | print(
272 | f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
273 | )
274 | if not XFORMERS_IS_AVAILABLE and attn_type == 'vanilla-xformers':
275 | print(
276 | f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
277 | f'This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}'
278 | )
279 | attn_type = 'vanilla'
280 |
281 | if attn_type == 'vanilla':
282 | assert attn_kwargs is None
283 | return partialclass(VideoBlock,
284 | in_channels,
285 | alpha=alpha,
286 | merge_strategy=merge_strategy)
287 | elif attn_type == 'vanilla-xformers':
288 | print(
289 | f'building MemoryEfficientAttnBlock with {in_channels} in_channels...'
290 | )
291 | return partialclass(
292 | MemoryEfficientVideoBlock,
293 | in_channels,
294 | alpha=alpha,
295 | merge_strategy=merge_strategy,
296 | )
297 | else:
298 | return NotImplementedError()
299 |
300 |
301 | class Conv2DWrapper(torch.nn.Conv2d):
302 |
303 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
304 | return super().forward(input)
305 |
306 |
307 | class VideoDecoder(Decoder):
308 | available_time_modes = ['all', 'conv-only', 'attn-only']
309 |
310 | def __init__(
311 | self,
312 | *args,
313 | video_kernel_size: Union[int, list] = 3,
314 | alpha: float = 0.0,
315 | merge_strategy: str = 'learned',
316 | time_mode: str = 'conv-only',
317 | **kwargs,
318 | ):
319 | self.video_kernel_size = video_kernel_size
320 | self.alpha = alpha
321 | self.merge_strategy = merge_strategy
322 | self.time_mode = time_mode
323 | assert (
324 | self.time_mode in self.available_time_modes
325 | ), f'time_mode parameter has to be in {self.available_time_modes}'
326 | super().__init__(*args, **kwargs)
327 |
328 | def get_last_layer(self, skip_time_mix=False, **kwargs):
329 | if self.time_mode == 'attn-only':
330 | raise NotImplementedError('TODO')
331 | else:
332 | return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight
333 |
334 | def _make_attn(self) -> Callable:
335 | if self.time_mode not in ['conv-only', 'only-last-conv']:
336 | return partialclass(
337 | make_time_attn,
338 | alpha=self.alpha,
339 | merge_strategy=self.merge_strategy,
340 | )
341 | else:
342 | return super()._make_attn()
343 |
344 | def _make_conv(self) -> Callable:
345 | if self.time_mode != 'attn-only':
346 | return partialclass(AE3DConv,
347 | video_kernel_size=self.video_kernel_size)
348 | else:
349 | return Conv2DWrapper
350 |
351 | def _make_resblock(self) -> Callable:
352 | if self.time_mode not in ['attn-only', 'only-last-conv']:
353 | return partialclass(
354 | VideoResBlock,
355 | video_kernel_size=self.video_kernel_size,
356 | alpha=self.alpha,
357 | merge_strategy=self.merge_strategy,
358 | )
359 | else:
360 | return super()._make_resblock()
361 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/autoencoding/vqvae/quantize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | from torch import einsum
7 |
8 |
9 | class VectorQuantizer2(nn.Module):
10 | """
11 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
12 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
13 | """
14 |
15 | # NOTE: due to a bug the beta term was applied to the wrong term. for
16 | # backwards compatibility we use the buggy version by default, but you can
17 | # specify legacy=False to fix it.
18 | def __init__(self,
19 | n_e,
20 | e_dim,
21 | beta,
22 | remap=None,
23 | unknown_index='random',
24 | sane_index_shape=False,
25 | legacy=True):
26 | super().__init__()
27 | self.n_e = n_e
28 | self.e_dim = e_dim
29 | self.beta = beta
30 | self.legacy = legacy
31 |
32 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
33 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
34 |
35 | self.remap = remap
36 | if self.remap is not None:
37 | self.register_buffer('used', torch.tensor(np.load(self.remap)))
38 | self.re_embed = self.used.shape[0]
39 | self.unknown_index = unknown_index # "random" or "extra" or integer
40 | if self.unknown_index == 'extra':
41 | self.unknown_index = self.re_embed
42 | self.re_embed = self.re_embed + 1
43 | print(f'Remapping {self.n_e} indices to {self.re_embed} indices. '
44 | f'Using {self.unknown_index} for unknown indices.')
45 | else:
46 | self.re_embed = n_e
47 |
48 | self.sane_index_shape = sane_index_shape
49 |
50 | def remap_to_used(self, inds):
51 | ishape = inds.shape
52 | assert len(ishape) > 1
53 | inds = inds.reshape(ishape[0], -1)
54 | used = self.used.to(inds)
55 | match = (inds[:, :, None] == used[None, None, ...]).long()
56 | new = match.argmax(-1)
57 | unknown = match.sum(2) < 1
58 | if self.unknown_index == 'random':
59 | new[unknown] = torch.randint(
60 | 0, self.re_embed,
61 | size=new[unknown].shape).to(device=new.device)
62 | else:
63 | new[unknown] = self.unknown_index
64 | return new.reshape(ishape)
65 |
66 | def unmap_to_all(self, inds):
67 | ishape = inds.shape
68 | assert len(ishape) > 1
69 | inds = inds.reshape(ishape[0], -1)
70 | used = self.used.to(inds)
71 | if self.re_embed > self.used.shape[0]: # extra token
72 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero
73 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
74 | return back.reshape(ishape)
75 |
76 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
77 | assert temp is None or temp == 1.0, 'Only for interface compatible with Gumbel'
78 | assert rescale_logits == False, 'Only for interface compatible with Gumbel'
79 | assert return_logits == False, 'Only for interface compatible with Gumbel'
80 | # reshape z -> (batch, height, width, channel) and flatten
81 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
82 | z_flattened = z.view(-1, self.e_dim)
83 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
84 |
85 | d = (torch.sum(z_flattened**2, dim=1, keepdim=True) +
86 | torch.sum(self.embedding.weight**2, dim=1) -
87 | 2 * torch.einsum('bd,dn->bn', z_flattened,
88 | rearrange(self.embedding.weight, 'n d -> d n')))
89 |
90 | min_encoding_indices = torch.argmin(d, dim=1)
91 | z_q = self.embedding(min_encoding_indices).view(z.shape)
92 | perplexity = None
93 | min_encodings = None
94 |
95 | # compute loss for embedding
96 | if not self.legacy:
97 | loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean(
98 | (z_q - z.detach())**2)
99 | else:
100 | loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean(
101 | (z_q - z.detach())**2)
102 |
103 | # preserve gradients
104 | z_q = z + (z_q - z).detach()
105 |
106 | # reshape back to match original input shape
107 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
108 |
109 | if self.remap is not None:
110 | min_encoding_indices = min_encoding_indices.reshape(
111 | z.shape[0], -1) # add batch axis
112 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
113 | min_encoding_indices = min_encoding_indices.reshape(-1,
114 | 1) # flatten
115 |
116 | if self.sane_index_shape:
117 | min_encoding_indices = min_encoding_indices.reshape(
118 | z_q.shape[0], z_q.shape[2], z_q.shape[3])
119 |
120 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
121 |
122 | def get_codebook_entry(self, indices, shape):
123 | # shape specifying (batch, height, width, channel)
124 | if self.remap is not None:
125 | indices = indices.reshape(shape[0], -1) # add batch axis
126 | indices = self.unmap_to_all(indices)
127 | indices = indices.reshape(-1) # flatten again
128 |
129 | # get quantized latent vectors
130 | z_q = self.embedding(indices)
131 |
132 | if shape is not None:
133 | z_q = z_q.view(shape)
134 | # reshape back to match original input shape
135 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
136 |
137 | return z_q
138 |
139 |
140 | class GumbelQuantize(nn.Module):
141 | """
142 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
143 | Gumbel Softmax trick quantizer
144 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
145 | https://arxiv.org/abs/1611.01144
146 | """
147 |
148 | def __init__(
149 | self,
150 | num_hiddens,
151 | embedding_dim,
152 | n_embed,
153 | straight_through=True,
154 | kl_weight=5e-4,
155 | temp_init=1.0,
156 | use_vqinterface=True,
157 | remap=None,
158 | unknown_index='random',
159 | ):
160 | super().__init__()
161 |
162 | self.embedding_dim = embedding_dim
163 | self.n_embed = n_embed
164 |
165 | self.straight_through = straight_through
166 | self.temperature = temp_init
167 | self.kl_weight = kl_weight
168 |
169 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
170 | self.embed = nn.Embedding(n_embed, embedding_dim)
171 |
172 | self.use_vqinterface = use_vqinterface
173 |
174 | self.remap = remap
175 | if self.remap is not None:
176 | self.register_buffer('used', torch.tensor(np.load(self.remap)))
177 | self.re_embed = self.used.shape[0]
178 | self.unknown_index = unknown_index # "random" or "extra" or integer
179 | if self.unknown_index == 'extra':
180 | self.unknown_index = self.re_embed
181 | self.re_embed = self.re_embed + 1
182 | print(
183 | f'Remapping {self.n_embed} indices to {self.re_embed} indices. '
184 | f'Using {self.unknown_index} for unknown indices.')
185 | else:
186 | self.re_embed = n_embed
187 |
188 | def remap_to_used(self, inds):
189 | ishape = inds.shape
190 | assert len(ishape) > 1
191 | inds = inds.reshape(ishape[0], -1)
192 | used = self.used.to(inds)
193 | match = (inds[:, :, None] == used[None, None, ...]).long()
194 | new = match.argmax(-1)
195 | unknown = match.sum(2) < 1
196 | if self.unknown_index == 'random':
197 | new[unknown] = torch.randint(
198 | 0, self.re_embed,
199 | size=new[unknown].shape).to(device=new.device)
200 | else:
201 | new[unknown] = self.unknown_index
202 | return new.reshape(ishape)
203 |
204 | def unmap_to_all(self, inds):
205 | ishape = inds.shape
206 | assert len(ishape) > 1
207 | inds = inds.reshape(ishape[0], -1)
208 | used = self.used.to(inds)
209 | if self.re_embed > self.used.shape[0]: # extra token
210 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero
211 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
212 | return back.reshape(ishape)
213 |
214 | def forward(self, z, temp=None, return_logits=False):
215 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
216 | hard = self.straight_through if self.training else True
217 | temp = self.temperature if temp is None else temp
218 |
219 | logits = self.proj(z)
220 | if self.remap is not None:
221 | # continue only with used logits
222 | full_zeros = torch.zeros_like(logits)
223 | logits = logits[:, self.used, ...]
224 |
225 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
226 | if self.remap is not None:
227 | # go back to all entries but unused set to zero
228 | full_zeros[:, self.used, ...] = soft_one_hot
229 | soft_one_hot = full_zeros
230 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot,
231 | self.embed.weight)
232 |
233 | # + kl divergence to the prior loss
234 | qy = F.softmax(logits, dim=1)
235 | diff = self.kl_weight * torch.sum(
236 | qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
237 |
238 | ind = soft_one_hot.argmax(dim=1)
239 | if self.remap is not None:
240 | ind = self.remap_to_used(ind)
241 | if self.use_vqinterface:
242 | if return_logits:
243 | return z_q, diff, (None, None, ind), logits
244 | return z_q, diff, (None, None, ind)
245 | return z_q, diff, ind
246 |
247 | def get_codebook_entry(self, indices, shape):
248 | b, h, w, c = shape
249 | assert b * h * w == indices.shape[0]
250 | indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
251 | if self.remap is not None:
252 | indices = self.unmap_to_all(indices)
253 | one_hot = F.one_hot(indices,
254 | num_classes=self.n_embed).permute(0, 3, 1,
255 | 2).float()
256 | z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
257 | return z_q
258 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/cp_enc_dec.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.distributed
5 | import torch.nn as nn
6 |
7 | from ..util import (get_context_parallel_group, get_context_parallel_rank,
8 | get_context_parallel_world_size)
9 |
10 | _USE_CP = True
11 |
12 |
13 | def cast_tuple(t, length=1):
14 | return t if isinstance(t, tuple) else ((t, ) * length)
15 |
16 |
17 | def divisible_by(num, den):
18 | return (num % den) == 0
19 |
20 |
21 | def is_odd(n):
22 | return not divisible_by(n, 2)
23 |
24 |
25 | def exists(v):
26 | return v is not None
27 |
28 |
29 | def pair(t):
30 | return t if isinstance(t, tuple) else (t, t)
31 |
32 |
33 | def get_timestep_embedding(timesteps, embedding_dim):
34 | """
35 | This matches the implementation in Denoising Diffusion Probabilistic Models:
36 | From Fairseq.
37 | Build sinusoidal embeddings.
38 | This matches the implementation in tensor2tensor, but differs slightly
39 | from the description in Section 3.5 of "Attention Is All You Need".
40 | """
41 | assert len(timesteps.shape) == 1
42 |
43 | half_dim = embedding_dim // 2
44 | emb = math.log(10000) / (half_dim - 1)
45 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
46 | emb = emb.to(device=timesteps.device)
47 | emb = timesteps.float()[:, None] * emb[None, :]
48 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
49 | if embedding_dim % 2 == 1: # zero pad
50 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
51 | return emb
52 |
53 |
54 | def nonlinearity(x):
55 | # swish
56 | return x * torch.sigmoid(x)
57 |
58 |
59 | def leaky_relu(p=0.1):
60 | return nn.LeakyReLU(p)
61 |
62 |
63 | def _split(input_, dim):
64 | cp_world_size = get_context_parallel_world_size()
65 |
66 | if cp_world_size == 1:
67 | return input_
68 |
69 | cp_rank = get_context_parallel_rank()
70 |
71 | # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
72 |
73 | inpu_first_frame_ = input_.transpose(0,
74 | dim)[:1].transpose(0,
75 | dim).contiguous()
76 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
77 | dim_size = input_.size()[dim] // cp_world_size
78 |
79 | input_list = torch.split(input_, dim_size, dim=dim)
80 | output = input_list[cp_rank]
81 |
82 | if cp_rank == 0:
83 | output = torch.cat([inpu_first_frame_, output], dim=dim)
84 | output = output.contiguous()
85 |
86 | # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
87 |
88 | return output
89 |
90 |
91 | def _gather(input_, dim):
92 | cp_world_size = get_context_parallel_world_size()
93 |
94 | # Bypass the function if context parallel is 1
95 | if cp_world_size == 1:
96 | return input_
97 |
98 | group = get_context_parallel_group()
99 | cp_rank = get_context_parallel_rank()
100 |
101 | # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
102 |
103 | input_first_frame_ = input_.transpose(0,
104 | dim)[:1].transpose(0,
105 | dim).contiguous()
106 | if cp_rank == 0:
107 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
108 |
109 | tensor_list = [
110 | torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))
111 | ] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
112 |
113 | if cp_rank == 0:
114 | input_ = torch.cat([input_first_frame_, input_], dim=dim)
115 |
116 | tensor_list[cp_rank] = input_
117 | torch.distributed.all_gather(tensor_list, input_, group=group)
118 |
119 | output = torch.cat(tensor_list, dim=dim).contiguous()
120 |
121 | # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
122 |
123 | return output
124 |
125 |
126 | def _conv_split(input_, dim, kernel_size):
127 | cp_world_size = get_context_parallel_world_size()
128 |
129 | # Bypass the function if context parallel is 1
130 | if cp_world_size == 1:
131 | return input_
132 |
133 | # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
134 |
135 | cp_rank = get_context_parallel_rank()
136 |
137 | dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
138 |
139 | if cp_rank == 0:
140 | output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose(
141 | dim, 0)
142 | else:
143 | output = input_.transpose(
144 | dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size +
145 | kernel_size].transpose(dim, 0)
146 | output = output.contiguous()
147 |
148 | # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
149 |
150 | return output
151 |
152 |
153 | def _conv_gather(input_, dim, kernel_size):
154 | cp_world_size = get_context_parallel_world_size()
155 |
156 | # Bypass the function if context parallel is 1
157 | if cp_world_size == 1:
158 | return input_
159 |
160 | group = get_context_parallel_group()
161 | cp_rank = get_context_parallel_rank()
162 |
163 | # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
164 |
165 | input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(
166 | 0, dim).contiguous()
167 | if cp_rank == 0:
168 | input_ = input_.transpose(0, dim)[kernel_size:].transpose(
169 | 0, dim).contiguous()
170 | else:
171 | input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(
172 | 0, dim).contiguous()
173 |
174 | tensor_list = [
175 | torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))
176 | ] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
177 | if cp_rank == 0:
178 | input_ = torch.cat([input_first_kernel_, input_], dim=dim)
179 |
180 | tensor_list[cp_rank] = input_
181 | torch.distributed.all_gather(tensor_list, input_, group=group)
182 |
183 | # Note: torch.cat already creates a contiguous tensor.
184 | output = torch.cat(tensor_list, dim=dim).contiguous()
185 |
186 | # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
187 |
188 | return output
189 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 | from .denoiser import Denoiser
2 | from .discretizer import Discretization
3 | from .model import Decoder, Encoder, Model
4 | from .openaimodel import UNetModel
5 | from .sampling import BaseDiffusionSampler
6 | from .wrappers import OpenAIWrapper
7 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/denoiser.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ...util import append_dims, instantiate_from_config
7 |
8 |
9 | class Denoiser(nn.Module):
10 |
11 | def __init__(self, weighting_config, scaling_config):
12 | super().__init__()
13 |
14 | self.weighting = instantiate_from_config(weighting_config)
15 | self.scaling = instantiate_from_config(scaling_config)
16 |
17 | def possibly_quantize_sigma(self, sigma):
18 | return sigma
19 |
20 | def possibly_quantize_c_noise(self, c_noise):
21 | return c_noise
22 |
23 | def w(self, sigma):
24 | return self.weighting(sigma)
25 |
26 | def forward(
27 | self,
28 | network: nn.Module,
29 | input: torch.Tensor,
30 | sigma: torch.Tensor,
31 | cond: Dict,
32 | **additional_model_inputs,
33 | ) -> torch.Tensor:
34 | sigma = self.possibly_quantize_sigma(sigma)
35 | sigma_shape = sigma.shape
36 | sigma = append_dims(sigma, input.ndim)
37 | c_skip, c_out, c_in, c_noise = self.scaling(sigma,
38 | **additional_model_inputs)
39 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
40 | return network(input * c_in, c_noise, cond, **
41 | additional_model_inputs) * c_out + input * c_skip
42 |
43 |
44 | class DiscreteDenoiser(Denoiser):
45 |
46 | def __init__(
47 | self,
48 | weighting_config,
49 | scaling_config,
50 | num_idx,
51 | discretization_config,
52 | do_append_zero=False,
53 | quantize_c_noise=True,
54 | flip=True,
55 | ):
56 | super().__init__(weighting_config, scaling_config)
57 | sigmas = instantiate_from_config(discretization_config)(
58 | num_idx, do_append_zero=do_append_zero, flip=flip)
59 | self.sigmas = sigmas
60 | # self.register_buffer("sigmas", sigmas)
61 | self.quantize_c_noise = quantize_c_noise
62 |
63 | def sigma_to_idx(self, sigma):
64 | dists = sigma - self.sigmas.to(sigma.device)[:, None]
65 | return dists.abs().argmin(dim=0).view(sigma.shape)
66 |
67 | def idx_to_sigma(self, idx):
68 | return self.sigmas.to(idx.device)[idx]
69 |
70 | def possibly_quantize_sigma(self, sigma):
71 | return self.idx_to_sigma(self.sigma_to_idx(sigma))
72 |
73 | def possibly_quantize_c_noise(self, c_noise):
74 | if self.quantize_c_noise:
75 | return self.sigma_to_idx(c_noise)
76 | else:
77 | return c_noise
78 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/denoiser_scaling.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 |
6 |
7 | class DenoiserScaling(ABC):
8 |
9 | @abstractmethod
10 | def __call__(
11 | self, sigma: torch.Tensor
12 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
13 | pass
14 |
15 |
16 | class EDMScaling:
17 |
18 | def __init__(self, sigma_data: float = 0.5):
19 | self.sigma_data = sigma_data
20 |
21 | def __call__(
22 | self, sigma: torch.Tensor
23 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
24 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
25 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2)**0.5
26 | c_in = 1 / (sigma**2 + self.sigma_data**2)**0.5
27 | c_noise = 0.25 * sigma.log()
28 | return c_skip, c_out, c_in, c_noise
29 |
30 |
31 | class EpsScaling:
32 |
33 | def __call__(
34 | self, sigma: torch.Tensor
35 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
36 | c_skip = torch.ones_like(sigma, device=sigma.device)
37 | c_out = -sigma
38 | c_in = 1 / (sigma**2 + 1.0)**0.5
39 | c_noise = sigma.clone()
40 | return c_skip, c_out, c_in, c_noise
41 |
42 |
43 | class VScaling:
44 |
45 | def __call__(
46 | self, sigma: torch.Tensor
47 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
48 | c_skip = 1.0 / (sigma**2 + 1.0)
49 | c_out = -sigma / (sigma**2 + 1.0)**0.5
50 | c_in = 1.0 / (sigma**2 + 1.0)**0.5
51 | c_noise = sigma.clone()
52 | return c_skip, c_out, c_in, c_noise
53 |
54 |
55 | class VScalingWithEDMcNoise(DenoiserScaling):
56 |
57 | def __call__(
58 | self, sigma: torch.Tensor
59 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
60 | c_skip = 1.0 / (sigma**2 + 1.0)
61 | c_out = -sigma / (sigma**2 + 1.0)**0.5
62 | c_in = 1.0 / (sigma**2 + 1.0)**0.5
63 | c_noise = 0.25 * sigma.log()
64 | return c_skip, c_out, c_in, c_noise
65 |
66 |
67 | class VideoScaling: # similar to VScaling
68 |
69 | def __call__(
70 | self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
71 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
72 | c_skip = alphas_cumprod_sqrt
73 | c_out = -((1 - alphas_cumprod_sqrt**2)**0.5)
74 | c_in = torch.ones_like(alphas_cumprod_sqrt,
75 | device=alphas_cumprod_sqrt.device)
76 | c_noise = additional_model_inputs['idx'].clone()
77 | return c_skip, c_out, c_in, c_noise
78 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/denoiser_weighting.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class UnitWeighting:
5 |
6 | def __call__(self, sigma):
7 | return torch.ones_like(sigma, device=sigma.device)
8 |
9 |
10 | class EDMWeighting:
11 |
12 | def __init__(self, sigma_data=0.5):
13 | self.sigma_data = sigma_data
14 |
15 | def __call__(self, sigma):
16 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data)**2
17 |
18 |
19 | class VWeighting(EDMWeighting):
20 |
21 | def __init__(self):
22 | super().__init__(sigma_data=1.0)
23 |
24 |
25 | class EpsWeighting:
26 |
27 | def __call__(self, sigma):
28 | return sigma**-2.0
29 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/discretizer.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from functools import partial
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from ...modules.diffusionmodules.util import make_beta_schedule
8 | from ...util import append_zero
9 |
10 |
11 | def generate_roughly_equally_spaced_steps(num_substeps: int,
12 | max_step: int) -> np.ndarray:
13 | return np.linspace(max_step - 1, 0, num_substeps,
14 | endpoint=False).astype(int)[::-1]
15 |
16 |
17 | class Discretization:
18 |
19 | def __call__(self,
20 | n,
21 | do_append_zero=True,
22 | device='cpu',
23 | flip=False,
24 | return_idx=False):
25 | if return_idx:
26 | sigmas, idx = self.get_sigmas(n,
27 | device=device,
28 | return_idx=return_idx)
29 | else:
30 | sigmas = self.get_sigmas(n, device=device, return_idx=return_idx)
31 | sigmas = append_zero(sigmas) if do_append_zero else sigmas
32 | if return_idx:
33 | return sigmas if not flip else torch.flip(sigmas, (0, )), idx
34 | else:
35 | return sigmas if not flip else torch.flip(sigmas, (0, ))
36 |
37 | @abstractmethod
38 | def get_sigmas(self, n, device):
39 | pass
40 |
41 |
42 | class EDMDiscretization(Discretization):
43 |
44 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
45 | self.sigma_min = sigma_min
46 | self.sigma_max = sigma_max
47 | self.rho = rho
48 |
49 | def get_sigmas(self, n, device='cpu'):
50 | ramp = torch.linspace(0, 1, n, device=device)
51 | min_inv_rho = self.sigma_min**(1 / self.rho)
52 | max_inv_rho = self.sigma_max**(1 / self.rho)
53 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**self.rho
54 | return sigmas
55 |
56 |
57 | class LegacyDDPMDiscretization(Discretization):
58 |
59 | def __init__(
60 | self,
61 | linear_start=0.00085,
62 | linear_end=0.0120,
63 | num_timesteps=1000,
64 | ):
65 | super().__init__()
66 | self.num_timesteps = num_timesteps
67 | betas = make_beta_schedule('linear',
68 | num_timesteps,
69 | linear_start=linear_start,
70 | linear_end=linear_end)
71 | alphas = 1.0 - betas
72 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
73 | self.to_torch = partial(torch.tensor, dtype=torch.float32)
74 |
75 | def get_sigmas(self, n, device='cpu'):
76 | if n < self.num_timesteps:
77 | timesteps = generate_roughly_equally_spaced_steps(
78 | n, self.num_timesteps)
79 | alphas_cumprod = self.alphas_cumprod[timesteps]
80 | elif n == self.num_timesteps:
81 | alphas_cumprod = self.alphas_cumprod
82 | else:
83 | raise ValueError
84 |
85 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
86 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod)**0.5
87 | return torch.flip(sigmas, (0, )) # sigma_t: 14.4 -> 0.029
88 |
89 |
90 | class ZeroSNRDDPMDiscretization(Discretization):
91 |
92 | def __init__(
93 | self,
94 | linear_start=0.00085,
95 | linear_end=0.0120,
96 | num_timesteps=1000,
97 | shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale)
98 | keep_start=False,
99 | post_shift=False,
100 | ):
101 | super().__init__()
102 | if keep_start and not post_shift:
103 | linear_start = linear_start / (shift_scale +
104 | (1 - shift_scale) * linear_start)
105 | self.num_timesteps = num_timesteps
106 | betas = make_beta_schedule('linear',
107 | num_timesteps,
108 | linear_start=linear_start,
109 | linear_end=linear_end)
110 | alphas = 1.0 - betas
111 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
112 | self.to_torch = partial(torch.tensor, dtype=torch.float32)
113 |
114 | # SNR shift
115 | if not post_shift:
116 | self.alphas_cumprod = self.alphas_cumprod / (
117 | shift_scale + (1 - shift_scale) * self.alphas_cumprod)
118 |
119 | self.post_shift = post_shift
120 | self.shift_scale = shift_scale
121 |
122 | def get_sigmas(self, n, device='cpu', return_idx=False):
123 | if n < self.num_timesteps:
124 | timesteps = generate_roughly_equally_spaced_steps(
125 | n, self.num_timesteps)
126 | alphas_cumprod = self.alphas_cumprod[timesteps]
127 | elif n == self.num_timesteps:
128 | alphas_cumprod = self.alphas_cumprod
129 | else:
130 | raise ValueError
131 |
132 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
133 | alphas_cumprod = to_torch(alphas_cumprod)
134 | alphas_cumprod_sqrt = alphas_cumprod.sqrt()
135 | alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
136 | alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
137 |
138 | alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
139 | alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 -
140 | alphas_cumprod_sqrt_T)
141 |
142 | if self.post_shift:
143 | alphas_cumprod_sqrt = (
144 | alphas_cumprod_sqrt**2 /
145 | (self.shift_scale +
146 | (1 - self.shift_scale) * alphas_cumprod_sqrt**2))**0.5
147 |
148 | if return_idx:
149 | return torch.flip(alphas_cumprod_sqrt, (0, )), timesteps
150 | else:
151 | return torch.flip(alphas_cumprod_sqrt,
152 | (0, )) # sqrt(alpha_t): 0 -> 0.99
153 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/guiders.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from abc import ABC, abstractmethod
4 | from functools import partial
5 | from typing import Dict, List, Optional, Tuple, Union
6 |
7 | import torch
8 | from einops import rearrange, repeat
9 |
10 | from ...util import append_dims, default, instantiate_from_config
11 |
12 |
13 | class Guider(ABC):
14 |
15 | @abstractmethod
16 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
17 | pass
18 |
19 | def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict,
20 | uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
21 | pass
22 |
23 |
24 | class VanillaCFG:
25 | """
26 | implements parallelized CFG
27 | """
28 |
29 | def __init__(self, scale, dyn_thresh_config=None):
30 | self.scale = scale
31 | scale_schedule = lambda scale, sigma: scale # independent of step
32 | self.scale_schedule = partial(scale_schedule, scale)
33 | self.dyn_thresh = instantiate_from_config(
34 | default(
35 | dyn_thresh_config,
36 | {
37 | 'target':
38 | 'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
39 | },
40 | ))
41 |
42 | def __call__(self, x, sigma, scale=None):
43 | x_u, x_c = x.chunk(2)
44 | scale_value = default(scale, self.scale_schedule(sigma))
45 | x_pred = self.dyn_thresh(x_u, x_c, scale_value)
46 | return x_pred
47 |
48 | def prepare_inputs(self, x, s, c, uc):
49 | c_out = dict()
50 |
51 | for k in c:
52 | if k in ['vector', 'crossattn', 'concat']:
53 | c_out[k] = torch.cat((uc[k], c[k]), 0)
54 | else:
55 | assert c[k] == uc[k]
56 | c_out[k] = c[k]
57 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out
58 |
59 |
60 | # class DynamicCFG(VanillaCFG):
61 |
62 | # def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
63 | # super().__init__(scale, dyn_thresh_config)
64 | # scale_schedule = (lambda scale, sigma, step_index: 1 + scale *
65 | # (1 - math.cos(math.pi *
66 | # (step_index / num_steps)**exp)) / 2)
67 | # self.scale_schedule = partial(scale_schedule, scale)
68 | # self.dyn_thresh = instantiate_from_config(
69 | # default(
70 | # dyn_thresh_config,
71 | # {
72 | # 'target':
73 | # 'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
74 | # },
75 | # ))
76 |
77 | # def __call__(self, x, sigma, step_index, scale=None):
78 | # x_u, x_c = x.chunk(2)
79 | # scale_value = self.scale_schedule(sigma, step_index.item())
80 | # x_pred = self.dyn_thresh(x_u, x_c, scale_value)
81 | # return x_pred
82 |
83 |
84 | class DynamicCFG(VanillaCFG):
85 |
86 | def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
87 | super().__init__(scale, dyn_thresh_config)
88 |
89 | self.scale = scale
90 | self.num_steps = num_steps
91 | self.exp = exp
92 | scale_schedule = (lambda scale, sigma, step_index: 1 + scale *
93 | (1 - math.cos(math.pi *
94 | (step_index / num_steps)**exp)) / 2)
95 |
96 | #self.scale_schedule = partial(scale_schedule, scale)
97 | self.dyn_thresh = instantiate_from_config(
98 | default(
99 | dyn_thresh_config,
100 | {
101 | 'target':
102 | 'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding'
103 | },
104 | ))
105 |
106 | def scale_schedule_dy(self, sigma, step_index):
107 | # print(self.scale)
108 | return 1 + self.scale * (
109 | 1 - math.cos(math.pi *
110 | (step_index / self.num_steps)**self.exp)) / 2
111 |
112 | def __call__(self, x, sigma, step_index, scale=None):
113 | x_u, x_c = x.chunk(2)
114 | scale_value = self.scale_schedule_dy(sigma, step_index.item())
115 | x_pred = self.dyn_thresh(x_u, x_c, scale_value)
116 | return x_pred
117 |
118 |
119 | class IdentityGuider:
120 |
121 | def __call__(self, x, sigma):
122 | return x
123 |
124 | def prepare_inputs(self, x, s, c, uc):
125 | c_out = dict()
126 |
127 | for k in c:
128 | c_out[k] = c[k]
129 |
130 | return x, s, c_out
131 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/loss.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from omegaconf import ListConfig
6 |
7 | from sat import mpu
8 |
9 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS
10 | from ...util import append_dims, instantiate_from_config
11 |
12 |
13 | class StandardDiffusionLoss(nn.Module):
14 |
15 | def __init__(
16 | self,
17 | sigma_sampler_config,
18 | type='l2',
19 | offset_noise_level=0.0,
20 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
21 | ):
22 | super().__init__()
23 |
24 | assert type in ['l2', 'l1', 'lpips']
25 |
26 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
27 |
28 | self.type = type
29 | self.offset_noise_level = offset_noise_level
30 |
31 | if type == 'lpips':
32 | self.lpips = LPIPS().eval()
33 |
34 | if not batch2model_keys:
35 | batch2model_keys = []
36 |
37 | if isinstance(batch2model_keys, str):
38 | batch2model_keys = [batch2model_keys]
39 |
40 | self.batch2model_keys = set(batch2model_keys)
41 |
42 | def __call__(self, network, denoiser, conditioner, input, batch):
43 | cond = conditioner(batch)
44 | additional_model_inputs = {
45 | key: batch[key]
46 | for key in self.batch2model_keys.intersection(batch)
47 | }
48 |
49 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
50 | noise = torch.randn_like(input)
51 | if self.offset_noise_level > 0.0:
52 | noise = (noise + append_dims(
53 | torch.randn(input.shape[0]).to(input.device), input.ndim) *
54 | self.offset_noise_level)
55 | noise = noise.to(input.dtype)
56 | noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
57 | model_output = denoiser(network, noised_input, sigmas, cond,
58 | **additional_model_inputs)
59 | w = append_dims(denoiser.w(sigmas), input.ndim)
60 | return self.get_loss(model_output, input, w)
61 |
62 | def get_loss(self, model_output, target, w):
63 | if self.type == 'l2':
64 | return torch.mean(
65 | (w * (model_output - target)**2).reshape(target.shape[0],
66 | -1), 1)
67 | elif self.type == 'l1':
68 | return torch.mean((w * (model_output - target).abs()).reshape(
69 | target.shape[0], -1), 1)
70 | elif self.type == 'lpips':
71 | loss = self.lpips(model_output, target).reshape(-1)
72 | return loss
73 |
74 |
75 | class VideoDiffusionLoss(StandardDiffusionLoss):
76 |
77 | def __init__(self,
78 | block_scale=None,
79 | block_size=None,
80 | min_snr_value=None,
81 | fixed_frames=0,
82 | **kwargs):
83 | self.fixed_frames = fixed_frames
84 | self.block_scale = block_scale
85 | self.block_size = block_size
86 | self.min_snr_value = min_snr_value
87 | super().__init__(**kwargs)
88 |
89 | def __call__(self, network, denoiser, conditioner, input, batch):
90 | cond = conditioner(batch)
91 | additional_model_inputs = {
92 | key: batch[key]
93 | for key in self.batch2model_keys.intersection(batch)
94 | }
95 |
96 | alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0],
97 | return_idx=True)
98 | #tensor([0.8585])
99 |
100 | if 'ref_noise_step' in self.share_cache:
101 |
102 | print(self.share_cache['ref_noise_step'])
103 | ref_noise_step = self.share_cache['ref_noise_step']
104 | ref_alphas_cumprod_sqrt = self.sigma_sampler.idx_to_sigma(
105 | torch.zeros(input.shape[0]).fill_(ref_noise_step).long())
106 | ref_alphas_cumprod_sqrt = ref_alphas_cumprod_sqrt.to(input.device)
107 | ref_x = self.share_cache['ref_x']
108 | ref_noise = torch.randn_like(ref_x)
109 |
110 | # *0.8505 + noise * 0.5128 sqrt(1-0.8505^2)**0.5
111 | ref_noised_input = ref_x * append_dims(ref_alphas_cumprod_sqrt, ref_x.ndim) \
112 | + ref_noise * append_dims(
113 | (1 - ref_alphas_cumprod_sqrt**2) ** 0.5, ref_x.ndim
114 | )
115 | self.share_cache['ref_x'] = ref_noised_input
116 |
117 | alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
118 | idx = idx.to(input.device)
119 |
120 | noise = torch.randn_like(input)
121 |
122 | # broadcast noise
123 | mp_size = mpu.get_model_parallel_world_size()
124 | global_rank = torch.distributed.get_rank() // mp_size
125 | src = global_rank * mp_size
126 | torch.distributed.broadcast(idx,
127 | src=src,
128 | group=mpu.get_model_parallel_group())
129 | torch.distributed.broadcast(noise,
130 | src=src,
131 | group=mpu.get_model_parallel_group())
132 | torch.distributed.broadcast(alphas_cumprod_sqrt,
133 | src=src,
134 | group=mpu.get_model_parallel_group())
135 |
136 | additional_model_inputs['idx'] = idx
137 |
138 | if self.offset_noise_level > 0.0:
139 | noise = (noise + append_dims(
140 | torch.randn(input.shape[0]).to(input.device), input.ndim) *
141 | self.offset_noise_level)
142 |
143 | noised_input = input.float() * append_dims(
144 | alphas_cumprod_sqrt, input.ndim) + noise * append_dims(
145 | (1 - alphas_cumprod_sqrt**2)**0.5, input.ndim)
146 |
147 | if 'concat_images' in batch.keys():
148 | cond['concat'] = batch['concat_images']
149 |
150 | # [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
151 | model_output = denoiser(network, noised_input, alphas_cumprod_sqrt,
152 | cond, **additional_model_inputs)
153 | w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
154 |
155 | if self.min_snr_value is not None:
156 | w = min(w, self.min_snr_value)
157 | return self.get_loss(model_output, input, w)
158 |
159 | def get_loss(self, model_output, target, w):
160 | if self.type == 'l2':
161 | # model_output.shape
162 | # torch.Size([1, 2, 16, 60, 88])
163 | return torch.mean(
164 | (w * (model_output - target)**2).reshape(target.shape[0],
165 | -1), 1)
166 | elif self.type == 'l1':
167 | return torch.mean((w * (model_output - target).abs()).reshape(
168 | target.shape[0], -1), 1)
169 | elif self.type == 'lpips':
170 | loss = self.lpips(model_output, target).reshape(-1)
171 | return loss
172 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/sampling_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from scipy import integrate
4 |
5 | from ...util import append_dims
6 |
7 |
8 | class NoDynamicThresholding:
9 |
10 | def __call__(self, uncond, cond, scale):
11 | scale = append_dims(scale, cond.ndim) if isinstance(
12 | scale, torch.Tensor) else scale
13 | return uncond + scale * (cond - uncond)
14 |
15 |
16 | class StaticThresholding:
17 |
18 | def __call__(self, uncond, cond, scale):
19 | result = uncond + scale * (cond - uncond)
20 | result = torch.clamp(result, min=-1.0, max=1.0)
21 | return result
22 |
23 |
24 | def dynamic_threshold(x, p=0.95):
25 | N, T, C, H, W = x.shape
26 | x = rearrange(x, 'n t c h w -> n c (t h w)')
27 | l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device),
28 | dim=-1,
29 | keepdim=True)
30 | s = torch.maximum(-l, r)
31 | threshold_mask = (s > 1).expand(-1, -1, H * W * T)
32 | if threshold_mask.any():
33 | x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x)
34 | x = rearrange(x, 'n c (t h w) -> n t c h w', t=T, h=H, w=W)
35 | return x
36 |
37 |
38 | def dynamic_thresholding2(x0):
39 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
40 | origin_dtype = x0.dtype
41 | x0 = x0.to(torch.float32)
42 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
43 | s = append_dims(torch.maximum(s,
44 | torch.ones_like(s).to(s.device)), x0.dim())
45 | x0 = torch.clamp(x0, -s, s) # / s
46 | return x0.to(origin_dtype)
47 |
48 |
49 | def latent_dynamic_thresholding(x0):
50 | p = 0.9995
51 | origin_dtype = x0.dtype
52 | x0 = x0.to(torch.float32)
53 | s = torch.quantile(torch.abs(x0), p, dim=2)
54 | s = append_dims(s, x0.dim())
55 | x0 = torch.clamp(x0, -s, s) / s
56 | return x0.to(origin_dtype)
57 |
58 |
59 | def dynamic_thresholding3(x0):
60 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
61 | origin_dtype = x0.dtype
62 | x0 = x0.to(torch.float32)
63 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
64 | s = append_dims(torch.maximum(s,
65 | torch.ones_like(s).to(s.device)), x0.dim())
66 | x0 = torch.clamp(x0, -s, s) # / s
67 | return x0.to(origin_dtype)
68 |
69 |
70 | class DynamicThresholding:
71 |
72 | def __call__(self, uncond, cond, scale):
73 | mean = uncond.mean()
74 | std = uncond.std()
75 | result = uncond + scale * (cond - uncond)
76 | result_mean, result_std = result.mean(), result.std()
77 | result = (result - result_mean) / result_std * std
78 | # result = dynamic_thresholding3(result)
79 | return result
80 |
81 |
82 | class DynamicThresholdingV1:
83 |
84 | def __init__(self, scale_factor):
85 | self.scale_factor = scale_factor
86 |
87 | def __call__(self, uncond, cond, scale):
88 | result = uncond + scale * (cond - uncond)
89 | unscaled_result = result / self.scale_factor
90 | B, T, C, H, W = unscaled_result.shape
91 | flattened = rearrange(unscaled_result, 'b t c h w -> b c (t h w)')
92 | means = flattened.mean(dim=2).unsqueeze(2)
93 | recentered = flattened - means
94 | magnitudes = recentered.abs().max()
95 | normalized = recentered / magnitudes
96 | thresholded = latent_dynamic_thresholding(normalized)
97 | denormalized = thresholded * magnitudes
98 | uncentered = denormalized + means
99 | unflattened = rearrange(uncentered,
100 | 'b c (t h w) -> b t c h w',
101 | t=T,
102 | h=H,
103 | w=W)
104 | scaled_result = unflattened * self.scale_factor
105 | return scaled_result
106 |
107 |
108 | class DynamicThresholdingV2:
109 |
110 | def __call__(self, uncond, cond, scale):
111 | B, T, C, H, W = uncond.shape
112 | diff = cond - uncond
113 | mim_target = uncond + diff * 4.0
114 | cfg_target = uncond + diff * 8.0
115 |
116 | mim_flattened = rearrange(mim_target, 'b t c h w -> b c (t h w)')
117 | cfg_flattened = rearrange(cfg_target, 'b t c h w -> b c (t h w)')
118 | mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
119 | cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
120 | mim_centered = mim_flattened - mim_means
121 | cfg_centered = cfg_flattened - cfg_means
122 |
123 | mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
124 | cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
125 |
126 | cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref
127 |
128 | result = cfg_renormalized + cfg_means
129 | unflattened = rearrange(result,
130 | 'b c (t h w) -> b t c h w',
131 | t=T,
132 | h=H,
133 | w=W)
134 |
135 | return unflattened
136 |
137 |
138 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
139 | if order - 1 > i:
140 | raise ValueError(f'Order {order} too high for step {i}')
141 |
142 | def fn(tau):
143 | prod = 1.0
144 | for k in range(order):
145 | if j == k:
146 | continue
147 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
148 | return prod
149 |
150 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
151 |
152 |
153 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
154 | if not eta:
155 | return sigma_to, 0.0
156 | sigma_up = torch.minimum(
157 | sigma_to,
158 | eta * (sigma_to**2 *
159 | (sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5,
160 | )
161 | sigma_down = (sigma_to**2 - sigma_up**2)**0.5
162 | return sigma_down, sigma_up
163 |
164 |
165 | def to_d(x, sigma, denoised):
166 | return (x - denoised) / append_dims(sigma, x.ndim)
167 |
168 |
169 | def to_neg_log_sigma(sigma):
170 | return sigma.log().neg()
171 |
172 |
173 | def to_sigma(neg_log_sigma):
174 | return neg_log_sigma.neg().exp()
175 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/sigma_sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed
3 |
4 | from sat import mpu
5 |
6 | from ...util import default, instantiate_from_config
7 |
8 |
9 | class EDMSampling:
10 |
11 | def __init__(self, p_mean=-1.2, p_std=1.2):
12 | self.p_mean = p_mean
13 | self.p_std = p_std
14 |
15 | def __call__(self, n_samples, rand=None):
16 | log_sigma = self.p_mean + self.p_std * default(
17 | rand, torch.randn((n_samples, )))
18 | return log_sigma.exp()
19 |
20 |
21 | class DiscreteSampling:
22 |
23 | def __init__(self,
24 | discretization_config,
25 | num_idx,
26 | do_append_zero=False,
27 | flip=True,
28 | uniform_sampling=False):
29 | self.num_idx = num_idx
30 | self.sigmas = instantiate_from_config(discretization_config)(
31 | num_idx, do_append_zero=do_append_zero, flip=flip)
32 | world_size = mpu.get_data_parallel_world_size()
33 | self.uniform_sampling = uniform_sampling
34 | if self.uniform_sampling:
35 | i = 1
36 | while True:
37 | if world_size % i != 0 or num_idx % (world_size // i) != 0:
38 | i += 1
39 | else:
40 | self.group_num = world_size // i
41 | break
42 |
43 | assert self.group_num > 0
44 | assert world_size % self.group_num == 0
45 | self.group_width = world_size // self.group_num # the number of rank in one group
46 | self.sigma_interval = self.num_idx // self.group_num
47 |
48 | def idx_to_sigma(self, idx):
49 | return self.sigmas[idx]
50 |
51 | def __call__(self, n_samples, rand=None, return_idx=False):
52 | if self.uniform_sampling:
53 | rank = mpu.get_data_parallel_rank()
54 | group_index = rank // self.group_width
55 | idx = default(
56 | rand,
57 | torch.randint(group_index * self.sigma_interval,
58 | (group_index + 1) * self.sigma_interval,
59 | (n_samples, )),
60 | )
61 | else:
62 | idx = default(
63 | rand,
64 | torch.randint(0, self.num_idx, (n_samples, )),
65 | )
66 | if return_idx:
67 | return self.idx_to_sigma(idx), idx
68 | else:
69 | return self.idx_to_sigma(idx)
70 |
71 |
72 | class PartialDiscreteSampling:
73 |
74 | def __init__(self,
75 | discretization_config,
76 | total_num_idx,
77 | partial_num_idx,
78 | do_append_zero=False,
79 | flip=True):
80 | self.total_num_idx = total_num_idx
81 | self.partial_num_idx = partial_num_idx
82 | self.sigmas = instantiate_from_config(discretization_config)(
83 | total_num_idx, do_append_zero=do_append_zero, flip=flip)
84 |
85 | def idx_to_sigma(self, idx):
86 | return self.sigmas[idx]
87 |
88 | def __call__(self, n_samples, rand=None):
89 | idx = default(
90 | rand,
91 | # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)),
92 | torch.randint(0, self.partial_num_idx, (n_samples, )),
93 | )
94 | return self.idx_to_sigma(idx)
95 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/diffusionmodules/wrappers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from packaging import version
4 |
5 | OPENAIUNETWRAPPER = 'sgm.modules.diffusionmodules.wrappers.OpenAIWrapper'
6 |
7 |
8 | class IdentityWrapper(nn.Module):
9 |
10 | def __init__(self,
11 | diffusion_model,
12 | compile_model: bool = False,
13 | dtype: torch.dtype = torch.float32):
14 | super().__init__()
15 | compile = (torch.compile if
16 | (version.parse(torch.__version__) >= version.parse('2.0.0'))
17 | and compile_model else lambda x: x)
18 | self.diffusion_model = compile(diffusion_model)
19 | self.dtype = dtype
20 |
21 | def forward(self, *args, **kwargs):
22 | return self.diffusion_model(*args, **kwargs)
23 |
24 |
25 | class OpenAIWrapper(IdentityWrapper):
26 |
27 | def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict,
28 | **kwargs) -> torch.Tensor:
29 | for key in c:
30 | c[key] = c[key].to(self.dtype)
31 |
32 | if x.dim() == 4:
33 | x = torch.cat((x, c.get('concat',
34 | torch.Tensor([]).type_as(x))),
35 | dim=1)
36 | elif x.dim() == 5:
37 | x = torch.cat((x, c.get('concat',
38 | torch.Tensor([]).type_as(x))),
39 | dim=2)
40 | else:
41 | raise ValueError('Input tensor must be 4D or 5D')
42 |
43 | return self.diffusion_model(
44 | x,
45 | timesteps=t,
46 | context=c.get('crossattn', None),
47 | y=c.get('vector', None),
48 | **kwargs,
49 | )
50 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 |
7 | def sample(self):
8 | raise NotImplementedError()
9 |
10 | def mode(self):
11 | raise NotImplementedError()
12 |
13 |
14 | class DiracDistribution(AbstractDistribution):
15 |
16 | def __init__(self, value):
17 | self.value = value
18 |
19 | def sample(self):
20 | return self.value
21 |
22 | def mode(self):
23 | return self.value
24 |
25 |
26 | class DiagonalGaussianDistribution:
27 |
28 | def __init__(self, parameters, deterministic=False):
29 | self.parameters = parameters
30 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
31 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
32 | self.deterministic = deterministic
33 | self.std = torch.exp(0.5 * self.logvar)
34 | self.var = torch.exp(self.logvar)
35 | if self.deterministic:
36 | self.var = self.std = torch.zeros_like(
37 | self.mean).to(device=self.parameters.device)
38 |
39 | def sample(self):
40 | # x = self.mean + self.std * torch.randn(self.mean.shape).to(
41 | # device=self.parameters.device
42 | # )
43 | x = self.mean + self.std * torch.randn_like(
44 | self.mean).to(device=self.parameters.device)
45 | return x
46 |
47 | def kl(self, other=None):
48 | if self.deterministic:
49 | return torch.Tensor([0.0])
50 | else:
51 | if other is None:
52 | return 0.5 * torch.sum(
53 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
54 | dim=[1, 2, 3],
55 | )
56 | else:
57 | return 0.5 * torch.sum(
58 | torch.pow(self.mean - other.mean, 2) / other.var +
59 | self.var / other.var - 1.0 - self.logvar + other.logvar,
60 | dim=[1, 2, 3],
61 | )
62 |
63 | def nll(self, sample, dims=[1, 2, 3]):
64 | if self.deterministic:
65 | return torch.Tensor([0.0])
66 | logtwopi = np.log(2.0 * np.pi)
67 | return 0.5 * torch.sum(
68 | logtwopi + self.logvar +
69 | torch.pow(sample - self.mean, 2) / self.var,
70 | dim=dims,
71 | )
72 |
73 | def mode(self):
74 | return self.mean
75 |
76 |
77 | def normal_kl(mean1, logvar1, mean2, logvar2):
78 | """
79 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
80 | Compute the KL divergence between two gaussians.
81 | Shapes are automatically broadcasted, so batches can be compared to
82 | scalars, among other use cases.
83 | """
84 | tensor = None
85 | for obj in (mean1, logvar1, mean2, logvar2):
86 | if isinstance(obj, torch.Tensor):
87 | tensor = obj
88 | break
89 | assert tensor is not None, 'at least one argument must be a Tensor'
90 |
91 | # Force variances to be Tensors. Broadcasting helps convert scalars to
92 | # Tensors, but it does not work for torch.exp().
93 | logvar1, logvar2 = (x if isinstance(x, torch.Tensor) else
94 | torch.tensor(x).to(tensor) for x in (logvar1, logvar2))
95 |
96 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
97 | ((mean1 - mean2)**2) * torch.exp(-logvar2))
98 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 |
7 | def __init__(self, model, decay=0.9999, use_num_upates=True):
8 | super().__init__()
9 | if decay < 0.0 or decay > 1.0:
10 | raise ValueError('Decay must be between 0 and 1')
11 |
12 | self.m_name2s_name = {}
13 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
14 | self.register_buffer(
15 | 'num_updates',
16 | torch.tensor(0, dtype=torch.int)
17 | if use_num_upates else torch.tensor(-1, dtype=torch.int),
18 | )
19 |
20 | for name, p in model.named_parameters():
21 | if p.requires_grad:
22 | # remove as '.'-character is not allowed in buffers
23 | s_name = name.replace('.', '')
24 | self.m_name2s_name.update({name: s_name})
25 | self.register_buffer(s_name, p.clone().detach().data)
26 |
27 | self.collected_params = []
28 |
29 | def reset_num_updates(self):
30 | del self.num_updates
31 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
32 |
33 | def forward(self, model):
34 | decay = self.decay
35 |
36 | if self.num_updates >= 0:
37 | self.num_updates += 1
38 | decay = min(self.decay,
39 | (1 + self.num_updates) / (10 + self.num_updates))
40 |
41 | one_minus_decay = 1.0 - decay
42 |
43 | with torch.no_grad():
44 | m_param = dict(model.named_parameters())
45 | shadow_params = dict(self.named_buffers())
46 |
47 | for key in m_param:
48 | if m_param[key].requires_grad:
49 | sname = self.m_name2s_name[key]
50 | shadow_params[sname] = shadow_params[sname].type_as(
51 | m_param[key])
52 | shadow_params[sname].sub_(
53 | one_minus_decay *
54 | (shadow_params[sname] - m_param[key]))
55 | else:
56 | assert not key in self.m_name2s_name
57 |
58 | def copy_to(self, model):
59 | m_param = dict(model.named_parameters())
60 | shadow_params = dict(self.named_buffers())
61 | for key in m_param:
62 | if m_param[key].requires_grad:
63 | m_param[key].data.copy_(
64 | shadow_params[self.m_name2s_name[key]].data)
65 | else:
66 | assert not key in self.m_name2s_name
67 |
68 | def store(self, parameters):
69 | """
70 | Save the current parameters for restoring later.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | temporarily stored.
74 | """
75 | self.collected_params = [param.clone() for param in parameters]
76 |
77 | def restore(self, parameters):
78 | """
79 | Restore the parameters stored with the `store` method.
80 | Useful to validate the model with EMA parameters without affecting the
81 | original optimization process. Store the parameters before the
82 | `copy_to` method. After validation (or model saving), use this to
83 | restore the former parameters.
84 | Args:
85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
86 | updated with the stored parameters.
87 | """
88 | for c_param, param in zip(self.collected_params, parameters):
89 | param.data.copy_(c_param.data)
90 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 | from contextlib import nullcontext
3 | from functools import partial
4 | from typing import Dict, List, Optional, Tuple, Union
5 |
6 | import kornia
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from einops import rearrange, repeat
11 | from omegaconf import ListConfig
12 | from torch.utils.checkpoint import checkpoint
13 | from transformers import T5EncoderModel, T5Tokenizer
14 |
15 | from ...util import (append_dims, autocast, count_params, default,
16 | disabled_train, expand_dims_like, instantiate_from_config)
17 |
18 |
19 | class AbstractEmbModel(nn.Module):
20 |
21 | def __init__(self):
22 | super().__init__()
23 | self._is_trainable = None
24 | self._ucg_rate = None
25 | self._input_key = None
26 |
27 | @property
28 | def is_trainable(self) -> bool:
29 | return self._is_trainable
30 |
31 | @property
32 | def ucg_rate(self) -> Union[float, torch.Tensor]:
33 | return self._ucg_rate
34 |
35 | @property
36 | def input_key(self) -> str:
37 | return self._input_key
38 |
39 | @is_trainable.setter
40 | def is_trainable(self, value: bool):
41 | self._is_trainable = value
42 |
43 | @ucg_rate.setter
44 | def ucg_rate(self, value: Union[float, torch.Tensor]):
45 | self._ucg_rate = value
46 |
47 | @input_key.setter
48 | def input_key(self, value: str):
49 | self._input_key = value
50 |
51 | @is_trainable.deleter
52 | def is_trainable(self):
53 | del self._is_trainable
54 |
55 | @ucg_rate.deleter
56 | def ucg_rate(self):
57 | del self._ucg_rate
58 |
59 | @input_key.deleter
60 | def input_key(self):
61 | del self._input_key
62 |
63 |
64 | class GeneralConditioner(nn.Module):
65 | OUTPUT_DIM2KEYS = {2: 'vector', 3: 'crossattn', 4: 'concat', 5: 'concat'}
66 | KEY2CATDIM = {'vector': 1, 'crossattn': 2, 'concat': 1}
67 |
68 | def __init__(self,
69 | emb_models: Union[List, ListConfig],
70 | cor_embs=[],
71 | cor_p=[]):
72 | super().__init__()
73 | embedders = []
74 | for n, embconfig in enumerate(emb_models):
75 | embedder = instantiate_from_config(embconfig)
76 | assert isinstance(
77 | embedder, AbstractEmbModel
78 | ), f'embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel'
79 | embedder.is_trainable = embconfig.get('is_trainable', False)
80 | embedder.ucg_rate = embconfig.get('ucg_rate', 0.0)
81 | if not embedder.is_trainable:
82 | embedder.train = disabled_train
83 | for param in embedder.parameters():
84 | param.requires_grad = False
85 | embedder.eval()
86 | print(
87 | f'Initialized embedder #{n}: {embedder.__class__.__name__} '
88 | f'with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}'
89 | )
90 |
91 | if 'input_key' in embconfig:
92 | embedder.input_key = embconfig['input_key']
93 | elif 'input_keys' in embconfig:
94 | embedder.input_keys = embconfig['input_keys']
95 | else:
96 | raise KeyError(
97 | f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
98 | )
99 |
100 | embedder.legacy_ucg_val = embconfig.get('legacy_ucg_value', None)
101 | if embedder.legacy_ucg_val is not None:
102 | embedder.ucg_prng = np.random.RandomState()
103 |
104 | embedders.append(embedder)
105 | self.embedders = nn.ModuleList(embedders)
106 |
107 | if len(cor_embs) > 0:
108 | assert len(cor_p) == 2**len(cor_embs)
109 | self.cor_embs = cor_embs
110 | self.cor_p = cor_p
111 |
112 | def possibly_get_ucg_val(self, embedder: AbstractEmbModel,
113 | batch: Dict) -> Dict:
114 | assert embedder.legacy_ucg_val is not None
115 | p = embedder.ucg_rate
116 | val = embedder.legacy_ucg_val
117 | for i in range(len(batch[embedder.input_key])):
118 | if embedder.ucg_prng.choice(2, p=[1 - p, p]):
119 | batch[embedder.input_key][i] = val
120 | return batch
121 |
122 | def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict,
123 | cond_or_not) -> Dict:
124 | assert embedder.legacy_ucg_val is not None
125 | val = embedder.legacy_ucg_val
126 | for i in range(len(batch[embedder.input_key])):
127 | if cond_or_not[i]:
128 | batch[embedder.input_key][i] = val
129 | return batch
130 |
131 | def get_single_embedding(
132 | self,
133 | embedder,
134 | batch,
135 | output,
136 | cond_or_not: Optional[np.ndarray] = None,
137 | force_zero_embeddings: Optional[List] = None,
138 | ):
139 | embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
140 | with embedding_context():
141 | if hasattr(embedder, 'input_key') and (embedder.input_key
142 | is not None):
143 | if embedder.legacy_ucg_val is not None:
144 | if cond_or_not is None:
145 | batch = self.possibly_get_ucg_val(embedder, batch)
146 | else:
147 | batch = self.surely_get_ucg_val(
148 | embedder, batch, cond_or_not)
149 | emb_out = embedder(batch[embedder.input_key])
150 | elif hasattr(embedder, 'input_keys'):
151 | emb_out = embedder(*[batch[k] for k in embedder.input_keys])
152 | assert isinstance(
153 | emb_out, (torch.Tensor, list, tuple)
154 | ), f'encoder outputs must be tensors or a sequence, but got {type(emb_out)}'
155 | if not isinstance(emb_out, (list, tuple)):
156 | emb_out = [emb_out]
157 | for emb in emb_out:
158 | out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
159 | if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
160 | if cond_or_not is None:
161 | emb = (expand_dims_like(
162 | torch.bernoulli(
163 | (1.0 - embedder.ucg_rate) *
164 | torch.ones(emb.shape[0], device=emb.device)),
165 | emb,
166 | ) * emb)
167 | else:
168 | emb = (expand_dims_like(
169 | torch.tensor(1 - cond_or_not,
170 | dtype=emb.dtype,
171 | device=emb.device),
172 | emb,
173 | ) * emb)
174 | if hasattr(embedder, 'input_key'
175 | ) and embedder.input_key in force_zero_embeddings:
176 | emb = torch.zeros_like(emb)
177 | if out_key in output:
178 | output[out_key] = torch.cat((output[out_key], emb),
179 | self.KEY2CATDIM[out_key])
180 | else:
181 | output[out_key] = emb
182 | return output
183 |
184 | def forward(self,
185 | batch: Dict,
186 | force_zero_embeddings: Optional[List] = None) -> Dict:
187 | output = dict()
188 | if force_zero_embeddings is None:
189 | force_zero_embeddings = []
190 |
191 | if len(self.cor_embs) > 0:
192 | batch_size = len(batch[list(batch.keys())[0]])
193 | rand_idx = np.random.choice(len(self.cor_p),
194 | size=(batch_size, ),
195 | p=self.cor_p)
196 | for emb_idx in self.cor_embs:
197 | cond_or_not = rand_idx % 2
198 | rand_idx //= 2
199 | output = self.get_single_embedding(
200 | self.embedders[emb_idx],
201 | batch,
202 | output=output,
203 | cond_or_not=cond_or_not,
204 | force_zero_embeddings=force_zero_embeddings,
205 | )
206 |
207 | for i, embedder in enumerate(self.embedders):
208 | if i in self.cor_embs:
209 | continue
210 | output = self.get_single_embedding(
211 | embedder,
212 | batch,
213 | output=output,
214 | force_zero_embeddings=force_zero_embeddings)
215 | return output
216 |
217 | def get_unconditional_conditioning(self,
218 | batch_c,
219 | batch_uc=None,
220 | force_uc_zero_embeddings=None):
221 | if force_uc_zero_embeddings is None:
222 | force_uc_zero_embeddings = []
223 | ucg_rates = list()
224 | for embedder in self.embedders:
225 | ucg_rates.append(embedder.ucg_rate)
226 | embedder.ucg_rate = 0.0
227 | cor_embs = self.cor_embs
228 | cor_p = self.cor_p
229 | self.cor_embs = []
230 | self.cor_p = []
231 |
232 | c = self(batch_c)
233 | uc = self(batch_c if batch_uc is None else batch_uc,
234 | force_uc_zero_embeddings)
235 |
236 | for embedder, rate in zip(self.embedders, ucg_rates):
237 | embedder.ucg_rate = rate
238 | self.cor_embs = cor_embs
239 | self.cor_p = cor_p
240 |
241 | return c, uc
242 |
243 |
244 | class FrozenT5Embedder(AbstractEmbModel):
245 | """Uses the T5 transformer encoder for text"""
246 |
247 | def __init__(
248 | self,
249 | model_dir='google/t5-v1_1-xxl',
250 | device='cuda',
251 | max_length=77,
252 | freeze=True,
253 | cache_dir=None,
254 | ):
255 | super().__init__()
256 | if model_dir != 'google/t5-v1_1-xxl':
257 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
258 | self.transformer = T5EncoderModel.from_pretrained(model_dir)
259 | else:
260 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir,
261 | cache_dir=cache_dir)
262 | self.transformer = T5EncoderModel.from_pretrained(
263 | model_dir, cache_dir=cache_dir)
264 | self.device = device
265 | self.max_length = max_length
266 | if freeze:
267 | self.freeze()
268 |
269 | def freeze(self):
270 | self.transformer = self.transformer.eval()
271 |
272 | for param in self.parameters():
273 | param.requires_grad = False
274 |
275 | # @autocast
276 | def forward(self, text):
277 | batch_encoding = self.tokenizer(
278 | text,
279 | truncation=True,
280 | max_length=self.max_length,
281 | return_length=True,
282 | return_overflowing_tokens=False,
283 | padding='max_length',
284 | return_tensors='pt',
285 | )
286 | tokens = batch_encoding['input_ids'].to(self.device)
287 | with torch.autocast('cuda', enabled=False):
288 | outputs = self.transformer(input_ids=tokens)
289 | z = outputs.last_hidden_state
290 | return z
291 |
292 | def encode(self, text):
293 | return self(text)
294 |
--------------------------------------------------------------------------------
/flashvideo/sgm/modules/video_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..modules.attention import *
4 | from ..modules.diffusionmodules.util import (AlphaBlender, linear,
5 | timestep_embedding)
6 |
7 |
8 | class TimeMixSequential(nn.Sequential):
9 |
10 | def forward(self, x, context=None, timesteps=None):
11 | for layer in self:
12 | x = layer(x, context, timesteps)
13 |
14 | return x
15 |
16 |
17 | class VideoTransformerBlock(nn.Module):
18 | ATTENTION_MODES = {
19 | 'softmax': CrossAttention,
20 | 'softmax-xformers': MemoryEfficientCrossAttention,
21 | }
22 |
23 | def __init__(
24 | self,
25 | dim,
26 | n_heads,
27 | d_head,
28 | dropout=0.0,
29 | context_dim=None,
30 | gated_ff=True,
31 | checkpoint=True,
32 | timesteps=None,
33 | ff_in=False,
34 | inner_dim=None,
35 | attn_mode='softmax',
36 | disable_self_attn=False,
37 | disable_temporal_crossattention=False,
38 | switch_temporal_ca_to_sa=False,
39 | ):
40 | super().__init__()
41 |
42 | attn_cls = self.ATTENTION_MODES[attn_mode]
43 |
44 | self.ff_in = ff_in or inner_dim is not None
45 | if inner_dim is None:
46 | inner_dim = dim
47 |
48 | assert int(n_heads * d_head) == inner_dim
49 |
50 | self.is_res = inner_dim == dim
51 |
52 | if self.ff_in:
53 | self.norm_in = nn.LayerNorm(dim)
54 | self.ff_in = FeedForward(dim,
55 | dim_out=inner_dim,
56 | dropout=dropout,
57 | glu=gated_ff)
58 |
59 | self.timesteps = timesteps
60 | self.disable_self_attn = disable_self_attn
61 | if self.disable_self_attn:
62 | self.attn1 = attn_cls(
63 | query_dim=inner_dim,
64 | heads=n_heads,
65 | dim_head=d_head,
66 | context_dim=context_dim,
67 | dropout=dropout,
68 | ) # is a cross-attention
69 | else:
70 | self.attn1 = attn_cls(query_dim=inner_dim,
71 | heads=n_heads,
72 | dim_head=d_head,
73 | dropout=dropout) # is a self-attention
74 |
75 | self.ff = FeedForward(inner_dim,
76 | dim_out=dim,
77 | dropout=dropout,
78 | glu=gated_ff)
79 |
80 | if disable_temporal_crossattention:
81 | if switch_temporal_ca_to_sa:
82 | raise ValueError
83 | else:
84 | self.attn2 = None
85 | else:
86 | self.norm2 = nn.LayerNorm(inner_dim)
87 | if switch_temporal_ca_to_sa:
88 | self.attn2 = attn_cls(query_dim=inner_dim,
89 | heads=n_heads,
90 | dim_head=d_head,
91 | dropout=dropout) # is a self-attention
92 | else:
93 | self.attn2 = attn_cls(
94 | query_dim=inner_dim,
95 | context_dim=context_dim,
96 | heads=n_heads,
97 | dim_head=d_head,
98 | dropout=dropout,
99 | ) # is self-attn if context is none
100 |
101 | self.norm1 = nn.LayerNorm(inner_dim)
102 | self.norm3 = nn.LayerNorm(inner_dim)
103 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
104 |
105 | self.checkpoint = checkpoint
106 | if self.checkpoint:
107 | print(f'{self.__class__.__name__} is using checkpointing')
108 |
109 | def forward(self,
110 | x: torch.Tensor,
111 | context: torch.Tensor = None,
112 | timesteps: int = None) -> torch.Tensor:
113 | if self.checkpoint:
114 | return checkpoint(self._forward, x, context, timesteps)
115 | else:
116 | return self._forward(x, context, timesteps=timesteps)
117 |
118 | def _forward(self, x, context=None, timesteps=None):
119 | assert self.timesteps or timesteps
120 | assert not (self.timesteps
121 | and timesteps) or self.timesteps == timesteps
122 | timesteps = self.timesteps or timesteps
123 | B, S, C = x.shape
124 | x = rearrange(x, '(b t) s c -> (b s) t c', t=timesteps)
125 |
126 | if self.ff_in:
127 | x_skip = x
128 | x = self.ff_in(self.norm_in(x))
129 | if self.is_res:
130 | x += x_skip
131 |
132 | if self.disable_self_attn:
133 | x = self.attn1(self.norm1(x), context=context) + x
134 | else:
135 | x = self.attn1(self.norm1(x)) + x
136 |
137 | if self.attn2 is not None:
138 | if self.switch_temporal_ca_to_sa:
139 | x = self.attn2(self.norm2(x)) + x
140 | else:
141 | x = self.attn2(self.norm2(x), context=context) + x
142 | x_skip = x
143 | x = self.ff(self.norm3(x))
144 | if self.is_res:
145 | x += x_skip
146 |
147 | x = rearrange(x,
148 | '(b s) t c -> (b t) s c',
149 | s=S,
150 | b=B // timesteps,
151 | c=C,
152 | t=timesteps)
153 | return x
154 |
155 | def get_last_layer(self):
156 | return self.ff.net[-1].weight
157 |
158 |
159 | str_to_dtype = {
160 | 'fp32': torch.float32,
161 | 'fp16': torch.float16,
162 | 'bf16': torch.bfloat16
163 | }
164 |
165 |
166 | class SpatialVideoTransformer(SpatialTransformer):
167 |
168 | def __init__(
169 | self,
170 | in_channels,
171 | n_heads,
172 | d_head,
173 | depth=1,
174 | dropout=0.0,
175 | use_linear=False,
176 | context_dim=None,
177 | use_spatial_context=False,
178 | timesteps=None,
179 | merge_strategy: str = 'fixed',
180 | merge_factor: float = 0.5,
181 | time_context_dim=None,
182 | ff_in=False,
183 | checkpoint=False,
184 | time_depth=1,
185 | attn_mode='softmax',
186 | disable_self_attn=False,
187 | disable_temporal_crossattention=False,
188 | max_time_embed_period: int = 10000,
189 | dtype='fp32',
190 | ):
191 | super().__init__(
192 | in_channels,
193 | n_heads,
194 | d_head,
195 | depth=depth,
196 | dropout=dropout,
197 | attn_type=attn_mode,
198 | use_checkpoint=checkpoint,
199 | context_dim=context_dim,
200 | use_linear=use_linear,
201 | disable_self_attn=disable_self_attn,
202 | )
203 | self.time_depth = time_depth
204 | self.depth = depth
205 | self.max_time_embed_period = max_time_embed_period
206 |
207 | time_mix_d_head = d_head
208 | n_time_mix_heads = n_heads
209 |
210 | time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
211 |
212 | inner_dim = n_heads * d_head
213 | if use_spatial_context:
214 | time_context_dim = context_dim
215 |
216 | self.time_stack = nn.ModuleList([
217 | VideoTransformerBlock(
218 | inner_dim,
219 | n_time_mix_heads,
220 | time_mix_d_head,
221 | dropout=dropout,
222 | context_dim=time_context_dim,
223 | timesteps=timesteps,
224 | checkpoint=checkpoint,
225 | ff_in=ff_in,
226 | inner_dim=time_mix_inner_dim,
227 | attn_mode=attn_mode,
228 | disable_self_attn=disable_self_attn,
229 | disable_temporal_crossattention=disable_temporal_crossattention,
230 | ) for _ in range(self.depth)
231 | ])
232 |
233 | assert len(self.time_stack) == len(self.transformer_blocks)
234 |
235 | self.use_spatial_context = use_spatial_context
236 | self.in_channels = in_channels
237 |
238 | time_embed_dim = self.in_channels * 4
239 | self.time_pos_embed = nn.Sequential(
240 | linear(self.in_channels, time_embed_dim),
241 | nn.SiLU(),
242 | linear(time_embed_dim, self.in_channels),
243 | )
244 |
245 | self.time_mixer = AlphaBlender(alpha=merge_factor,
246 | merge_strategy=merge_strategy)
247 | self.dtype = str_to_dtype[dtype]
248 |
249 | def forward(
250 | self,
251 | x: torch.Tensor,
252 | context: Optional[torch.Tensor] = None,
253 | time_context: Optional[torch.Tensor] = None,
254 | timesteps: Optional[int] = None,
255 | image_only_indicator: Optional[torch.Tensor] = None,
256 | ) -> torch.Tensor:
257 | _, _, h, w = x.shape
258 | x_in = x
259 | spatial_context = None
260 | if exists(context):
261 | spatial_context = context
262 |
263 | if self.use_spatial_context:
264 | assert context.ndim == 3, f'n dims of spatial context should be 3 but are {context.ndim}'
265 |
266 | time_context = context
267 | time_context_first_timestep = time_context[::timesteps]
268 | time_context = repeat(time_context_first_timestep,
269 | 'b ... -> (b n) ...',
270 | n=h * w)
271 | elif time_context is not None and not self.use_spatial_context:
272 | time_context = repeat(time_context, 'b ... -> (b n) ...', n=h * w)
273 | if time_context.ndim == 2:
274 | time_context = rearrange(time_context, 'b c -> b 1 c')
275 |
276 | x = self.norm(x)
277 | if not self.use_linear:
278 | x = self.proj_in(x)
279 | x = rearrange(x, 'b c h w -> b (h w) c')
280 | if self.use_linear:
281 | x = self.proj_in(x)
282 |
283 | num_frames = torch.arange(timesteps, device=x.device)
284 | num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps)
285 | num_frames = rearrange(num_frames, 'b t -> (b t)')
286 | t_emb = timestep_embedding(
287 | num_frames,
288 | self.in_channels,
289 | repeat_only=False,
290 | max_period=self.max_time_embed_period,
291 | dtype=self.dtype,
292 | )
293 | emb = self.time_pos_embed(t_emb)
294 | emb = emb[:, None, :]
295 |
296 | for it_, (block, mix_block) in enumerate(
297 | zip(self.transformer_blocks, self.time_stack)):
298 | x = block(
299 | x,
300 | context=spatial_context,
301 | )
302 |
303 | x_mix = x
304 | x_mix = x_mix + emb
305 |
306 | x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
307 | x = self.time_mixer(
308 | x_spatial=x,
309 | x_temporal=x_mix,
310 | image_only_indicator=image_only_indicator,
311 | )
312 | if self.use_linear:
313 | x = self.proj_out(x)
314 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
315 | if not self.use_linear:
316 | x = self.proj_out(x)
317 | out = x + x_in
318 | return out
319 |
--------------------------------------------------------------------------------
/flashvideo/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Union
3 |
4 | import numpy as np
5 | import torch
6 | from omegaconf import ListConfig
7 | from sgm.util import instantiate_from_config
8 |
9 |
10 | def read_from_file(p, rank=0, world_size=1):
11 | with open(p) as fin:
12 | cnt = -1
13 | for l in fin:
14 | cnt += 1
15 | if cnt % world_size != rank:
16 | continue
17 | yield l.strip(), cnt
18 |
19 |
20 | def disable_all_init():
21 | """Disable all redundant torch default initialization to accelerate model
22 | creation."""
23 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
24 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
25 | setattr(torch.nn.modules.sparse.Embedding, 'reset_parameters',
26 | lambda self: None)
27 | setattr(torch.nn.modules.conv.Conv2d, 'reset_parameters',
28 | lambda self: None)
29 | setattr(torch.nn.modules.normalization.GroupNorm, 'reset_parameters',
30 | lambda self: None)
31 |
32 |
33 | def get_unique_embedder_keys_from_conditioner(conditioner):
34 | return list({x.input_key for x in conditioner.embedders})
35 |
36 |
37 | def get_batch(keys,
38 | value_dict,
39 | N: Union[List, ListConfig],
40 | T=None,
41 | device='cuda'):
42 | batch = {}
43 | batch_uc = {}
44 |
45 | for key in keys:
46 | if key == 'txt':
47 | batch['txt'] = np.repeat([value_dict['prompt']],
48 | repeats=math.prod(N)).reshape(N).tolist()
49 | batch_uc['txt'] = np.repeat(
50 | [value_dict['negative_prompt']],
51 | repeats=math.prod(N)).reshape(N).tolist()
52 | else:
53 | batch[key] = value_dict[key]
54 |
55 | if T is not None:
56 | batch['num_video_frames'] = T
57 |
58 | for key in batch.keys():
59 | if key not in batch_uc and isinstance(batch[key], torch.Tensor):
60 | batch_uc[key] = torch.clone(batch[key])
61 | return batch, batch_uc
62 |
63 |
64 | def decode(first_stage_model, latent):
65 | first_stage_model.to(torch.float16)
66 | latent = latent.to(torch.float16)
67 | recons = []
68 | T = latent.shape[2]
69 | if T > 2:
70 | loop_num = (T - 1) // 2
71 | for i in range(loop_num):
72 | if i == 0:
73 | start_frame, end_frame = 0, 3
74 | else:
75 | start_frame, end_frame = i * 2 + 1, i * 2 + 3
76 | if i == loop_num - 1:
77 | clear_fake_cp_cache = True
78 | else:
79 | clear_fake_cp_cache = False
80 | with torch.no_grad():
81 | recon = first_stage_model.decode(
82 | latent[:, :, start_frame:end_frame].contiguous(),
83 | clear_fake_cp_cache=clear_fake_cp_cache)
84 |
85 | recons.append(recon)
86 | else:
87 |
88 | clear_fake_cp_cache = True
89 | if latent.shape[2] > 1:
90 | for m in first_stage_model.modules():
91 | m.force_split = True
92 | recon = first_stage_model.decode(
93 | latent.contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
94 | recons.append(recon)
95 | recon = torch.cat(recons, dim=2).to(torch.float32)
96 | samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
97 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
98 | samples = (samples * 255).squeeze(0).permute(0, 2, 3, 1)
99 | save_frames = samples
100 |
101 | return save_frames
102 |
103 |
104 | def save_mem_decode(first_stage_model, latent):
105 |
106 | l_h, l_w = latent.shape[3], latent.shape[4]
107 | T = latent.shape[2]
108 | F = 8
109 | # split spatial along h w
110 | num_h_splits = 1
111 | num_w_splits = 2
112 | ori_video = torch.zeros((1, 3, 1 + 4 * (T - 1), l_h * 8, l_w * 8),
113 | device=latent.device)
114 | for h_idx in range(num_h_splits):
115 | for w_idx in range(num_w_splits):
116 | start_h = h_idx * latent.shape[3] // num_h_splits
117 | end_h = (h_idx + 1) * latent.shape[3] // num_h_splits
118 | start_w = w_idx * latent.shape[4] // num_w_splits
119 | end_w = (w_idx + 1) * latent.shape[4] // num_w_splits
120 |
121 | latent_overlap = 16
122 | if (start_h - latent_overlap >= 0) and (num_h_splits > 1):
123 | real_start_h = start_h - latent_overlap
124 | h_start_overlap = latent_overlap * F
125 | else:
126 | h_start_overlap = 0
127 | real_start_h = start_h
128 | if (end_h + latent_overlap <= l_h) and (num_h_splits > 1):
129 | real_end_h = end_h + latent_overlap
130 | h_end_overlap = latent_overlap * F
131 | else:
132 | h_end_overlap = 0
133 | real_end_h = end_h
134 |
135 | if (start_w - latent_overlap >= 0) and (num_w_splits > 1):
136 | real_start_w = start_w - latent_overlap
137 | w_start_overlap = latent_overlap * F
138 | else:
139 | w_start_overlap = 0
140 | real_start_w = start_w
141 |
142 | if (end_w + latent_overlap <= l_w) and (num_w_splits > 1):
143 | real_end_w = end_w + latent_overlap
144 | w_end_overlap = latent_overlap * F
145 | else:
146 | w_end_overlap = 0
147 | real_end_w = end_w
148 |
149 | latent_slice = latent[:, :, :, real_start_h:real_end_h,
150 | real_start_w:real_end_w]
151 | recon = decode(first_stage_model, latent_slice)
152 |
153 | recon = recon.permute(3, 0, 1, 2).contiguous()[None]
154 |
155 | recon = recon[:, :, :,
156 | h_start_overlap:recon.shape[3] - h_end_overlap,
157 | w_start_overlap:recon.shape[4] - w_end_overlap]
158 | ori_video[:, :, :, start_h * 8:end_h * 8,
159 | start_w * 8:end_w * 8] = recon
160 | ori_video = ori_video.squeeze(0)
161 | ori_video = ori_video.permute(1, 2, 3, 0).contiguous().cpu()
162 | return ori_video
163 |
164 |
165 | def prepare_input(text, model, T, negative_prompt=None, pos_prompt=None):
166 |
167 | if negative_prompt is None:
168 | negative_prompt = ''
169 | if pos_prompt is None:
170 | pos_prompt = ''
171 | value_dict = {
172 | 'prompt': text + pos_prompt,
173 | 'negative_prompt': negative_prompt,
174 | 'num_frames': torch.tensor(T).unsqueeze(0),
175 | }
176 | print(value_dict)
177 | batch, batch_uc = get_batch(
178 | get_unique_embedder_keys_from_conditioner(model.conditioner),
179 | value_dict, [1])
180 |
181 | for key in batch:
182 | if isinstance(batch[key], torch.Tensor):
183 | print(key, batch[key].shape)
184 | elif isinstance(batch[key], list):
185 | print(key, [len(l) for l in batch[key]])
186 | else:
187 | print(key, batch[key])
188 | c, uc = model.conditioner.get_unconditional_conditioning(
189 | batch,
190 | batch_uc=batch_uc,
191 | force_uc_zero_embeddings=['txt'],
192 | )
193 |
194 | for k in c:
195 | if not k == 'crossattn':
196 | c[k], uc[k] = map(lambda y: y[k][:math.prod([1])].to('cuda'),
197 | (c, uc))
198 | return c, uc
199 |
200 |
201 | def save_memory_encode_first_stage(x, model):
202 | splits_x = torch.split(x, [17, 16, 16], dim=2)
203 | all_out = []
204 |
205 | with torch.autocast('cuda', enabled=False):
206 | for idx, input_x in enumerate(splits_x):
207 | if idx == len(splits_x) - 1:
208 | clear_fake_cp_cache = True
209 | else:
210 | clear_fake_cp_cache = False
211 | out = model.first_stage_model.encode(
212 | input_x.contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
213 | all_out.append(out)
214 |
215 | z = torch.cat(all_out, dim=2)
216 | z = model.scale_factor * z
217 | return z
218 |
219 |
220 | def seed_everything(seed: int = 42):
221 | import os
222 | import random
223 |
224 | import numpy as np
225 | import torch
226 |
227 | # Python random module
228 | random.seed(seed)
229 |
230 | # Numpy
231 | np.random.seed(seed)
232 |
233 | # PyTorch
234 | torch.manual_seed(seed)
235 |
236 | # If using CUDA
237 | torch.cuda.manual_seed(seed)
238 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
239 |
240 | # # CuDNN
241 | # torch.backends.cudnn.deterministic = True
242 | # torch.backends.cudnn.benchmark = False
243 |
244 | # OS environment
245 | os.environ['PYTHONHASHSEED'] = str(seed)
246 |
247 |
248 | def get_time_slice_vae():
249 | vae_config = {
250 | 'target': 'vae_modules.autoencoder.VideoAutoencoderInferenceWrapper',
251 | 'params': {
252 | 'cp_size': 1,
253 | 'ckpt_path': './checkpoints/3d-vae.pt',
254 | 'ignore_keys': ['loss'],
255 | 'loss_config': {
256 | 'target': 'torch.nn.Identity'
257 | },
258 | 'regularizer_config': {
259 | 'target':
260 | 'vae_modules.regularizers.DiagonalGaussianRegularizer'
261 | },
262 | 'encoder_config': {
263 | 'target':
264 | 'vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D',
265 | 'params': {
266 | 'double_z': True,
267 | 'z_channels': 16,
268 | 'resolution': 256,
269 | 'in_channels': 3,
270 | 'out_ch': 3,
271 | 'ch': 128,
272 | 'ch_mult': [1, 2, 2, 4],
273 | 'attn_resolutions': [],
274 | 'num_res_blocks': 3,
275 | 'dropout': 0.0,
276 | 'gather_norm': False
277 | }
278 | },
279 | 'decoder_config': {
280 | 'target': 'vae_modules.cp_enc_dec.ContextParallelDecoder3D',
281 | 'params': {
282 | 'double_z': True,
283 | 'z_channels': 16,
284 | 'resolution': 256,
285 | 'in_channels': 3,
286 | 'out_ch': 3,
287 | 'ch': 128,
288 | 'ch_mult': [1, 2, 2, 4],
289 | 'attn_resolutions': [],
290 | 'num_res_blocks': 3,
291 | 'dropout': 0.0,
292 | 'gather_norm': False
293 | }
294 | }
295 | }
296 | }
297 |
298 | vae = instantiate_from_config(vae_config).eval().half().cuda()
299 | return vae
300 |
--------------------------------------------------------------------------------
/flashvideo/vae_modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 |
7 | def __init__(self, model, decay=0.9999, use_num_upates=True):
8 | super().__init__()
9 | if decay < 0.0 or decay > 1.0:
10 | raise ValueError('Decay must be between 0 and 1')
11 |
12 | self.m_name2s_name = {}
13 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
14 | self.register_buffer(
15 | 'num_updates',
16 | torch.tensor(0, dtype=torch.int)
17 | if use_num_upates else torch.tensor(-1, dtype=torch.int),
18 | )
19 |
20 | for name, p in model.named_parameters():
21 | if p.requires_grad:
22 | # remove as '.'-character is not allowed in buffers
23 | s_name = name.replace('.', '')
24 | self.m_name2s_name.update({name: s_name})
25 | self.register_buffer(s_name, p.clone().detach().data)
26 |
27 | self.collected_params = []
28 |
29 | def reset_num_updates(self):
30 | del self.num_updates
31 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
32 |
33 | def forward(self, model):
34 | decay = self.decay
35 |
36 | if self.num_updates >= 0:
37 | self.num_updates += 1
38 | decay = min(self.decay,
39 | (1 + self.num_updates) / (10 + self.num_updates))
40 |
41 | one_minus_decay = 1.0 - decay
42 |
43 | with torch.no_grad():
44 | m_param = dict(model.named_parameters())
45 | shadow_params = dict(self.named_buffers())
46 |
47 | for key in m_param:
48 | if m_param[key].requires_grad:
49 | sname = self.m_name2s_name[key]
50 | shadow_params[sname] = shadow_params[sname].type_as(
51 | m_param[key])
52 | shadow_params[sname].sub_(
53 | one_minus_decay *
54 | (shadow_params[sname] - m_param[key]))
55 | else:
56 | assert not key in self.m_name2s_name
57 |
58 | def copy_to(self, model):
59 | m_param = dict(model.named_parameters())
60 | shadow_params = dict(self.named_buffers())
61 | for key in m_param:
62 | if m_param[key].requires_grad:
63 | m_param[key].data.copy_(
64 | shadow_params[self.m_name2s_name[key]].data)
65 | else:
66 | assert not key in self.m_name2s_name
67 |
68 | def store(self, parameters):
69 | """
70 | Save the current parameters for restoring later.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | temporarily stored.
74 | """
75 | self.collected_params = [param.clone() for param in parameters]
76 |
77 | def restore(self, parameters):
78 | """
79 | Restore the parameters stored with the `store` method.
80 | Useful to validate the model with EMA parameters without affecting the
81 | original optimization process. Store the parameters before the
82 | `copy_to` method. After validation (or model saving), use this to
83 | restore the former parameters.
84 | Args:
85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
86 | updated with the stored parameters.
87 | """
88 | for c_param, param in zip(self.collected_params, parameters):
89 | param.data.copy_(c_param.data)
90 |
--------------------------------------------------------------------------------
/flashvideo/vae_modules/regularizers.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class DiagonalGaussianDistribution:
11 |
12 | def __init__(self, parameters, deterministic=False):
13 | self.parameters = parameters
14 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
15 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
16 | self.deterministic = deterministic
17 | self.std = torch.exp(0.5 * self.logvar)
18 | self.var = torch.exp(self.logvar)
19 | if self.deterministic:
20 | self.var = self.std = torch.zeros_like(
21 | self.mean).to(device=self.parameters.device)
22 |
23 | def sample(self):
24 | # x = self.mean + self.std * torch.randn(self.mean.shape).to(
25 | # device=self.parameters.device
26 | # )
27 | x = self.mean + self.std * torch.randn_like(self.mean)
28 | return x
29 |
30 | def kl(self, other=None):
31 | if self.deterministic:
32 | return torch.Tensor([0.0])
33 | else:
34 | if other is None:
35 | return 0.5 * torch.sum(
36 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
37 | dim=[1, 2, 3],
38 | )
39 | else:
40 | return 0.5 * torch.sum(
41 | torch.pow(self.mean - other.mean, 2) / other.var +
42 | self.var / other.var - 1.0 - self.logvar + other.logvar,
43 | dim=[1, 2, 3],
44 | )
45 |
46 | def nll(self, sample, dims=[1, 2, 3]):
47 | if self.deterministic:
48 | return torch.Tensor([0.0])
49 | logtwopi = np.log(2.0 * np.pi)
50 | return 0.5 * torch.sum(
51 | logtwopi + self.logvar +
52 | torch.pow(sample - self.mean, 2) / self.var,
53 | dim=dims,
54 | )
55 |
56 | def mode(self):
57 | return self.mean
58 |
59 |
60 | class AbstractRegularizer(nn.Module):
61 |
62 | def __init__(self):
63 | super().__init__()
64 |
65 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
66 | raise NotImplementedError()
67 |
68 | @abstractmethod
69 | def get_trainable_parameters(self) -> Any:
70 | raise NotImplementedError()
71 |
72 |
73 | class IdentityRegularizer(AbstractRegularizer):
74 |
75 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
76 | return z, dict()
77 |
78 | def get_trainable_parameters(self) -> Any:
79 | yield from ()
80 |
81 |
82 | def measure_perplexity(
83 | predicted_indices: torch.Tensor,
84 | num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
85 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
86 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
87 | encodings = F.one_hot(predicted_indices,
88 | num_centroids).float().reshape(-1, num_centroids)
89 | avg_probs = encodings.mean(0)
90 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
91 | cluster_use = torch.sum(avg_probs > 0)
92 | return perplexity, cluster_use
93 |
94 |
95 | class DiagonalGaussianRegularizer(AbstractRegularizer):
96 |
97 | def __init__(self, sample: bool = True):
98 | super().__init__()
99 | self.sample = sample
100 |
101 | def get_trainable_parameters(self) -> Any:
102 | yield from ()
103 |
104 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
105 | log = dict()
106 | posterior = DiagonalGaussianDistribution(z)
107 | if self.sample:
108 | z = posterior.sample()
109 | else:
110 | z = posterior.mode()
111 | kl_loss = posterior.kl()
112 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
113 | log['kl_loss'] = kl_loss
114 | return z, log
115 |
--------------------------------------------------------------------------------
/inf_270_1080p.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node=8 \
2 | --nnodes=1 \
3 | --node_rank=0 \
4 | --master_port=20023 flashvideo/dist_inf_text_file.py \
5 | --base "flashvideo/configs/stage1.yaml" \
6 | --second "flashvideo/configs/stage2.yaml" \
7 | --inf-ckpt ./checkpoints/stage1.pt \
8 | --inf-ckpt2 ./checkpoints/stage2.pt \
9 | --input-file ./example.txt \
10 | --output-dir ./vis_270p_1080p_example
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested
2 | diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested
3 | gradio>=4.42.0 # For HF gradio demo
4 | imageio==2.34.2 # For diffusers inference export video
5 | imageio-ffmpeg==0.5.1 # For diffusers inference export video
6 | moviepy==1.0.3 # For export video
7 | numpy==1.26.0
8 | openai>=1.42.0 # For prompt refiner
9 | pillow==9.5.0
10 | sentencepiece>=0.2.0 # T5 used
11 | streamlit>=1.38.0 # For streamlit web demo
12 | SwissArmyTransformer>=0.4.12
13 | torch>=2.4.0 # Tested in 2.2 2.3 2.4 and 2.5, The development team is working on version 2.4.0.
14 | torchvision>=0.19.0 # The development team is working on version 0.19.0.
15 | transformers>=4.44.2 # The development team is working on version 4.44.2
16 |
--------------------------------------------------------------------------------