├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MODEL_LICENSE ├── README.md ├── example.txt ├── figs ├── 22_0.jpg ├── logo.png ├── methodv3.png └── pipeline.png ├── flashvideo ├── arguments.py ├── base_model.py ├── base_transformer.py ├── configs │ ├── stage1.yaml │ └── stage2.yaml ├── demo.ipynb ├── diffusion_video.py ├── dist_inf_text_file.py ├── dit_video_concat.py ├── extra_models │ └── dit_res_adapter.py ├── flow_video.py ├── sgm │ ├── __init__.py │ ├── lr_scheduler.py │ ├── models │ │ ├── __init__.py │ │ └── autoencoder.py │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── autoencoding │ │ │ ├── __init__.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── discriminator_loss.py │ │ │ │ ├── lpips.py │ │ │ │ └── video_loss.py │ │ │ ├── lpips │ │ │ │ ├── __init__.py │ │ │ │ ├── loss │ │ │ │ │ ├── .gitignore │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── lpips.py │ │ │ │ ├── model │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── model.py │ │ │ │ ├── util.py │ │ │ │ └── vqperceptual.py │ │ │ ├── magvit2_pytorch.py │ │ │ ├── regularizers │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── finite_scalar_quantization.py │ │ │ │ ├── lookup_free_quantization.py │ │ │ │ └── quantize.py │ │ │ ├── temporal_ae.py │ │ │ └── vqvae │ │ │ │ ├── movq_dec_3d.py │ │ │ │ ├── movq_dec_3d_dev.py │ │ │ │ ├── movq_enc_3d.py │ │ │ │ ├── movq_modules.py │ │ │ │ ├── quantize.py │ │ │ │ └── vqvae_blocks.py │ │ ├── cp_enc_dec.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── denoiser.py │ │ │ ├── denoiser_scaling.py │ │ │ ├── denoiser_weighting.py │ │ │ ├── discretizer.py │ │ │ ├── guiders.py │ │ │ ├── lora.py │ │ │ ├── loss.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── sampling.py │ │ │ ├── sampling_utils.py │ │ │ ├── sigma_sampling.py │ │ │ ├── util.py │ │ │ └── wrappers.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ └── video_attention.py │ ├── util.py │ └── webds.py ├── utils.py └── vae_modules │ ├── attention.py │ ├── autoencoder.py │ ├── cp_enc_dec.py │ ├── ema.py │ ├── regularizers.py │ └── utils.py ├── inf_270_1080p.sh └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | # *.png 3 | *.pt 4 | 1080_rope_ana 5 | 1080_rope_ana/ 6 | samples*/ 7 | runs/ 8 | checkpoints/ 9 | master_ip 10 | logs/ 11 | *.DS_Store 12 | .idea 13 | output* 14 | test* 15 | cogvideox-5b-sat 16 | cogvideox-5b-sat/ 17 | *.log 18 | */mini_tools/* 19 | *.mp4 20 | *vis_* 21 | *.pkl 22 | *.mp4 23 | vis_* 24 | checkpoints 25 | checkpoints/ 26 | __pycache__/* 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: ^tests/data/ 2 | repos: 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 5.0.4 5 | hooks: 6 | - id: flake8 7 | - repo: https://github.com/PyCQA/isort 8 | rev: 5.11.5 9 | hooks: 10 | - id: isort 11 | - repo: https://github.com/pre-commit/mirrors-yapf 12 | rev: v0.32.0 13 | hooks: 14 | - id: yapf 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v4.3.0 17 | hooks: 18 | - id: trailing-whitespace 19 | - id: check-yaml 20 | - id: end-of-file-fixer 21 | - id: requirements-txt-fixer 22 | - id: double-quote-string-fixer 23 | - id: check-merge-conflict 24 | - id: fix-encoding-pragma 25 | args: ["--remove"] 26 | - id: mixed-line-ending 27 | args: ["--fix=lf"] 28 | - repo: https://github.com/codespell-project/codespell 29 | rev: v2.2.1 30 | hooks: 31 | - id: codespell 32 | - repo: https://github.com/executablebooks/mdformat 33 | rev: 0.7.9 34 | hooks: 35 | - id: mdformat 36 | args: ["--number"] 37 | additional_dependencies: 38 | - mdformat-openmmlab 39 | - mdformat_frontmatter 40 | - linkify-it-py 41 | - repo: https://github.com/asottile/pyupgrade 42 | rev: v3.0.0 43 | hooks: 44 | - id: pyupgrade 45 | args: ["--py36-plus"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 CogVideo Model Team @ Zhipu AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | The CogVideoX License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means the CogVideoX Model Team that distributes its Software. 6 | 7 | “Software” means the CogVideoX model parameters made available under this license. 8 | 9 | 2. License Grant 10 | 11 | Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws. 12 | This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form . 13 | Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month. 14 | If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses. 15 | The above copyright statement and this license statement should be included in all copies or significant portions of this software. 16 | 17 | 3. Restriction 18 | 19 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes. 20 | 21 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. 22 | 23 | 4. Disclaimer 24 | 25 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 26 | 27 | 5. Limitation of Liability 28 | 29 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 30 | 31 | 6. Dispute Resolution 32 | 33 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. 34 | 35 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn. 36 | 37 | 1. 定义 38 | 39 | “许可方”是指分发其软件的 CogVideoX 模型团队。 40 | 41 | “软件”是指根据本许可提供的 CogVideoX 模型参数。 42 | 43 | 2. 许可授予 44 | 45 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。 46 | 本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。 47 | 48 | 经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。 49 | 在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。 50 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 51 | 52 | 3.限制 53 | 54 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 55 | 56 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 57 | 58 | 4.免责声明 59 | 60 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 61 | 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 62 | 63 | 5. 责任限制 64 | 65 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 66 | 67 | 6.争议解决 68 | 69 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 70 | 71 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | 4 |

5 | 6 |
7 | 8 | # Flowing Fidelity to Detail for Efficient High-Resolution Video Generation 9 | 10 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2502.05179-b31b1b.svg)](https://arxiv.org/abs/2502.05179) 11 | [![project page](https://img.shields.io/badge/Project_page-More_visualizations-green)](https://jshilong.github.io/flashvideo-page/)  12 | 13 |
14 | 15 |
16 | 17 |

18 | 19 |

20 | 21 | > [**FlashVideo: Flowing Fidelity to Detail for Efficient High-Resolution Video Generation**](https://arxiv.org/abs/2502.05179)
22 | > [Shilong Zhang](https://jshilong.github.io/), [Wenbo Li](https://scholar.google.com/citations?user=foGn_TIAAAAJ&hl=en), [Shoufa Chen](https://www.shoufachen.com/), [Chongjian Ge](https://chongjiange.github.io/), [Peize Sun](https://peizesun.github.io/),
[Yida Zhang](<>), [Yi Jiang](https://enjoyyi.github.io/), [Zehuan Yuan](https://shallowyuan.github.io/), [Bingyue Peng](<>), [Ping Luo](http://luoping.me/), 23 | >
HKU, CUHK, ByteDance
24 | 25 | ## 🤗 More video examples 👀 can be accessed at the [![project page](https://img.shields.io/badge/Project_page-More_visualizations-green)](https://jshilong.github.io/flashvideo-page/) 26 | 27 | 31 | 32 | #### ⚡⚡ User Prompt to 270p, NFE = 50, Takes ~30s ⚡⚡ 33 | #### ⚡⚡ 270p to 1080p , NFE = 4, Takes ~72s ⚡⚡ 34 | 35 | [![]()](https://github.com/FoundationVision/flashvideo-page/blob/main/static/images/output.gif) 36 | 37 | 38 | 42 | 43 | 44 |

45 |
46 | 47 |

48 | 49 | 53 | 54 | 55 | ## 🔥 Update 56 | 57 | - \[2025.02.10\] 🔥 🔥 🔥 Inference code and both stage model [weights](https://huggingface.co/FoundationVision/FlashVideo/tree/main) have been released. 58 | 59 | ## 🌿 Introduction 60 | In this repository, we provide: 61 | 62 | - [x] The stage-I weight for 270P video generation. 63 | - [x] The stage-II for enhancing 270P video to 1080P. 64 | - [x] Inference code of both stages. 65 | - [ ] Training code and related augmentation. Work in process [PR#12](https://github.com/FoundationVision/FlashVideo/pull/12) 66 | - [x] Loss function 67 | - [ ] Dataset and augmentation 68 | - [ ] Configuration and training script 69 | - [ ] Implementation with diffusers. 70 | - [ ] Gradio. 71 | 72 | 73 | ## Install 74 | 75 | ### 1. Environment Setup 76 | 77 | This repository is tested with PyTorch 2.4.0+cu121 and Python 3.11.11. You can install the necessary dependencies using the following command: 78 | 79 | ```shell 80 | pip install -r requirements.txt 81 | ``` 82 | 83 | ### 2. Preparing the Checkpoints 84 | 85 | To get the 3D VAE (identical to CogVideoX), along with Stage-I and Stage-II weights, set them up as follows: 86 | 87 | ```shell 88 | cd FlashVideo 89 | mkdir -p ./checkpoints 90 | huggingface-cli download --local-dir ./checkpoints FoundationVision/FlashVideo 91 | ``` 92 | 93 | The checkpoints should be organized as shown below: 94 | 95 | ``` 96 | ├── 3d-vae.pt 97 | ├── stage1.pt 98 | └── stage2.pt 99 | ``` 100 | 101 | ## 🚀 Text to Video Generation 102 | 103 | #### ⚠️ IMPORTANT NOTICE ⚠️ : Both stage-I and stage-II are trained with long prompts only. For achieving the best results, include comprehensive and detailed descriptions in your prompts, akin to the example provided in [example.txt](./example.txt). 104 | 105 | ### Jupyter Notebook 106 | 107 | You can conveniently provide user prompts in our Jupyter notebook. The default configuration for spatial and temporal slices in the VAE Decoder is tailored for an 80G GPU. For GPUs with less memory, one might consider increasing the [spatial and temporal slice](https://github.com/FoundationVision/FlashVideo/blob/400a9c1ef905eab3a1cb6b9f5a5a4c331378e4b5/sat/utils.py#L110). 108 | 109 | 110 | ```python 111 | flashvideo/demo.ipynb 112 | ``` 113 | 114 | ### Inferring from a Text File Containing Prompts 115 | 116 | You can conveniently provide the user prompt in a text file and generate videos with multiple gpus. 117 | 118 | ```python 119 | bash inf_270_1080p.sh 120 | ``` 121 | 122 | ## License 123 | 124 | This project is developed based on [CogVideoX](https://github.com/THUDM/CogVideo). Please refer to their original [license](https://github.com/THUDM/CogVideo?tab=readme-ov-file#model-license) for usage details. 125 | 126 | ## BibTeX 127 | 128 | ```bibtex 129 | @article{zhang2025flashvideo, 130 | title={FlashVideo: Flowing Fidelity to Detail for Efficient High-Resolution Video Generation}, 131 | author={Zhang, Shilong and Li, Wenbo and Chen, Shoufa and Ge, Chongjian and Sun, Peize and Zhang, Yida and Jiang, Yi and Yuan, Zehuan and Peng, Binyue and Luo, Ping}, 132 | journal={arXiv preprint arXiv:2502.05179}, 133 | year={2025} 134 | } 135 | ``` 136 | -------------------------------------------------------------------------------- /figs/22_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/22_0.jpg -------------------------------------------------------------------------------- /figs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/logo.png -------------------------------------------------------------------------------- /figs/methodv3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/methodv3.png -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/figs/pipeline.png -------------------------------------------------------------------------------- /flashvideo/configs/stage1.yaml: -------------------------------------------------------------------------------- 1 | share_cache_args: 2 | disable_ref : True 3 | num_vis_img: 4 4 | vis_ddpm: True 5 | eval_interval_list: [1, 50, 100, 1000] 6 | save_interval_list: [2000] 7 | 8 | args: 9 | checkpoint_activations: False # using gradient checkpointing 10 | model_parallel_size: 1 11 | experiment_name: lora-disney 12 | mode: finetune 13 | load: "" 14 | no_load_rng: True 15 | train_iters: 10000000000 # Suggest more than 1000 For Lora and SFT For 500 is enough 16 | eval_iters: 100000000 17 | eval_interval: 1000000000000 18 | eval_batch_size: 1 19 | save: ./ 20 | save_interval: 1000 21 | log_interval: 20 22 | train_data: [ "disney" ] # Train data path 23 | valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended) 24 | split: 1,0,0 25 | num_workers: 2 26 | force_train: True 27 | only_log_video_latents: True 28 | 29 | 30 | deepspeed: 31 | # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs 32 | train_micro_batch_size_per_gpu: 1 33 | gradient_accumulation_steps: 1 34 | steps_per_print: 50 35 | gradient_clipping: 0.1 36 | zero_optimization: 37 | stage: 2 38 | cpu_offload: false 39 | contiguous_gradients: false 40 | overlap_comm: true 41 | reduce_scatter: true 42 | reduce_bucket_size: 1000000000 43 | allgather_bucket_size: 1000000000 44 | load_from_fp32_weights: false 45 | zero_allow_untested_optimizer: true 46 | bf16: 47 | enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True 48 | fp16: 49 | enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False 50 | loss_scale: 0 51 | loss_scale_window: 400 52 | hysteresis: 2 53 | min_loss_scale: 1 54 | 55 | 56 | 57 | model: 58 | scale_factor: 0.7 59 | disable_first_stage_autocast: true 60 | log_keys: 61 | - txt 62 | 63 | denoiser_config: 64 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 65 | params: 66 | num_idx: 1000 67 | quantize_c_noise: False 68 | 69 | weighting_config: 70 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 71 | scaling_config: 72 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling 73 | discretization_config: 74 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 75 | params: 76 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 77 | 78 | network_config: 79 | target: extra_models.dit_res_adapter.SMALLDiffusionTransformer 80 | params: 81 | time_embed_dim: 512 82 | elementwise_affine: True 83 | num_frames: 64 84 | time_compressed_rate: 4 85 | # latent_width: 90 86 | # latent_height: 60 87 | latent_width: 512 88 | latent_height: 512 89 | num_layers: 42 # different from cogvideox_2b_infer.yaml 90 | patch_size: 2 91 | in_channels: 16 92 | out_channels: 16 93 | hidden_size: 3072 # different from cogvideox_2b_infer.yaml 94 | adm_in_channels: 256 95 | num_attention_heads: 48 # different from cogvideox_2b_infer.yaml 96 | 97 | transformer_args: 98 | checkpoint_activations: False 99 | vocab_size: 1 100 | max_sequence_length: 64 101 | layernorm_order: pre 102 | skip_init: false 103 | model_parallel_size: 1 104 | is_decoder: false 105 | 106 | modules: 107 | pos_embed_config: 108 | target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml 109 | params: 110 | hidden_size_head: 64 111 | text_length: 226 112 | 113 | lora_config: 114 | target: extra_models.dit_res_adapter.ResLoraMixin 115 | params: 116 | r: 128 117 | 118 | patch_embed_config: 119 | target: dit_video_concat.ImagePatchEmbeddingMixin 120 | params: 121 | text_hidden_size: 4096 122 | 123 | adaln_layer_config: 124 | target: dit_video_concat.AdaLNMixin 125 | params: 126 | qk_ln: True 127 | 128 | final_layer_config: 129 | target: dit_video_concat.FinalLayerMixin 130 | 131 | conditioner_config: 132 | target: sgm.modules.GeneralConditioner 133 | params: 134 | emb_models: 135 | - is_trainable: false 136 | input_key: txt 137 | ucg_rate: 0.1 138 | target: sgm.modules.encoders.modules.FrozenT5Embedder 139 | params: 140 | model_dir: "google/t5-v1_1-xxl" 141 | max_length: 226 142 | 143 | first_stage_config: 144 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper 145 | params: 146 | cp_size: 1 147 | ckpt_path: "checkpoints/3d-vae.pt" 148 | ignore_keys: [ 'loss' ] 149 | 150 | loss_config: 151 | target: torch.nn.Identity 152 | 153 | regularizer_config: 154 | target: vae_modules.regularizers.DiagonalGaussianRegularizer 155 | 156 | encoder_config: 157 | target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D 158 | params: 159 | double_z: true 160 | z_channels: 16 161 | resolution: 256 162 | in_channels: 3 163 | out_ch: 3 164 | ch: 128 165 | ch_mult: [ 1, 2, 2, 4 ] 166 | attn_resolutions: [ ] 167 | num_res_blocks: 3 168 | dropout: 0.0 169 | gather_norm: True 170 | 171 | decoder_config: 172 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D 173 | params: 174 | double_z: True 175 | z_channels: 16 176 | resolution: 256 177 | in_channels: 3 178 | out_ch: 3 179 | ch: 128 180 | ch_mult: [ 1, 2, 2, 4 ] 181 | attn_resolutions: [ ] 182 | num_res_blocks: 3 183 | dropout: 0.0 184 | gather_norm: False 185 | 186 | loss_fn_config: 187 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss 188 | params: 189 | offset_noise_level: 0 190 | sigma_sampler_config: 191 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 192 | params: 193 | uniform_sampling: True 194 | num_idx: 1000 195 | discretization_config: 196 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 197 | params: 198 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 199 | 200 | sampler_config: 201 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler 202 | params: 203 | num_steps: 51 204 | verbose: True 205 | 206 | discretization_config: 207 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 208 | params: 209 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 210 | 211 | guider_config: 212 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG 213 | params: 214 | # TODO check this cfg 215 | scale: 8 216 | exp: 5 217 | num_steps: 51 218 | -------------------------------------------------------------------------------- /flashvideo/configs/stage2.yaml: -------------------------------------------------------------------------------- 1 | custom_args: 2 | reload: "" 3 | 4 | share_cache_args: 5 | sample_ref_noise_step: 675 6 | time_size_embedding: True 7 | 8 | args: 9 | checkpoint_activations: True # using gradient checkpointing 10 | model_parallel_size: 1 11 | experiment_name: lora-disney 12 | mode: finetune 13 | load: "" # This is for Full model without lora adapter " # This is for Full model without lora adapter 14 | no_load_rng: True 15 | train_iters: 100000 # Suggest more than 1000 For Lora and SFT For 500 is enough 16 | eval_iters: 100000000 17 | eval_interval: [1, 200] 18 | eval_batch_size: 1 19 | save: 20 | # for debug 21 | save_interval: 250 22 | log_interval: 5 23 | train_data: [ "disney" ] # Train data path 24 | valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended) 25 | split: 1,0,0 26 | num_workers: 1 27 | force_train: True 28 | only_log_video_latents: True 29 | 30 | 31 | deepspeed: 32 | # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs 33 | train_micro_batch_size_per_gpu: 1 34 | gradient_accumulation_steps: 1 35 | steps_per_print: 50 36 | gradient_clipping: 0.1 37 | zero_optimization: 38 | stage: 2 39 | cpu_offload: false 40 | contiguous_gradients: false 41 | overlap_comm: true 42 | reduce_scatter: true 43 | reduce_bucket_size: 1000000000 44 | allgather_bucket_size: 1000000000 45 | load_from_fp32_weights: false 46 | zero_allow_untested_optimizer: true 47 | bf16: 48 | enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True 49 | fp16: 50 | enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False 51 | loss_scale: 0 52 | loss_scale_window: 400 53 | hysteresis: 2 54 | min_loss_scale: 1 55 | 56 | 57 | model: 58 | scale_factor: 1.15258426 59 | disable_first_stage_autocast: true 60 | log_keys: 61 | - txt 62 | 63 | denoiser_config: 64 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 65 | params: 66 | num_idx: 1000 67 | quantize_c_noise: False 68 | 69 | weighting_config: 70 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 71 | scaling_config: 72 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling 73 | discretization_config: 74 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 75 | params: 76 | shift_scale: 3.0 # different from cogvideox_2b_infer.yaml 77 | 78 | network_config: 79 | target: extra_models.dit_res_adapter.SMALLDiffusionTransformer 80 | params: 81 | time_embed_dim: 512 82 | elementwise_affine: True 83 | num_frames: 64 84 | time_compressed_rate: 4 85 | # latent_width: 90 86 | # latent_height: 60 87 | latent_width: 512 88 | latent_height: 512 89 | num_layers: 30 # different from cogvideox_2b_infer.yaml 90 | patch_size: 2 91 | in_channels: 16 92 | out_channels: 16 93 | hidden_size: 1920 # different from cogvideox_2b_infer.yaml 94 | adm_in_channels: 256 95 | num_attention_heads: 30 # different from cogvideox_2b_infer.yaml 96 | 97 | transformer_args: 98 | checkpoint_activations: True 99 | vocab_size: 1 100 | max_sequence_length: 64 101 | layernorm_order: pre 102 | skip_init: false 103 | model_parallel_size: 1 104 | is_decoder: false 105 | 106 | modules: 107 | pos_embed_config: 108 | target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml 109 | params: 110 | hidden_size_head: 64 111 | text_length: 226 112 | 113 | patch_embed_config: 114 | target: dit_video_concat.ImagePatchEmbeddingMixin 115 | params: 116 | text_hidden_size: 4096 117 | 118 | adaln_layer_config: 119 | target: dit_video_concat.AdaLNMixin 120 | params: 121 | qk_ln: True 122 | 123 | final_layer_config: 124 | target: dit_video_concat.FinalLayerMixin 125 | 126 | conditioner_config: 127 | target: sgm.modules.GeneralConditioner 128 | params: 129 | emb_models: 130 | - is_trainable: false 131 | input_key: txt 132 | ucg_rate: 0.1 133 | target: sgm.modules.encoders.modules.FrozenT5Embedder 134 | params: 135 | model_dir: "google/t5-v1_1-xxl" 136 | max_length: 226 137 | 138 | first_stage_config: 139 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper 140 | params: 141 | cp_size: 1 142 | ckpt_path: "checkpoints/3d-vae.pt" 143 | ignore_keys: [ 'loss' ] 144 | 145 | loss_config: 146 | target: torch.nn.Identity 147 | 148 | regularizer_config: 149 | target: vae_modules.regularizers.DiagonalGaussianRegularizer 150 | 151 | encoder_config: 152 | target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D 153 | params: 154 | double_z: true 155 | z_channels: 16 156 | resolution: 256 157 | in_channels: 3 158 | out_ch: 3 159 | ch: 128 160 | ch_mult: [ 1, 2, 2, 4 ] 161 | attn_resolutions: [ ] 162 | num_res_blocks: 3 163 | dropout: 0.0 164 | gather_norm: True 165 | 166 | decoder_config: 167 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D 168 | params: 169 | double_z: True 170 | z_channels: 16 171 | resolution: 256 172 | in_channels: 3 173 | out_ch: 3 174 | ch: 128 175 | ch_mult: [ 1, 2, 2, 4 ] 176 | attn_resolutions: [ ] 177 | num_res_blocks: 3 178 | dropout: 0.0 179 | gather_norm: False 180 | 181 | loss_fn_config: 182 | target: flow_video.FlowVideoDiffusionLoss 183 | params: 184 | offset_noise_level: 0 185 | sigma_sampler_config: 186 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 187 | params: 188 | uniform_sampling: False 189 | num_idx: 1000 190 | discretization_config: 191 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 192 | params: 193 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 194 | 195 | sampler_config: 196 | target: sgm.modules.diffusionmodules.sampling.CascadeVPSDEDPMPP2MSampler 197 | params: 198 | num_steps: 50 199 | verbose: True 200 | 201 | discretization_config: 202 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 203 | params: 204 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 205 | 206 | guider_config: 207 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG 208 | params: 209 | scale: 6 210 | exp: 5 211 | num_steps: 50 212 | -------------------------------------------------------------------------------- /flashvideo/flow_video.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | from sgm.modules import UNCONDITIONAL_CONFIG 8 | from sgm.modules.autoencoding.temporal_ae import VideoDecoder 9 | from sgm.modules.diffusionmodules.loss import StandardDiffusionLoss 10 | from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER 11 | from sgm.util import (append_dims, default, disabled_train, get_obj_from_str, 12 | instantiate_from_config) 13 | from torch import nn 14 | from torchdiffeq import odeint 15 | 16 | 17 | class FlowEngine(nn.Module): 18 | 19 | def __init__(self, args, **kwargs): 20 | super().__init__() 21 | model_config = args.model_config 22 | log_keys = model_config.get('log_keys', None) 23 | input_key = model_config.get('input_key', 'mp4') 24 | network_config = model_config.get('network_config', None) 25 | network_wrapper = model_config.get('network_wrapper', None) 26 | denoiser_config = model_config.get('denoiser_config', None) 27 | sampler_config = model_config.get('sampler_config', None) 28 | conditioner_config = model_config.get('conditioner_config', None) 29 | first_stage_config = model_config.get('first_stage_config', None) 30 | loss_fn_config = model_config.get('loss_fn_config', None) 31 | scale_factor = model_config.get('scale_factor', 1.0) 32 | latent_input = model_config.get('latent_input', False) 33 | disable_first_stage_autocast = model_config.get( 34 | 'disable_first_stage_autocast', False) 35 | no_cond_log = model_config.get('disable_first_stage_autocast', False) 36 | not_trainable_prefixes = model_config.get( 37 | 'not_trainable_prefixes', ['first_stage_model', 'conditioner']) 38 | compile_model = model_config.get('compile_model', False) 39 | en_and_decode_n_samples_a_time = model_config.get( 40 | 'en_and_decode_n_samples_a_time', None) 41 | lr_scale = model_config.get('lr_scale', None) 42 | lora_train = model_config.get('lora_train', False) 43 | self.use_pd = model_config.get('use_pd', False) 44 | 45 | self.log_keys = log_keys 46 | self.input_key = input_key 47 | self.not_trainable_prefixes = not_trainable_prefixes 48 | self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time 49 | self.lr_scale = lr_scale 50 | self.lora_train = lora_train 51 | self.noised_image_input = model_config.get('noised_image_input', False) 52 | self.noised_image_all_concat = model_config.get( 53 | 'noised_image_all_concat', False) 54 | self.noised_image_dropout = model_config.get('noised_image_dropout', 55 | 0.0) 56 | if args.fp16: 57 | dtype = torch.float16 58 | dtype_str = 'fp16' 59 | elif args.bf16: 60 | dtype = torch.bfloat16 61 | dtype_str = 'bf16' 62 | else: 63 | dtype = torch.float32 64 | dtype_str = 'fp32' 65 | self.dtype = dtype 66 | self.dtype_str = dtype_str 67 | 68 | network_config['params']['dtype'] = dtype_str 69 | model = instantiate_from_config(network_config) 70 | self.model = get_obj_from_str( 71 | default(network_wrapper, 72 | OPENAIUNETWRAPPER))(model, 73 | compile_model=compile_model, 74 | dtype=dtype) 75 | 76 | self.denoiser = instantiate_from_config(denoiser_config) 77 | self.sampler = instantiate_from_config( 78 | sampler_config) if sampler_config is not None else None 79 | self.conditioner = instantiate_from_config( 80 | default(conditioner_config, UNCONDITIONAL_CONFIG)) 81 | 82 | self._init_first_stage(first_stage_config) 83 | 84 | self.loss_fn = instantiate_from_config( 85 | loss_fn_config) if loss_fn_config is not None else None 86 | 87 | self.latent_input = latent_input 88 | self.scale_factor = scale_factor 89 | self.disable_first_stage_autocast = disable_first_stage_autocast 90 | self.no_cond_log = no_cond_log 91 | self.device = args.device 92 | 93 | def disable_untrainable_params(self): 94 | pass 95 | 96 | def reinit(self, parent_model=None): 97 | pass 98 | 99 | def _init_first_stage(self, config): 100 | model = instantiate_from_config(config).eval() 101 | model.train = disabled_train 102 | for param in model.parameters(): 103 | param.requires_grad = False 104 | self.first_stage_model = model 105 | 106 | def get_input(self, batch): 107 | return batch[self.input_key].to(self.dtype) 108 | 109 | @torch.no_grad() 110 | def decode_first_stage(self, z): 111 | z = 1.0 / self.scale_factor * z 112 | n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) 113 | n_rounds = math.ceil(z.shape[0] / n_samples) 114 | all_out = [] 115 | with torch.autocast('cuda', 116 | enabled=not self.disable_first_stage_autocast): 117 | for n in range(n_rounds): 118 | if isinstance(self.first_stage_model.decoder, VideoDecoder): 119 | kwargs = { 120 | 'timesteps': len(z[n * n_samples:(n + 1) * n_samples]) 121 | } 122 | else: 123 | kwargs = {} 124 | out = self.first_stage_model.decode( 125 | z[n * n_samples:(n + 1) * n_samples], **kwargs) 126 | all_out.append(out) 127 | out = torch.cat(all_out, dim=0) 128 | return out 129 | 130 | @torch.no_grad() 131 | def encode_first_stage(self, x, batch): 132 | frame = x.shape[2] 133 | 134 | if frame > 1 and self.latent_input: 135 | x = x.permute(0, 2, 1, 3, 4).contiguous() 136 | return x * self.scale_factor # already encoded 137 | 138 | n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) 139 | n_rounds = math.ceil(x.shape[0] / n_samples) 140 | all_out = [] 141 | with torch.autocast('cuda', 142 | enabled=not self.disable_first_stage_autocast): 143 | for n in range(n_rounds): 144 | out = self.first_stage_model.encode(x[n * n_samples:(n + 1) * 145 | n_samples]) 146 | all_out.append(out) 147 | z = torch.cat(all_out, dim=0) 148 | z = self.scale_factor * z 149 | return z 150 | 151 | @torch.no_grad() 152 | def save_memory_encode_first_stage(self, x, batch): 153 | splits_x = torch.split(x, [13, 12, 12, 12], dim=2) 154 | 155 | all_out = [] 156 | 157 | with torch.autocast('cuda', enabled=False): 158 | for idx, input_x in enumerate(splits_x): 159 | if idx == len(splits_x) - 1: 160 | clear_fake_cp_cache = True 161 | else: 162 | clear_fake_cp_cache = False 163 | out = self.first_stage_model.encode( 164 | input_x.contiguous(), 165 | clear_fake_cp_cache=clear_fake_cp_cache) 166 | all_out.append(out) 167 | 168 | z = torch.cat(all_out, dim=2) 169 | z = self.scale_factor * z 170 | return z 171 | 172 | def single_function_evaluation(self, 173 | t, 174 | x, 175 | cond=None, 176 | uc=None, 177 | cfg=1, 178 | **kwargs): 179 | start_time = time.time() 180 | # for CFG 181 | x = torch.cat([x] * 2) 182 | t = t.reshape(1).to(x.dtype).to(x.device) 183 | t = torch.cat([t] * 2) 184 | idx = 1000 - (t * 1000) 185 | 186 | real_cond = dict() 187 | for k, v in cond.items(): 188 | uncond_v = uc[k] 189 | real_cond[k] = torch.cat([v, uncond_v]) 190 | 191 | vt = self.model(x, t=idx, c=real_cond, idx=idx) 192 | vt, uc_vt = vt.chunk(2) 193 | vt = uc_vt + cfg * (vt - uc_vt) 194 | end_time = time.time() 195 | print(f'single_function_evaluation time at {t}', end_time - start_time) 196 | return vt 197 | 198 | @torch.no_grad() 199 | def sample( 200 | self, 201 | ref_x, 202 | cond, 203 | uc, 204 | **sample_kwargs, 205 | ): 206 | """Stage 2 Sampling, start from the first stage results `ref_x` 207 | 208 | Args: 209 | ref_x (_type_): Stage1 low resolution video 210 | cond (dict): Dict contains condtion embeddings 211 | uc (dict): Dict contains uncondition embedding 212 | 213 | Returns: 214 | Tensor: Secondary stage results 215 | """ 216 | 217 | sample_kwargs = sample_kwargs or {} 218 | print('sample_kwargs', sample_kwargs) 219 | # timesteps 220 | num_steps = sample_kwargs.get('num_steps', 4) 221 | t = torch.linspace(0, 1, num_steps + 1, 222 | dtype=ref_x.dtype).to(ref_x.device) 223 | print(self.share_cache['shift_t']) 224 | shift_t = float(self.share_cache['shift_t']) 225 | t = 1 - shift_t * (1 - t) / (1 + (shift_t - 1) * (1 - t)) 226 | 227 | print('sample:', t) 228 | t = t 229 | single_function_evaluation = partial(self.single_function_evaluation, 230 | cond=cond, 231 | uc=uc, 232 | cfg=sample_kwargs.get('cfg', 1)) 233 | 234 | ref_noise_step = self.share_cache['sample_ref_noise_step'] 235 | print(f'ref_noise_step : {ref_noise_step}') 236 | 237 | ref_alphas_cumprod_sqrt = self.loss_fn.sigma_sampler.idx_to_sigma( 238 | torch.zeros(ref_x.shape[0]).fill_(ref_noise_step).long()) 239 | ref_alphas_cumprod_sqrt = ref_alphas_cumprod_sqrt.to(ref_x.device) 240 | ori_dtype = ref_x.dtype 241 | 242 | ref_noise = torch.randn_like(ref_x) 243 | print('weight', ref_alphas_cumprod_sqrt, flush=True) 244 | 245 | ref_noised_input = ref_x * append_dims(ref_alphas_cumprod_sqrt, ref_x.ndim) \ 246 | + ref_noise * append_dims( 247 | (1 - ref_alphas_cumprod_sqrt**2) ** 0.5, ref_x.ndim 248 | ) 249 | ref_x = ref_noised_input.to(ori_dtype) 250 | self.share_cache['ref_x'] = ref_x 251 | 252 | results = odeint(single_function_evaluation, 253 | ref_x, 254 | t, 255 | method=sample_kwargs.get('method', 'euler'), 256 | atol=1e-6, 257 | rtol=1e-3)[-1] 258 | 259 | return results 260 | 261 | 262 | class FlowVideoDiffusionLoss(StandardDiffusionLoss): 263 | 264 | def __init__(self, 265 | block_scale=None, 266 | block_size=None, 267 | min_snr_value=None, 268 | fixed_frames=0, 269 | **kwargs): 270 | self.fixed_frames = fixed_frames 271 | self.block_scale = block_scale 272 | self.block_size = block_size 273 | self.min_snr_value = min_snr_value 274 | self.schedule = None 275 | super().__init__(**kwargs) 276 | 277 | def __call__(self, network, denoiser, conditioner, input, batch): 278 | pass 279 | -------------------------------------------------------------------------------- /flashvideo/sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = '0.1.0' 5 | -------------------------------------------------------------------------------- /flashvideo/sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print( 30 | f'current step: {n}, recent lr-multiplier: {self.last_lr}') 31 | if n < self.lr_warm_up_steps: 32 | lr = (self.lr_max - 33 | self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - 38 | self.lr_warm_up_steps) 39 | t = min(t, 1.0) 40 | lr = self.lr_min + 0.5 * (self.lr_max - 41 | self.lr_min) * (1 + np.cos(t * np.pi)) 42 | self.last_lr = lr 43 | return lr 44 | 45 | def __call__(self, n, **kwargs): 46 | return self.schedule(n, **kwargs) 47 | 48 | 49 | class LambdaWarmUpCosineScheduler2: 50 | """ 51 | supports repeated iterations, configurable via lists 52 | note: use with a base_lr of 1.0. 53 | """ 54 | 55 | def __init__(self, 56 | warm_up_steps, 57 | f_min, 58 | f_max, 59 | f_start, 60 | cycle_lengths, 61 | verbosity_interval=0): 62 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len( 63 | f_start) == len(cycle_lengths) 64 | self.lr_warm_up_steps = warm_up_steps 65 | self.f_start = f_start 66 | self.f_min = f_min 67 | self.f_max = f_max 68 | self.cycle_lengths = cycle_lengths 69 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 70 | self.last_f = 0.0 71 | self.verbosity_interval = verbosity_interval 72 | 73 | def find_in_interval(self, n): 74 | interval = 0 75 | for cl in self.cum_cycles[1:]: 76 | if n <= cl: 77 | return interval 78 | interval += 1 79 | 80 | def schedule(self, n, **kwargs): 81 | cycle = self.find_in_interval(n) 82 | n = n - self.cum_cycles[cycle] 83 | if self.verbosity_interval > 0: 84 | if n % self.verbosity_interval == 0: 85 | print( 86 | f'current step: {n}, recent lr-multiplier: {self.last_f}, ' 87 | f'current cycle {cycle}') 88 | if n < self.lr_warm_up_steps[cycle]: 89 | f = (self.f_max[cycle] - self.f_start[cycle] 90 | ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 91 | self.last_f = f 92 | return f 93 | else: 94 | t = (n - self.lr_warm_up_steps[cycle]) / ( 95 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 96 | t = min(t, 1.0) 97 | f = self.f_min[cycle] + 0.5 * ( 98 | self.f_max[cycle] - self.f_min[cycle]) * (1 + 99 | np.cos(t * np.pi)) 100 | self.last_f = f 101 | return f 102 | 103 | def __call__(self, n, **kwargs): 104 | return self.schedule(n, **kwargs) 105 | 106 | 107 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 108 | 109 | def schedule(self, n, **kwargs): 110 | cycle = self.find_in_interval(n) 111 | n = n - self.cum_cycles[cycle] 112 | if self.verbosity_interval > 0: 113 | if n % self.verbosity_interval == 0: 114 | print( 115 | f'current step: {n}, recent lr-multiplier: {self.last_f}, ' 116 | f'current cycle {cycle}') 117 | 118 | if n < self.lr_warm_up_steps[cycle]: 119 | f = (self.f_max[cycle] - self.f_start[cycle] 120 | ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 121 | self.last_f = f 122 | return f 123 | else: 124 | f = (self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * 125 | (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])) 126 | self.last_f = f 127 | return f 128 | -------------------------------------------------------------------------------- /flashvideo/sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | 'target': 'sgm.modules.GeneralConditioner', 5 | 'params': { 6 | 'emb_models': [] 7 | }, 8 | } 9 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'GeneralLPIPSWithDiscriminator', 3 | 'LatentLPIPS', 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | from .video_loss import VideoAutoencoderLoss 9 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | 10 | def __init__( 11 | self, 12 | decoder_config, 13 | perceptual_weight=1.0, 14 | latent_weight=1.0, 15 | scale_input_to_tgt_size=False, 16 | scale_tgt_to_input_size=False, 17 | perceptual_weight_on_inputs=0.0, 18 | ): 19 | super().__init__() 20 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 21 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 22 | self.init_decoder(decoder_config) 23 | self.perceptual_loss = LPIPS().eval() 24 | self.perceptual_weight = perceptual_weight 25 | self.latent_weight = latent_weight 26 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 27 | 28 | def init_decoder(self, config): 29 | self.decoder = instantiate_from_config(config) 30 | if hasattr(self.decoder, 'encoder'): 31 | del self.decoder.encoder 32 | 33 | def forward(self, 34 | latent_inputs, 35 | latent_predictions, 36 | image_inputs, 37 | split='train'): 38 | log = dict() 39 | loss = (latent_inputs - latent_predictions)**2 40 | log[f'{split}/latent_l2_loss'] = loss.mean().detach() 41 | image_reconstructions = None 42 | if self.perceptual_weight > 0.0: 43 | image_reconstructions = self.decoder.decode(latent_predictions) 44 | image_targets = self.decoder.decode(latent_inputs) 45 | perceptual_loss = self.perceptual_loss( 46 | image_targets.contiguous(), image_reconstructions.contiguous()) 47 | loss = self.latent_weight * loss.mean( 48 | ) + self.perceptual_weight * perceptual_loss.mean() 49 | log[f'{split}/perceptual_loss'] = perceptual_loss.mean().detach() 50 | 51 | if self.perceptual_weight_on_inputs > 0.0: 52 | image_reconstructions = default( 53 | image_reconstructions, self.decoder.decode(latent_predictions)) 54 | if self.scale_input_to_tgt_size: 55 | image_inputs = torch.nn.functional.interpolate( 56 | image_inputs, 57 | image_reconstructions.shape[2:], 58 | mode='bicubic', 59 | antialias=True, 60 | ) 61 | elif self.scale_tgt_to_input_size: 62 | image_reconstructions = torch.nn.functional.interpolate( 63 | image_reconstructions, 64 | image_inputs.shape[2:], 65 | mode='bicubic', 66 | antialias=True, 67 | ) 68 | 69 | perceptual_loss2 = self.perceptual_loss( 70 | image_inputs.contiguous(), image_reconstructions.contiguous()) 71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean( 72 | ) 73 | log[f'{split}/perceptual_loss_on_inputs'] = perceptual_loss2.mean( 74 | ).detach() 75 | return loss, log 76 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth 2 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name='vgg_lpips'): 29 | ckpt = get_ckpt_path(name, 'sgm/modules/autoencoding/lpips/loss') 30 | self.load_state_dict(torch.load(ckpt, 31 | map_location=torch.device('cpu')), 32 | strict=False) 33 | print(f'loaded pretrained LPIPS loss from {ckpt}') 34 | 35 | @classmethod 36 | def from_pretrained(cls, name='vgg_lpips'): 37 | if name != 'vgg_lpips': 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict(torch.load(ckpt, 42 | map_location=torch.device('cpu')), 43 | strict=False) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), 48 | self.scaling_layer(target)) 49 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 50 | feats0, feats1, diffs = {}, {}, {} 51 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 52 | for kk in range(len(self.chns)): 53 | feats0[kk], feats1[kk] = normalize_tensor( 54 | outs0[kk]), normalize_tensor(outs1[kk]) 55 | diffs[kk] = (feats0[kk] - feats1[kk])**2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | 69 | def __init__(self): 70 | super().__init__() 71 | self.register_buffer( 72 | 'shift', 73 | torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) 74 | self.register_buffer( 75 | 'scale', 76 | torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) 77 | 78 | def forward(self, inp): 79 | return (inp - self.shift) / self.scale 80 | 81 | 82 | class NetLinLayer(nn.Module): 83 | """A single linear layer which does a 1x1 conv""" 84 | 85 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 86 | super().__init__() 87 | layers = ([ 88 | nn.Dropout(), 89 | ] if (use_dropout) else []) 90 | layers += [ 91 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 92 | ] 93 | self.model = nn.Sequential(*layers) 94 | 95 | 96 | class vgg16(torch.nn.Module): 97 | 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super().__init__() 100 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple( 133 | 'VggOutputs', 134 | ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 135 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, 136 | h_relu5_3) 137 | return out 138 | 139 | 140 | def normalize_tensor(x, eps=1e-10): 141 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 142 | return x / (norm_factor + eps) 143 | 144 | 145 | def spatial_average(x, keepdim=True): 146 | return x.mean([2, 3], keepdim=keepdim) 147 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 59 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | try: 12 | nn.init.normal_(m.weight.data, 0.0, 0.02) 13 | except: 14 | nn.init.normal_(m.conv.weight.data, 0.0, 0.02) 15 | elif classname.find('BatchNorm') != -1: 16 | nn.init.normal_(m.weight.data, 1.0, 0.02) 17 | nn.init.constant_(m.bias.data, 0) 18 | 19 | 20 | class NLayerDiscriminator(nn.Module): 21 | """Defines a PatchGAN discriminator as in Pix2Pix 22 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 23 | """ 24 | 25 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 26 | """Construct a PatchGAN discriminator 27 | Parameters: 28 | input_nc (int) -- the number of channels in input images 29 | ndf (int) -- the number of filters in the last conv layer 30 | n_layers (int) -- the number of conv layers in the discriminator 31 | norm_layer -- normalization layer 32 | """ 33 | super().__init__() 34 | if not use_actnorm: 35 | norm_layer = nn.BatchNorm2d 36 | else: 37 | norm_layer = ActNorm 38 | if type( 39 | norm_layer 40 | ) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 41 | use_bias = norm_layer.func != nn.BatchNorm2d 42 | else: 43 | use_bias = norm_layer != nn.BatchNorm2d 44 | 45 | kw = 4 46 | padw = 1 47 | sequence = [ 48 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 49 | nn.LeakyReLU(0.2, True), 50 | ] 51 | nf_mult = 1 52 | nf_mult_prev = 1 53 | for n in range(1, 54 | n_layers): # gradually increase the number of filters 55 | nf_mult_prev = nf_mult 56 | nf_mult = min(2**n, 8) 57 | sequence += [ 58 | nn.Conv2d( 59 | ndf * nf_mult_prev, 60 | ndf * nf_mult, 61 | kernel_size=kw, 62 | stride=2, 63 | padding=padw, 64 | bias=use_bias, 65 | ), 66 | norm_layer(ndf * nf_mult), 67 | nn.LeakyReLU(0.2, True), 68 | ] 69 | 70 | nf_mult_prev = nf_mult 71 | nf_mult = min(2**n_layers, 8) 72 | sequence += [ 73 | nn.Conv2d( 74 | ndf * nf_mult_prev, 75 | ndf * nf_mult, 76 | kernel_size=kw, 77 | stride=1, 78 | padding=padw, 79 | bias=use_bias, 80 | ), 81 | norm_layer(ndf * nf_mult), 82 | nn.LeakyReLU(0.2, True), 83 | ] 84 | 85 | sequence += [ 86 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 87 | ] # output 1 channel prediction map 88 | self.main = nn.Sequential(*sequence) 89 | 90 | def forward(self, input): 91 | """Standard forward.""" 92 | return self.main(input) 93 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = { 10 | 'vgg_lpips': 11 | 'https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1' 12 | } 13 | 14 | CKPT_MAP = {'vgg_lpips': 'vgg.pth'} 15 | 16 | MD5_MAP = {'vgg_lpips': 'd507d7349b931f0638a25a48a722f98a'} 17 | 18 | 19 | def download(url, local_path, chunk_size=1024): 20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 21 | with requests.get(url, stream=True) as r: 22 | total_size = int(r.headers.get('content-length', 0)) 23 | with tqdm(total=total_size, unit='B', unit_scale=True) as pbar: 24 | with open(local_path, 'wb') as f: 25 | for data in r.iter_content(chunk_size=chunk_size): 26 | if data: 27 | f.write(data) 28 | pbar.update(chunk_size) 29 | 30 | 31 | def md5_hash(path): 32 | with open(path, 'rb') as f: 33 | content = f.read() 34 | return hashlib.md5(content).hexdigest() 35 | 36 | 37 | def get_ckpt_path(name, root, check=False): 38 | assert name in URL_MAP 39 | path = os.path.join(root, CKPT_MAP[name]) 40 | if not os.path.exists(path) or (check 41 | and not md5_hash(path) == MD5_MAP[name]): 42 | print('Downloading {} model from {} to {}'.format( 43 | name, URL_MAP[name], path)) 44 | download(URL_MAP[name], path) 45 | md5 = md5_hash(path) 46 | assert md5 == MD5_MAP[name], md5 47 | return path 48 | 49 | 50 | class ActNorm(nn.Module): 51 | 52 | def __init__(self, 53 | num_features, 54 | logdet=False, 55 | affine=True, 56 | allow_reverse_init=False): 57 | assert affine 58 | super().__init__() 59 | self.logdet = logdet 60 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 61 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 62 | self.allow_reverse_init = allow_reverse_init 63 | 64 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 65 | 66 | def initialize(self, input): 67 | with torch.no_grad(): 68 | flatten = input.permute(1, 0, 2, 69 | 3).contiguous().view(input.shape[1], -1) 70 | mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze( 71 | 3).permute(1, 0, 2, 3) 72 | std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze( 73 | 3).permute(1, 0, 2, 3) 74 | 75 | self.loc.data.copy_(-mean) 76 | self.scale.data.copy_(1 / (std + 1e-6)) 77 | 78 | def forward(self, input, reverse=False): 79 | if reverse: 80 | return self.reverse(input) 81 | if len(input.shape) == 2: 82 | input = input[:, :, None, None] 83 | squeeze = True 84 | else: 85 | squeeze = False 86 | 87 | _, _, height, width = input.shape 88 | 89 | if self.training and self.initialized.item() == 0: 90 | self.initialize(input) 91 | self.initialized.fill_(1) 92 | 93 | h = self.scale * (input + self.loc) 94 | 95 | if squeeze: 96 | h = h.squeeze(-1).squeeze(-1) 97 | 98 | if self.logdet: 99 | log_abs = torch.log(torch.abs(self.scale)) 100 | logdet = height * width * torch.sum(log_abs) 101 | logdet = logdet * torch.ones(input.shape[0]).to(input) 102 | return h, logdet 103 | 104 | return h 105 | 106 | def reverse(self, output): 107 | if self.training and self.initialized.item() == 0: 108 | if not self.allow_reverse_init: 109 | raise RuntimeError( 110 | 'Initializing ActNorm in reverse direction is ' 111 | 'disabled by default. Use allow_reverse_init=True to enable.' 112 | ) 113 | else: 114 | self.initialize(output) 115 | self.initialized.fill_(1) 116 | 117 | if len(output.shape) == 2: 118 | output = output[:, :, None, None] 119 | squeeze = True 120 | else: 121 | squeeze = False 122 | 123 | h = output / self.scale - self.loc 124 | 125 | if squeeze: 126 | h = h.squeeze(-1).squeeze(-1) 127 | return h 128 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * (torch.mean(torch.nn.functional.softplus(-logits_real)) + 14 | torch.mean(torch.nn.functional.softplus(logits_fake))) 15 | return d_loss 16 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import \ 9 | DiagonalGaussianDistribution 10 | from .base import AbstractRegularizer 11 | 12 | 13 | class DiagonalGaussianRegularizer(AbstractRegularizer): 14 | 15 | def __init__(self, sample: bool = True): 16 | super().__init__() 17 | self.sample = sample 18 | 19 | def get_trainable_parameters(self) -> Any: 20 | yield from () 21 | 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | log = dict() 24 | posterior = DiagonalGaussianDistribution(z) 25 | if self.sample: 26 | z = posterior.sample() 27 | else: 28 | z = posterior.mode() 29 | kl_loss = posterior.kl() 30 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 31 | log['kl_loss'] = kl_loss 32 | return z, log 33 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 15 | raise NotImplementedError() 16 | 17 | @abstractmethod 18 | def get_trainable_parameters(self) -> Any: 19 | raise NotImplementedError() 20 | 21 | 22 | class IdentityRegularizer(AbstractRegularizer): 23 | 24 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 25 | return z, dict() 26 | 27 | def get_trainable_parameters(self) -> Any: 28 | yield from () 29 | 30 | 31 | def measure_perplexity( 32 | predicted_indices: torch.Tensor, 33 | num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: 34 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 35 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 36 | encodings = F.one_hot(predicted_indices, 37 | num_centroids).float().reshape(-1, num_centroids) 38 | avg_probs = encodings.mean(0) 39 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 40 | cluster_use = torch.sum(avg_probs > 0) 41 | return perplexity, cluster_use 42 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 3 | Code adapted from Jax version in Appendix A.1 4 | """ 5 | 6 | from typing import List, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | from einops import pack, rearrange, unpack 11 | from torch import Tensor, int32 12 | from torch.cuda.amp import autocast 13 | from torch.nn import Module 14 | 15 | # helper functions 16 | 17 | 18 | def exists(v): 19 | return v is not None 20 | 21 | 22 | def default(*args): 23 | for arg in args: 24 | if exists(arg): 25 | return arg 26 | return None 27 | 28 | 29 | def pack_one(t, pattern): 30 | return pack([t], pattern) 31 | 32 | 33 | def unpack_one(t, ps, pattern): 34 | return unpack(t, ps, pattern)[0] 35 | 36 | 37 | # tensor helpers 38 | 39 | 40 | def round_ste(z: Tensor) -> Tensor: 41 | """Round with straight through gradients.""" 42 | zhat = z.round() 43 | return z + (zhat - z).detach() 44 | 45 | 46 | # main class 47 | 48 | 49 | class FSQ(Module): 50 | 51 | def __init__( 52 | self, 53 | levels: List[int], 54 | dim: Optional[int] = None, 55 | num_codebooks=1, 56 | keep_num_codebooks_dim: Optional[bool] = None, 57 | scale: Optional[float] = None, 58 | ): 59 | super().__init__() 60 | _levels = torch.tensor(levels, dtype=int32) 61 | self.register_buffer('_levels', _levels, persistent=False) 62 | 63 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), 64 | dim=0, 65 | dtype=int32) 66 | self.register_buffer('_basis', _basis, persistent=False) 67 | 68 | self.scale = scale 69 | 70 | codebook_dim = len(levels) 71 | self.codebook_dim = codebook_dim 72 | 73 | effective_codebook_dim = codebook_dim * num_codebooks 74 | self.num_codebooks = num_codebooks 75 | self.effective_codebook_dim = effective_codebook_dim 76 | 77 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, 78 | num_codebooks > 1) 79 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 80 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 81 | 82 | self.dim = default(dim, len(_levels) * num_codebooks) 83 | 84 | has_projections = self.dim != effective_codebook_dim 85 | self.project_in = nn.Linear( 86 | self.dim, 87 | effective_codebook_dim) if has_projections else nn.Identity() 88 | self.project_out = nn.Linear( 89 | effective_codebook_dim, 90 | self.dim) if has_projections else nn.Identity() 91 | self.has_projections = has_projections 92 | 93 | self.codebook_size = self._levels.prod().item() 94 | 95 | implicit_codebook = self.indices_to_codes(torch.arange( 96 | self.codebook_size), 97 | project_out=False) 98 | self.register_buffer('implicit_codebook', 99 | implicit_codebook, 100 | persistent=False) 101 | 102 | def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor: 103 | """Bound `z`, an array of shape (..., d).""" 104 | half_l = (self._levels - 1) * (1 + eps) / 2 105 | offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) 106 | shift = (offset / half_l).atanh() 107 | return (z + shift).tanh() * half_l - offset 108 | 109 | def quantize(self, z: Tensor) -> Tensor: 110 | """Quantizes z, returns quantized zhat, same shape as z.""" 111 | quantized = round_ste(self.bound(z)) 112 | half_width = self._levels // 2 # Renormalize to [-1, 1]. 113 | return quantized / half_width 114 | 115 | def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor: 116 | half_width = self._levels // 2 117 | return (zhat_normalized * half_width) + half_width 118 | 119 | def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor: 120 | half_width = self._levels // 2 121 | return (zhat - half_width) / half_width 122 | 123 | def codes_to_indices(self, zhat: Tensor) -> Tensor: 124 | """Converts a `code` to an index in the codebook.""" 125 | assert zhat.shape[-1] == self.codebook_dim 126 | zhat = self._scale_and_shift(zhat) 127 | return (zhat * self._basis).sum(dim=-1).to(int32) 128 | 129 | def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor: 130 | """Inverse of `codes_to_indices`.""" 131 | 132 | is_img_or_video = indices.ndim >= (3 + 133 | int(self.keep_num_codebooks_dim)) 134 | 135 | indices = rearrange(indices, '... -> ... 1') 136 | codes_non_centered = (indices // self._basis) % self._levels 137 | codes = self._scale_and_shift_inverse(codes_non_centered) 138 | 139 | if self.keep_num_codebooks_dim: 140 | codes = rearrange(codes, '... c d -> ... (c d)') 141 | 142 | if project_out: 143 | codes = self.project_out(codes) 144 | 145 | if is_img_or_video: 146 | codes = rearrange(codes, 'b ... d -> b d ...') 147 | 148 | return codes 149 | 150 | @autocast(enabled=False) 151 | def forward(self, z: Tensor) -> Tensor: 152 | """ 153 | einstein notation 154 | b - batch 155 | n - sequence (or flattened spatial dimensions) 156 | d - feature dimension 157 | c - number of codebook dim 158 | """ 159 | 160 | is_img_or_video = z.ndim >= 4 161 | 162 | # standardize image or video into (batch, seq, dimension) 163 | 164 | if is_img_or_video: 165 | z = rearrange(z, 'b d ... -> b ... d') 166 | z, ps = pack_one(z, 'b * d') 167 | 168 | assert z.shape[ 169 | -1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}' 170 | 171 | z = self.project_in(z) 172 | 173 | z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks) 174 | 175 | codes = self.quantize(z) 176 | indices = self.codes_to_indices(codes) 177 | 178 | codes = rearrange(codes, 'b n c d -> b n (c d)') 179 | 180 | out = self.project_out(codes) 181 | 182 | # reconstitute image or video dimensions 183 | 184 | if is_img_or_video: 185 | out = unpack_one(out, ps, 'b * d') 186 | out = rearrange(out, 'b ... d -> b d ...') 187 | 188 | indices = unpack_one(indices, ps, 'b * c') 189 | 190 | if not self.keep_num_codebooks_dim: 191 | indices = rearrange(indices, '... 1 -> ...') 192 | 193 | return out, indices 194 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lookup Free Quantization 3 | Proposed in https://arxiv.org/abs/2310.05737 4 | 5 | In the simplest setup, each dimension is quantized into {-1, 1}. 6 | An entropy penalty is used to encourage utilization. 7 | """ 8 | 9 | from collections import namedtuple 10 | from math import ceil, log2 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from einops import pack, rearrange, reduce, unpack 15 | from torch import einsum, nn 16 | from torch.cuda.amp import autocast 17 | from torch.nn import Module 18 | 19 | # constants 20 | 21 | Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss']) 22 | 23 | LossBreakdown = namedtuple( 24 | 'LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment']) 25 | 26 | # helper functions 27 | 28 | 29 | def exists(v): 30 | return v is not None 31 | 32 | 33 | def default(*args): 34 | for arg in args: 35 | if exists(arg): 36 | return arg() if callable(arg) else arg 37 | return None 38 | 39 | 40 | def pack_one(t, pattern): 41 | return pack([t], pattern) 42 | 43 | 44 | def unpack_one(t, ps, pattern): 45 | return unpack(t, ps, pattern)[0] 46 | 47 | 48 | # entropy 49 | 50 | 51 | def log(t, eps=1e-5): 52 | return t.clamp(min=eps).log() 53 | 54 | 55 | def entropy(prob): 56 | return (-prob * log(prob)).sum(dim=-1) 57 | 58 | 59 | # class 60 | 61 | 62 | class LFQ(Module): 63 | 64 | def __init__( 65 | self, 66 | *, 67 | dim=None, 68 | codebook_size=None, 69 | entropy_loss_weight=0.1, 70 | commitment_loss_weight=0.25, 71 | diversity_gamma=1.0, 72 | straight_through_activation=nn.Identity(), 73 | num_codebooks=1, 74 | keep_num_codebooks_dim=None, 75 | codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer 76 | frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy 77 | ): 78 | super().__init__() 79 | 80 | # some assert validations 81 | 82 | assert exists(dim) or exists( 83 | codebook_size 84 | ), 'either dim or codebook_size must be specified for LFQ' 85 | assert ( 86 | not exists(codebook_size) or log2(codebook_size).is_integer() 87 | ), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})' 88 | 89 | codebook_size = default(codebook_size, lambda: 2**dim) 90 | codebook_dim = int(log2(codebook_size)) 91 | 92 | codebook_dims = codebook_dim * num_codebooks 93 | dim = default(dim, codebook_dims) 94 | 95 | has_projections = dim != codebook_dims 96 | self.project_in = nn.Linear( 97 | dim, codebook_dims) if has_projections else nn.Identity() 98 | self.project_out = nn.Linear( 99 | codebook_dims, dim) if has_projections else nn.Identity() 100 | self.has_projections = has_projections 101 | 102 | self.dim = dim 103 | self.codebook_dim = codebook_dim 104 | self.num_codebooks = num_codebooks 105 | 106 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, 107 | num_codebooks > 1) 108 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 109 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 110 | 111 | # straight through activation 112 | 113 | self.activation = straight_through_activation 114 | 115 | # entropy aux loss related weights 116 | 117 | assert 0 < frac_per_sample_entropy <= 1.0 118 | self.frac_per_sample_entropy = frac_per_sample_entropy 119 | 120 | self.diversity_gamma = diversity_gamma 121 | self.entropy_loss_weight = entropy_loss_weight 122 | 123 | # codebook scale 124 | 125 | self.codebook_scale = codebook_scale 126 | 127 | # commitment loss 128 | 129 | self.commitment_loss_weight = commitment_loss_weight 130 | 131 | # for no auxiliary loss, during inference 132 | 133 | self.register_buffer('mask', 2**torch.arange(codebook_dim - 1, -1, -1)) 134 | self.register_buffer('zero', torch.tensor(0.0), persistent=False) 135 | 136 | # codes 137 | 138 | all_codes = torch.arange(codebook_size) 139 | bits = ((all_codes[..., None].int() & self.mask) != 0).float() 140 | codebook = self.bits_to_codes(bits) 141 | 142 | self.register_buffer('codebook', codebook, persistent=False) 143 | 144 | def bits_to_codes(self, bits): 145 | return bits * self.codebook_scale * 2 - self.codebook_scale 146 | 147 | @property 148 | def dtype(self): 149 | return self.codebook.dtype 150 | 151 | def indices_to_codes(self, indices, project_out=True): 152 | is_img_or_video = indices.ndim >= (3 + 153 | int(self.keep_num_codebooks_dim)) 154 | 155 | if not self.keep_num_codebooks_dim: 156 | indices = rearrange(indices, '... -> ... 1') 157 | 158 | # indices to codes, which are bits of either -1 or 1 159 | 160 | bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) 161 | 162 | codes = self.bits_to_codes(bits) 163 | 164 | codes = rearrange(codes, '... c d -> ... (c d)') 165 | 166 | # whether to project codes out to original dimensions 167 | # if the input feature dimensions were not log2(codebook size) 168 | 169 | if project_out: 170 | codes = self.project_out(codes) 171 | 172 | # rearrange codes back to original shape 173 | 174 | if is_img_or_video: 175 | codes = rearrange(codes, 'b ... d -> b d ...') 176 | 177 | return codes 178 | 179 | @autocast(enabled=False) 180 | def forward( 181 | self, 182 | x, 183 | inv_temperature=100.0, 184 | return_loss_breakdown=False, 185 | mask=None, 186 | ): 187 | """ 188 | einstein notation 189 | b - batch 190 | n - sequence (or flattened spatial dimensions) 191 | d - feature dimension, which is also log2(codebook size) 192 | c - number of codebook dim 193 | """ 194 | 195 | x = x.float() 196 | 197 | is_img_or_video = x.ndim >= 4 198 | 199 | # standardize image or video into (batch, seq, dimension) 200 | 201 | if is_img_or_video: 202 | x = rearrange(x, 'b d ... -> b ... d') 203 | x, ps = pack_one(x, 'b * d') 204 | 205 | assert x.shape[ 206 | -1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' 207 | 208 | x = self.project_in(x) 209 | 210 | # split out number of codebooks 211 | 212 | x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks) 213 | 214 | # quantize by eq 3. 215 | 216 | original_input = x 217 | 218 | codebook_value = torch.ones_like(x) * self.codebook_scale 219 | quantized = torch.where(x > 0, codebook_value, -codebook_value) 220 | 221 | # use straight-through gradients (optionally with custom activation fn) if training 222 | 223 | if self.training: 224 | x = self.activation(x) 225 | x = x + (quantized - x).detach() 226 | else: 227 | x = quantized 228 | 229 | # calculate indices 230 | 231 | indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 232 | 'sum') 233 | 234 | # entropy aux loss 235 | 236 | if self.training: 237 | # the same as euclidean distance up to a constant 238 | distance = -2 * einsum('... i d, j d -> ... i j', original_input, 239 | self.codebook) 240 | 241 | prob = (-distance * inv_temperature).softmax(dim=-1) 242 | 243 | # account for mask 244 | 245 | if exists(mask): 246 | prob = prob[mask] 247 | else: 248 | prob = rearrange(prob, 'b n ... -> (b n) ...') 249 | 250 | # whether to only use a fraction of probs, for reducing memory 251 | 252 | if self.frac_per_sample_entropy < 1.0: 253 | num_tokens = prob.shape[0] 254 | num_sampled_tokens = int(num_tokens * 255 | self.frac_per_sample_entropy) 256 | rand_mask = torch.randn(num_tokens).argsort( 257 | dim=-1) < num_sampled_tokens 258 | per_sample_probs = prob[rand_mask] 259 | else: 260 | per_sample_probs = prob 261 | 262 | # calculate per sample entropy 263 | 264 | per_sample_entropy = entropy(per_sample_probs).mean() 265 | 266 | # distribution over all available tokens in the batch 267 | 268 | avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean') 269 | codebook_entropy = entropy(avg_prob).mean() 270 | 271 | # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions 272 | # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch 273 | 274 | entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy 275 | else: 276 | # if not training, just return dummy 0 277 | entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero 278 | 279 | # commit loss 280 | 281 | if self.training: 282 | commit_loss = F.mse_loss(original_input, 283 | quantized.detach(), 284 | reduction='none') 285 | 286 | if exists(mask): 287 | commit_loss = commit_loss[mask] 288 | 289 | commit_loss = commit_loss.mean() 290 | else: 291 | commit_loss = self.zero 292 | 293 | # merge back codebook dim 294 | 295 | x = rearrange(x, 'b n c d -> b n (c d)') 296 | 297 | # project out to feature dimension if needed 298 | 299 | x = self.project_out(x) 300 | 301 | # reconstitute image or video dimensions 302 | 303 | if is_img_or_video: 304 | x = unpack_one(x, ps, 'b * d') 305 | x = rearrange(x, 'b ... d -> b d ...') 306 | 307 | indices = unpack_one(indices, ps, 'b * c') 308 | 309 | # whether to remove single codebook dim 310 | 311 | if not self.keep_num_codebooks_dim: 312 | indices = rearrange(indices, '... 1 -> ...') 313 | 314 | # complete aux loss 315 | 316 | aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight 317 | 318 | ret = Return(x, indices, aux_loss) 319 | 320 | if not return_loss_breakdown: 321 | return ret 322 | 323 | return ret, LossBreakdown(per_sample_entropy, codebook_entropy, 324 | commit_loss) 325 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/temporal_ae.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, Union 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE, 6 | AttnBlock, Decoder, 7 | MemoryEfficientAttnBlock, 8 | ResnetBlock) 9 | from sgm.modules.diffusionmodules.openaimodel import (ResBlock, 10 | timestep_embedding) 11 | from sgm.modules.video_attention import VideoTransformerBlock 12 | from sgm.util import partialclass 13 | 14 | 15 | class VideoResBlock(ResnetBlock): 16 | 17 | def __init__( 18 | self, 19 | out_channels, 20 | *args, 21 | dropout=0.0, 22 | video_kernel_size=3, 23 | alpha=0.0, 24 | merge_strategy='learned', 25 | **kwargs, 26 | ): 27 | super().__init__(out_channels=out_channels, 28 | dropout=dropout, 29 | *args, 30 | **kwargs) 31 | if video_kernel_size is None: 32 | video_kernel_size = [3, 1, 1] 33 | self.time_stack = ResBlock( 34 | channels=out_channels, 35 | emb_channels=0, 36 | dropout=dropout, 37 | dims=3, 38 | use_scale_shift_norm=False, 39 | use_conv=False, 40 | up=False, 41 | down=False, 42 | kernel_size=video_kernel_size, 43 | use_checkpoint=False, 44 | skip_t_emb=True, 45 | ) 46 | 47 | self.merge_strategy = merge_strategy 48 | if self.merge_strategy == 'fixed': 49 | self.register_buffer('mix_factor', torch.Tensor([alpha])) 50 | elif self.merge_strategy == 'learned': 51 | self.register_parameter('mix_factor', 52 | torch.nn.Parameter(torch.Tensor([alpha]))) 53 | else: 54 | raise ValueError(f'unknown merge strategy {self.merge_strategy}') 55 | 56 | def get_alpha(self, bs): 57 | if self.merge_strategy == 'fixed': 58 | return self.mix_factor 59 | elif self.merge_strategy == 'learned': 60 | return torch.sigmoid(self.mix_factor) 61 | else: 62 | raise NotImplementedError() 63 | 64 | def forward(self, x, temb, skip_video=False, timesteps=None): 65 | if timesteps is None: 66 | timesteps = self.timesteps 67 | 68 | b, c, h, w = x.shape 69 | 70 | x = super().forward(x, temb) 71 | 72 | if not skip_video: 73 | x_mix = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps) 74 | 75 | x = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps) 76 | 77 | x = self.time_stack(x, temb) 78 | 79 | alpha = self.get_alpha(bs=b // timesteps) 80 | x = alpha * x + (1.0 - alpha) * x_mix 81 | 82 | x = rearrange(x, 'b c t h w -> (b t) c h w') 83 | return x 84 | 85 | 86 | class AE3DConv(torch.nn.Conv2d): 87 | 88 | def __init__(self, 89 | in_channels, 90 | out_channels, 91 | video_kernel_size=3, 92 | *args, 93 | **kwargs): 94 | super().__init__(in_channels, out_channels, *args, **kwargs) 95 | if isinstance(video_kernel_size, Iterable): 96 | padding = [int(k // 2) for k in video_kernel_size] 97 | else: 98 | padding = int(video_kernel_size // 2) 99 | 100 | self.time_mix_conv = torch.nn.Conv3d( 101 | in_channels=out_channels, 102 | out_channels=out_channels, 103 | kernel_size=video_kernel_size, 104 | padding=padding, 105 | ) 106 | 107 | def forward(self, input, timesteps, skip_video=False): 108 | x = super().forward(input) 109 | if skip_video: 110 | return x 111 | x = rearrange(x, '(b t) c h w -> b c t h w', t=timesteps) 112 | x = self.time_mix_conv(x) 113 | return rearrange(x, 'b c t h w -> (b t) c h w') 114 | 115 | 116 | class VideoBlock(AttnBlock): 117 | 118 | def __init__(self, 119 | in_channels: int, 120 | alpha: float = 0, 121 | merge_strategy: str = 'learned'): 122 | super().__init__(in_channels) 123 | # no context, single headed, as in base class 124 | self.time_mix_block = VideoTransformerBlock( 125 | dim=in_channels, 126 | n_heads=1, 127 | d_head=in_channels, 128 | checkpoint=False, 129 | ff_in=True, 130 | attn_mode='softmax', 131 | ) 132 | 133 | time_embed_dim = self.in_channels * 4 134 | self.video_time_embed = torch.nn.Sequential( 135 | torch.nn.Linear(self.in_channels, time_embed_dim), 136 | torch.nn.SiLU(), 137 | torch.nn.Linear(time_embed_dim, self.in_channels), 138 | ) 139 | 140 | self.merge_strategy = merge_strategy 141 | if self.merge_strategy == 'fixed': 142 | self.register_buffer('mix_factor', torch.Tensor([alpha])) 143 | elif self.merge_strategy == 'learned': 144 | self.register_parameter('mix_factor', 145 | torch.nn.Parameter(torch.Tensor([alpha]))) 146 | else: 147 | raise ValueError(f'unknown merge strategy {self.merge_strategy}') 148 | 149 | def forward(self, x, timesteps, skip_video=False): 150 | if skip_video: 151 | return super().forward(x) 152 | 153 | x_in = x 154 | x = self.attention(x) 155 | h, w = x.shape[2:] 156 | x = rearrange(x, 'b c h w -> b (h w) c') 157 | 158 | x_mix = x 159 | num_frames = torch.arange(timesteps, device=x.device) 160 | num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps) 161 | num_frames = rearrange(num_frames, 'b t -> (b t)') 162 | t_emb = timestep_embedding(num_frames, 163 | self.in_channels, 164 | repeat_only=False) 165 | emb = self.video_time_embed(t_emb) # b, n_channels 166 | emb = emb[:, None, :] 167 | x_mix = x_mix + emb 168 | 169 | alpha = self.get_alpha() 170 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 171 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 172 | 173 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 174 | x = self.proj_out(x) 175 | 176 | return x_in + x 177 | 178 | def get_alpha(self, ): 179 | if self.merge_strategy == 'fixed': 180 | return self.mix_factor 181 | elif self.merge_strategy == 'learned': 182 | return torch.sigmoid(self.mix_factor) 183 | else: 184 | raise NotImplementedError( 185 | f'unknown merge strategy {self.merge_strategy}') 186 | 187 | 188 | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): 189 | 190 | def __init__(self, 191 | in_channels: int, 192 | alpha: float = 0, 193 | merge_strategy: str = 'learned'): 194 | super().__init__(in_channels) 195 | # no context, single headed, as in base class 196 | self.time_mix_block = VideoTransformerBlock( 197 | dim=in_channels, 198 | n_heads=1, 199 | d_head=in_channels, 200 | checkpoint=False, 201 | ff_in=True, 202 | attn_mode='softmax-xformers', 203 | ) 204 | 205 | time_embed_dim = self.in_channels * 4 206 | self.video_time_embed = torch.nn.Sequential( 207 | torch.nn.Linear(self.in_channels, time_embed_dim), 208 | torch.nn.SiLU(), 209 | torch.nn.Linear(time_embed_dim, self.in_channels), 210 | ) 211 | 212 | self.merge_strategy = merge_strategy 213 | if self.merge_strategy == 'fixed': 214 | self.register_buffer('mix_factor', torch.Tensor([alpha])) 215 | elif self.merge_strategy == 'learned': 216 | self.register_parameter('mix_factor', 217 | torch.nn.Parameter(torch.Tensor([alpha]))) 218 | else: 219 | raise ValueError(f'unknown merge strategy {self.merge_strategy}') 220 | 221 | def forward(self, x, timesteps, skip_time_block=False): 222 | if skip_time_block: 223 | return super().forward(x) 224 | 225 | x_in = x 226 | x = self.attention(x) 227 | h, w = x.shape[2:] 228 | x = rearrange(x, 'b c h w -> b (h w) c') 229 | 230 | x_mix = x 231 | num_frames = torch.arange(timesteps, device=x.device) 232 | num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps) 233 | num_frames = rearrange(num_frames, 'b t -> (b t)') 234 | t_emb = timestep_embedding(num_frames, 235 | self.in_channels, 236 | repeat_only=False) 237 | emb = self.video_time_embed(t_emb) # b, n_channels 238 | emb = emb[:, None, :] 239 | x_mix = x_mix + emb 240 | 241 | alpha = self.get_alpha() 242 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 243 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 244 | 245 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 246 | x = self.proj_out(x) 247 | 248 | return x_in + x 249 | 250 | def get_alpha(self, ): 251 | if self.merge_strategy == 'fixed': 252 | return self.mix_factor 253 | elif self.merge_strategy == 'learned': 254 | return torch.sigmoid(self.mix_factor) 255 | else: 256 | raise NotImplementedError( 257 | f'unknown merge strategy {self.merge_strategy}') 258 | 259 | 260 | def make_time_attn( 261 | in_channels, 262 | attn_type='vanilla', 263 | attn_kwargs=None, 264 | alpha: float = 0, 265 | merge_strategy: str = 'learned', 266 | ): 267 | assert attn_type in [ 268 | 'vanilla', 269 | 'vanilla-xformers', 270 | ], f'attn_type {attn_type} not supported for spatio-temporal attention' 271 | print( 272 | f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" 273 | ) 274 | if not XFORMERS_IS_AVAILABLE and attn_type == 'vanilla-xformers': 275 | print( 276 | f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " 277 | f'This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}' 278 | ) 279 | attn_type = 'vanilla' 280 | 281 | if attn_type == 'vanilla': 282 | assert attn_kwargs is None 283 | return partialclass(VideoBlock, 284 | in_channels, 285 | alpha=alpha, 286 | merge_strategy=merge_strategy) 287 | elif attn_type == 'vanilla-xformers': 288 | print( 289 | f'building MemoryEfficientAttnBlock with {in_channels} in_channels...' 290 | ) 291 | return partialclass( 292 | MemoryEfficientVideoBlock, 293 | in_channels, 294 | alpha=alpha, 295 | merge_strategy=merge_strategy, 296 | ) 297 | else: 298 | return NotImplementedError() 299 | 300 | 301 | class Conv2DWrapper(torch.nn.Conv2d): 302 | 303 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: 304 | return super().forward(input) 305 | 306 | 307 | class VideoDecoder(Decoder): 308 | available_time_modes = ['all', 'conv-only', 'attn-only'] 309 | 310 | def __init__( 311 | self, 312 | *args, 313 | video_kernel_size: Union[int, list] = 3, 314 | alpha: float = 0.0, 315 | merge_strategy: str = 'learned', 316 | time_mode: str = 'conv-only', 317 | **kwargs, 318 | ): 319 | self.video_kernel_size = video_kernel_size 320 | self.alpha = alpha 321 | self.merge_strategy = merge_strategy 322 | self.time_mode = time_mode 323 | assert ( 324 | self.time_mode in self.available_time_modes 325 | ), f'time_mode parameter has to be in {self.available_time_modes}' 326 | super().__init__(*args, **kwargs) 327 | 328 | def get_last_layer(self, skip_time_mix=False, **kwargs): 329 | if self.time_mode == 'attn-only': 330 | raise NotImplementedError('TODO') 331 | else: 332 | return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight 333 | 334 | def _make_attn(self) -> Callable: 335 | if self.time_mode not in ['conv-only', 'only-last-conv']: 336 | return partialclass( 337 | make_time_attn, 338 | alpha=self.alpha, 339 | merge_strategy=self.merge_strategy, 340 | ) 341 | else: 342 | return super()._make_attn() 343 | 344 | def _make_conv(self) -> Callable: 345 | if self.time_mode != 'attn-only': 346 | return partialclass(AE3DConv, 347 | video_kernel_size=self.video_kernel_size) 348 | else: 349 | return Conv2DWrapper 350 | 351 | def _make_resblock(self) -> Callable: 352 | if self.time_mode not in ['attn-only', 'only-last-conv']: 353 | return partialclass( 354 | VideoResBlock, 355 | video_kernel_size=self.video_kernel_size, 356 | alpha=self.alpha, 357 | merge_strategy=self.merge_strategy, 358 | ) 359 | else: 360 | return super()._make_resblock() 361 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/autoencoding/vqvae/quantize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch import einsum 7 | 8 | 9 | class VectorQuantizer2(nn.Module): 10 | """ 11 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 12 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 13 | """ 14 | 15 | # NOTE: due to a bug the beta term was applied to the wrong term. for 16 | # backwards compatibility we use the buggy version by default, but you can 17 | # specify legacy=False to fix it. 18 | def __init__(self, 19 | n_e, 20 | e_dim, 21 | beta, 22 | remap=None, 23 | unknown_index='random', 24 | sane_index_shape=False, 25 | legacy=True): 26 | super().__init__() 27 | self.n_e = n_e 28 | self.e_dim = e_dim 29 | self.beta = beta 30 | self.legacy = legacy 31 | 32 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 33 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 34 | 35 | self.remap = remap 36 | if self.remap is not None: 37 | self.register_buffer('used', torch.tensor(np.load(self.remap))) 38 | self.re_embed = self.used.shape[0] 39 | self.unknown_index = unknown_index # "random" or "extra" or integer 40 | if self.unknown_index == 'extra': 41 | self.unknown_index = self.re_embed 42 | self.re_embed = self.re_embed + 1 43 | print(f'Remapping {self.n_e} indices to {self.re_embed} indices. ' 44 | f'Using {self.unknown_index} for unknown indices.') 45 | else: 46 | self.re_embed = n_e 47 | 48 | self.sane_index_shape = sane_index_shape 49 | 50 | def remap_to_used(self, inds): 51 | ishape = inds.shape 52 | assert len(ishape) > 1 53 | inds = inds.reshape(ishape[0], -1) 54 | used = self.used.to(inds) 55 | match = (inds[:, :, None] == used[None, None, ...]).long() 56 | new = match.argmax(-1) 57 | unknown = match.sum(2) < 1 58 | if self.unknown_index == 'random': 59 | new[unknown] = torch.randint( 60 | 0, self.re_embed, 61 | size=new[unknown].shape).to(device=new.device) 62 | else: 63 | new[unknown] = self.unknown_index 64 | return new.reshape(ishape) 65 | 66 | def unmap_to_all(self, inds): 67 | ishape = inds.shape 68 | assert len(ishape) > 1 69 | inds = inds.reshape(ishape[0], -1) 70 | used = self.used.to(inds) 71 | if self.re_embed > self.used.shape[0]: # extra token 72 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 73 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 74 | return back.reshape(ishape) 75 | 76 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 77 | assert temp is None or temp == 1.0, 'Only for interface compatible with Gumbel' 78 | assert rescale_logits == False, 'Only for interface compatible with Gumbel' 79 | assert return_logits == False, 'Only for interface compatible with Gumbel' 80 | # reshape z -> (batch, height, width, channel) and flatten 81 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 82 | z_flattened = z.view(-1, self.e_dim) 83 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 84 | 85 | d = (torch.sum(z_flattened**2, dim=1, keepdim=True) + 86 | torch.sum(self.embedding.weight**2, dim=1) - 87 | 2 * torch.einsum('bd,dn->bn', z_flattened, 88 | rearrange(self.embedding.weight, 'n d -> d n'))) 89 | 90 | min_encoding_indices = torch.argmin(d, dim=1) 91 | z_q = self.embedding(min_encoding_indices).view(z.shape) 92 | perplexity = None 93 | min_encodings = None 94 | 95 | # compute loss for embedding 96 | if not self.legacy: 97 | loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean( 98 | (z_q - z.detach())**2) 99 | else: 100 | loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean( 101 | (z_q - z.detach())**2) 102 | 103 | # preserve gradients 104 | z_q = z + (z_q - z).detach() 105 | 106 | # reshape back to match original input shape 107 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 108 | 109 | if self.remap is not None: 110 | min_encoding_indices = min_encoding_indices.reshape( 111 | z.shape[0], -1) # add batch axis 112 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 113 | min_encoding_indices = min_encoding_indices.reshape(-1, 114 | 1) # flatten 115 | 116 | if self.sane_index_shape: 117 | min_encoding_indices = min_encoding_indices.reshape( 118 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 119 | 120 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 121 | 122 | def get_codebook_entry(self, indices, shape): 123 | # shape specifying (batch, height, width, channel) 124 | if self.remap is not None: 125 | indices = indices.reshape(shape[0], -1) # add batch axis 126 | indices = self.unmap_to_all(indices) 127 | indices = indices.reshape(-1) # flatten again 128 | 129 | # get quantized latent vectors 130 | z_q = self.embedding(indices) 131 | 132 | if shape is not None: 133 | z_q = z_q.view(shape) 134 | # reshape back to match original input shape 135 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 136 | 137 | return z_q 138 | 139 | 140 | class GumbelQuantize(nn.Module): 141 | """ 142 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 143 | Gumbel Softmax trick quantizer 144 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 145 | https://arxiv.org/abs/1611.01144 146 | """ 147 | 148 | def __init__( 149 | self, 150 | num_hiddens, 151 | embedding_dim, 152 | n_embed, 153 | straight_through=True, 154 | kl_weight=5e-4, 155 | temp_init=1.0, 156 | use_vqinterface=True, 157 | remap=None, 158 | unknown_index='random', 159 | ): 160 | super().__init__() 161 | 162 | self.embedding_dim = embedding_dim 163 | self.n_embed = n_embed 164 | 165 | self.straight_through = straight_through 166 | self.temperature = temp_init 167 | self.kl_weight = kl_weight 168 | 169 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 170 | self.embed = nn.Embedding(n_embed, embedding_dim) 171 | 172 | self.use_vqinterface = use_vqinterface 173 | 174 | self.remap = remap 175 | if self.remap is not None: 176 | self.register_buffer('used', torch.tensor(np.load(self.remap))) 177 | self.re_embed = self.used.shape[0] 178 | self.unknown_index = unknown_index # "random" or "extra" or integer 179 | if self.unknown_index == 'extra': 180 | self.unknown_index = self.re_embed 181 | self.re_embed = self.re_embed + 1 182 | print( 183 | f'Remapping {self.n_embed} indices to {self.re_embed} indices. ' 184 | f'Using {self.unknown_index} for unknown indices.') 185 | else: 186 | self.re_embed = n_embed 187 | 188 | def remap_to_used(self, inds): 189 | ishape = inds.shape 190 | assert len(ishape) > 1 191 | inds = inds.reshape(ishape[0], -1) 192 | used = self.used.to(inds) 193 | match = (inds[:, :, None] == used[None, None, ...]).long() 194 | new = match.argmax(-1) 195 | unknown = match.sum(2) < 1 196 | if self.unknown_index == 'random': 197 | new[unknown] = torch.randint( 198 | 0, self.re_embed, 199 | size=new[unknown].shape).to(device=new.device) 200 | else: 201 | new[unknown] = self.unknown_index 202 | return new.reshape(ishape) 203 | 204 | def unmap_to_all(self, inds): 205 | ishape = inds.shape 206 | assert len(ishape) > 1 207 | inds = inds.reshape(ishape[0], -1) 208 | used = self.used.to(inds) 209 | if self.re_embed > self.used.shape[0]: # extra token 210 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 211 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 212 | return back.reshape(ishape) 213 | 214 | def forward(self, z, temp=None, return_logits=False): 215 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work 216 | hard = self.straight_through if self.training else True 217 | temp = self.temperature if temp is None else temp 218 | 219 | logits = self.proj(z) 220 | if self.remap is not None: 221 | # continue only with used logits 222 | full_zeros = torch.zeros_like(logits) 223 | logits = logits[:, self.used, ...] 224 | 225 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 226 | if self.remap is not None: 227 | # go back to all entries but unused set to zero 228 | full_zeros[:, self.used, ...] = soft_one_hot 229 | soft_one_hot = full_zeros 230 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, 231 | self.embed.weight) 232 | 233 | # + kl divergence to the prior loss 234 | qy = F.softmax(logits, dim=1) 235 | diff = self.kl_weight * torch.sum( 236 | qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 237 | 238 | ind = soft_one_hot.argmax(dim=1) 239 | if self.remap is not None: 240 | ind = self.remap_to_used(ind) 241 | if self.use_vqinterface: 242 | if return_logits: 243 | return z_q, diff, (None, None, ind), logits 244 | return z_q, diff, (None, None, ind) 245 | return z_q, diff, ind 246 | 247 | def get_codebook_entry(self, indices, shape): 248 | b, h, w, c = shape 249 | assert b * h * w == indices.shape[0] 250 | indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) 251 | if self.remap is not None: 252 | indices = self.unmap_to_all(indices) 253 | one_hot = F.one_hot(indices, 254 | num_classes=self.n_embed).permute(0, 3, 1, 255 | 2).float() 256 | z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) 257 | return z_q 258 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/cp_enc_dec.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.distributed 5 | import torch.nn as nn 6 | 7 | from ..util import (get_context_parallel_group, get_context_parallel_rank, 8 | get_context_parallel_world_size) 9 | 10 | _USE_CP = True 11 | 12 | 13 | def cast_tuple(t, length=1): 14 | return t if isinstance(t, tuple) else ((t, ) * length) 15 | 16 | 17 | def divisible_by(num, den): 18 | return (num % den) == 0 19 | 20 | 21 | def is_odd(n): 22 | return not divisible_by(n, 2) 23 | 24 | 25 | def exists(v): 26 | return v is not None 27 | 28 | 29 | def pair(t): 30 | return t if isinstance(t, tuple) else (t, t) 31 | 32 | 33 | def get_timestep_embedding(timesteps, embedding_dim): 34 | """ 35 | This matches the implementation in Denoising Diffusion Probabilistic Models: 36 | From Fairseq. 37 | Build sinusoidal embeddings. 38 | This matches the implementation in tensor2tensor, but differs slightly 39 | from the description in Section 3.5 of "Attention Is All You Need". 40 | """ 41 | assert len(timesteps.shape) == 1 42 | 43 | half_dim = embedding_dim // 2 44 | emb = math.log(10000) / (half_dim - 1) 45 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 46 | emb = emb.to(device=timesteps.device) 47 | emb = timesteps.float()[:, None] * emb[None, :] 48 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 49 | if embedding_dim % 2 == 1: # zero pad 50 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 51 | return emb 52 | 53 | 54 | def nonlinearity(x): 55 | # swish 56 | return x * torch.sigmoid(x) 57 | 58 | 59 | def leaky_relu(p=0.1): 60 | return nn.LeakyReLU(p) 61 | 62 | 63 | def _split(input_, dim): 64 | cp_world_size = get_context_parallel_world_size() 65 | 66 | if cp_world_size == 1: 67 | return input_ 68 | 69 | cp_rank = get_context_parallel_rank() 70 | 71 | # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) 72 | 73 | inpu_first_frame_ = input_.transpose(0, 74 | dim)[:1].transpose(0, 75 | dim).contiguous() 76 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() 77 | dim_size = input_.size()[dim] // cp_world_size 78 | 79 | input_list = torch.split(input_, dim_size, dim=dim) 80 | output = input_list[cp_rank] 81 | 82 | if cp_rank == 0: 83 | output = torch.cat([inpu_first_frame_, output], dim=dim) 84 | output = output.contiguous() 85 | 86 | # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) 87 | 88 | return output 89 | 90 | 91 | def _gather(input_, dim): 92 | cp_world_size = get_context_parallel_world_size() 93 | 94 | # Bypass the function if context parallel is 1 95 | if cp_world_size == 1: 96 | return input_ 97 | 98 | group = get_context_parallel_group() 99 | cp_rank = get_context_parallel_rank() 100 | 101 | # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) 102 | 103 | input_first_frame_ = input_.transpose(0, 104 | dim)[:1].transpose(0, 105 | dim).contiguous() 106 | if cp_rank == 0: 107 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() 108 | 109 | tensor_list = [ 110 | torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim)) 111 | ] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)] 112 | 113 | if cp_rank == 0: 114 | input_ = torch.cat([input_first_frame_, input_], dim=dim) 115 | 116 | tensor_list[cp_rank] = input_ 117 | torch.distributed.all_gather(tensor_list, input_, group=group) 118 | 119 | output = torch.cat(tensor_list, dim=dim).contiguous() 120 | 121 | # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) 122 | 123 | return output 124 | 125 | 126 | def _conv_split(input_, dim, kernel_size): 127 | cp_world_size = get_context_parallel_world_size() 128 | 129 | # Bypass the function if context parallel is 1 130 | if cp_world_size == 1: 131 | return input_ 132 | 133 | # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) 134 | 135 | cp_rank = get_context_parallel_rank() 136 | 137 | dim_size = (input_.size()[dim] - kernel_size) // cp_world_size 138 | 139 | if cp_rank == 0: 140 | output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose( 141 | dim, 0) 142 | else: 143 | output = input_.transpose( 144 | dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + 145 | kernel_size].transpose(dim, 0) 146 | output = output.contiguous() 147 | 148 | # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) 149 | 150 | return output 151 | 152 | 153 | def _conv_gather(input_, dim, kernel_size): 154 | cp_world_size = get_context_parallel_world_size() 155 | 156 | # Bypass the function if context parallel is 1 157 | if cp_world_size == 1: 158 | return input_ 159 | 160 | group = get_context_parallel_group() 161 | cp_rank = get_context_parallel_rank() 162 | 163 | # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) 164 | 165 | input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose( 166 | 0, dim).contiguous() 167 | if cp_rank == 0: 168 | input_ = input_.transpose(0, dim)[kernel_size:].transpose( 169 | 0, dim).contiguous() 170 | else: 171 | input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose( 172 | 0, dim).contiguous() 173 | 174 | tensor_list = [ 175 | torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim)) 176 | ] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)] 177 | if cp_rank == 0: 178 | input_ = torch.cat([input_first_kernel_, input_], dim=dim) 179 | 180 | tensor_list[cp_rank] = input_ 181 | torch.distributed.all_gather(tensor_list, input_, group=group) 182 | 183 | # Note: torch.cat already creates a contiguous tensor. 184 | output = torch.cat(tensor_list, dim=dim).contiguous() 185 | 186 | # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) 187 | 188 | return output 189 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser 2 | from .discretizer import Discretization 3 | from .model import Decoder, Encoder, Model 4 | from .openaimodel import UNetModel 5 | from .sampling import BaseDiffusionSampler 6 | from .wrappers import OpenAIWrapper 7 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | 8 | 9 | class Denoiser(nn.Module): 10 | 11 | def __init__(self, weighting_config, scaling_config): 12 | super().__init__() 13 | 14 | self.weighting = instantiate_from_config(weighting_config) 15 | self.scaling = instantiate_from_config(scaling_config) 16 | 17 | def possibly_quantize_sigma(self, sigma): 18 | return sigma 19 | 20 | def possibly_quantize_c_noise(self, c_noise): 21 | return c_noise 22 | 23 | def w(self, sigma): 24 | return self.weighting(sigma) 25 | 26 | def forward( 27 | self, 28 | network: nn.Module, 29 | input: torch.Tensor, 30 | sigma: torch.Tensor, 31 | cond: Dict, 32 | **additional_model_inputs, 33 | ) -> torch.Tensor: 34 | sigma = self.possibly_quantize_sigma(sigma) 35 | sigma_shape = sigma.shape 36 | sigma = append_dims(sigma, input.ndim) 37 | c_skip, c_out, c_in, c_noise = self.scaling(sigma, 38 | **additional_model_inputs) 39 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 40 | return network(input * c_in, c_noise, cond, ** 41 | additional_model_inputs) * c_out + input * c_skip 42 | 43 | 44 | class DiscreteDenoiser(Denoiser): 45 | 46 | def __init__( 47 | self, 48 | weighting_config, 49 | scaling_config, 50 | num_idx, 51 | discretization_config, 52 | do_append_zero=False, 53 | quantize_c_noise=True, 54 | flip=True, 55 | ): 56 | super().__init__(weighting_config, scaling_config) 57 | sigmas = instantiate_from_config(discretization_config)( 58 | num_idx, do_append_zero=do_append_zero, flip=flip) 59 | self.sigmas = sigmas 60 | # self.register_buffer("sigmas", sigmas) 61 | self.quantize_c_noise = quantize_c_noise 62 | 63 | def sigma_to_idx(self, sigma): 64 | dists = sigma - self.sigmas.to(sigma.device)[:, None] 65 | return dists.abs().argmin(dim=0).view(sigma.shape) 66 | 67 | def idx_to_sigma(self, idx): 68 | return self.sigmas.to(idx.device)[idx] 69 | 70 | def possibly_quantize_sigma(self, sigma): 71 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 72 | 73 | def possibly_quantize_c_noise(self, c_noise): 74 | if self.quantize_c_noise: 75 | return self.sigma_to_idx(c_noise) 76 | else: 77 | return c_noise 78 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | 9 | @abstractmethod 10 | def __call__( 11 | self, sigma: torch.Tensor 12 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 13 | pass 14 | 15 | 16 | class EDMScaling: 17 | 18 | def __init__(self, sigma_data: float = 0.5): 19 | self.sigma_data = sigma_data 20 | 21 | def __call__( 22 | self, sigma: torch.Tensor 23 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 24 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 25 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2)**0.5 26 | c_in = 1 / (sigma**2 + self.sigma_data**2)**0.5 27 | c_noise = 0.25 * sigma.log() 28 | return c_skip, c_out, c_in, c_noise 29 | 30 | 31 | class EpsScaling: 32 | 33 | def __call__( 34 | self, sigma: torch.Tensor 35 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 36 | c_skip = torch.ones_like(sigma, device=sigma.device) 37 | c_out = -sigma 38 | c_in = 1 / (sigma**2 + 1.0)**0.5 39 | c_noise = sigma.clone() 40 | return c_skip, c_out, c_in, c_noise 41 | 42 | 43 | class VScaling: 44 | 45 | def __call__( 46 | self, sigma: torch.Tensor 47 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 48 | c_skip = 1.0 / (sigma**2 + 1.0) 49 | c_out = -sigma / (sigma**2 + 1.0)**0.5 50 | c_in = 1.0 / (sigma**2 + 1.0)**0.5 51 | c_noise = sigma.clone() 52 | return c_skip, c_out, c_in, c_noise 53 | 54 | 55 | class VScalingWithEDMcNoise(DenoiserScaling): 56 | 57 | def __call__( 58 | self, sigma: torch.Tensor 59 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 60 | c_skip = 1.0 / (sigma**2 + 1.0) 61 | c_out = -sigma / (sigma**2 + 1.0)**0.5 62 | c_in = 1.0 / (sigma**2 + 1.0)**0.5 63 | c_noise = 0.25 * sigma.log() 64 | return c_skip, c_out, c_in, c_noise 65 | 66 | 67 | class VideoScaling: # similar to VScaling 68 | 69 | def __call__( 70 | self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs 71 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 72 | c_skip = alphas_cumprod_sqrt 73 | c_out = -((1 - alphas_cumprod_sqrt**2)**0.5) 74 | c_in = torch.ones_like(alphas_cumprod_sqrt, 75 | device=alphas_cumprod_sqrt.device) 76 | c_noise = additional_model_inputs['idx'].clone() 77 | return c_skip, c_out, c_in, c_noise 78 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | 6 | def __call__(self, sigma): 7 | return torch.ones_like(sigma, device=sigma.device) 8 | 9 | 10 | class EDMWeighting: 11 | 12 | def __init__(self, sigma_data=0.5): 13 | self.sigma_data = sigma_data 14 | 15 | def __call__(self, sigma): 16 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data)**2 17 | 18 | 19 | class VWeighting(EDMWeighting): 20 | 21 | def __init__(self): 22 | super().__init__(sigma_data=1.0) 23 | 24 | 25 | class EpsWeighting: 26 | 27 | def __call__(self, sigma): 28 | return sigma**-2.0 29 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps(num_substeps: int, 12 | max_step: int) -> np.ndarray: 13 | return np.linspace(max_step - 1, 0, num_substeps, 14 | endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | 19 | def __call__(self, 20 | n, 21 | do_append_zero=True, 22 | device='cpu', 23 | flip=False, 24 | return_idx=False): 25 | if return_idx: 26 | sigmas, idx = self.get_sigmas(n, 27 | device=device, 28 | return_idx=return_idx) 29 | else: 30 | sigmas = self.get_sigmas(n, device=device, return_idx=return_idx) 31 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 32 | if return_idx: 33 | return sigmas if not flip else torch.flip(sigmas, (0, )), idx 34 | else: 35 | return sigmas if not flip else torch.flip(sigmas, (0, )) 36 | 37 | @abstractmethod 38 | def get_sigmas(self, n, device): 39 | pass 40 | 41 | 42 | class EDMDiscretization(Discretization): 43 | 44 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 45 | self.sigma_min = sigma_min 46 | self.sigma_max = sigma_max 47 | self.rho = rho 48 | 49 | def get_sigmas(self, n, device='cpu'): 50 | ramp = torch.linspace(0, 1, n, device=device) 51 | min_inv_rho = self.sigma_min**(1 / self.rho) 52 | max_inv_rho = self.sigma_max**(1 / self.rho) 53 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**self.rho 54 | return sigmas 55 | 56 | 57 | class LegacyDDPMDiscretization(Discretization): 58 | 59 | def __init__( 60 | self, 61 | linear_start=0.00085, 62 | linear_end=0.0120, 63 | num_timesteps=1000, 64 | ): 65 | super().__init__() 66 | self.num_timesteps = num_timesteps 67 | betas = make_beta_schedule('linear', 68 | num_timesteps, 69 | linear_start=linear_start, 70 | linear_end=linear_end) 71 | alphas = 1.0 - betas 72 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 73 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 74 | 75 | def get_sigmas(self, n, device='cpu'): 76 | if n < self.num_timesteps: 77 | timesteps = generate_roughly_equally_spaced_steps( 78 | n, self.num_timesteps) 79 | alphas_cumprod = self.alphas_cumprod[timesteps] 80 | elif n == self.num_timesteps: 81 | alphas_cumprod = self.alphas_cumprod 82 | else: 83 | raise ValueError 84 | 85 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 86 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod)**0.5 87 | return torch.flip(sigmas, (0, )) # sigma_t: 14.4 -> 0.029 88 | 89 | 90 | class ZeroSNRDDPMDiscretization(Discretization): 91 | 92 | def __init__( 93 | self, 94 | linear_start=0.00085, 95 | linear_end=0.0120, 96 | num_timesteps=1000, 97 | shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) 98 | keep_start=False, 99 | post_shift=False, 100 | ): 101 | super().__init__() 102 | if keep_start and not post_shift: 103 | linear_start = linear_start / (shift_scale + 104 | (1 - shift_scale) * linear_start) 105 | self.num_timesteps = num_timesteps 106 | betas = make_beta_schedule('linear', 107 | num_timesteps, 108 | linear_start=linear_start, 109 | linear_end=linear_end) 110 | alphas = 1.0 - betas 111 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 112 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 113 | 114 | # SNR shift 115 | if not post_shift: 116 | self.alphas_cumprod = self.alphas_cumprod / ( 117 | shift_scale + (1 - shift_scale) * self.alphas_cumprod) 118 | 119 | self.post_shift = post_shift 120 | self.shift_scale = shift_scale 121 | 122 | def get_sigmas(self, n, device='cpu', return_idx=False): 123 | if n < self.num_timesteps: 124 | timesteps = generate_roughly_equally_spaced_steps( 125 | n, self.num_timesteps) 126 | alphas_cumprod = self.alphas_cumprod[timesteps] 127 | elif n == self.num_timesteps: 128 | alphas_cumprod = self.alphas_cumprod 129 | else: 130 | raise ValueError 131 | 132 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 133 | alphas_cumprod = to_torch(alphas_cumprod) 134 | alphas_cumprod_sqrt = alphas_cumprod.sqrt() 135 | alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone() 136 | alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() 137 | 138 | alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T 139 | alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - 140 | alphas_cumprod_sqrt_T) 141 | 142 | if self.post_shift: 143 | alphas_cumprod_sqrt = ( 144 | alphas_cumprod_sqrt**2 / 145 | (self.shift_scale + 146 | (1 - self.shift_scale) * alphas_cumprod_sqrt**2))**0.5 147 | 148 | if return_idx: 149 | return torch.flip(alphas_cumprod_sqrt, (0, )), timesteps 150 | else: 151 | return torch.flip(alphas_cumprod_sqrt, 152 | (0, )) # sqrt(alpha_t): 0 -> 0.99 153 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from abc import ABC, abstractmethod 4 | from functools import partial 5 | from typing import Dict, List, Optional, Tuple, Union 6 | 7 | import torch 8 | from einops import rearrange, repeat 9 | 10 | from ...util import append_dims, default, instantiate_from_config 11 | 12 | 13 | class Guider(ABC): 14 | 15 | @abstractmethod 16 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 17 | pass 18 | 19 | def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, 20 | uc: Dict) -> Tuple[torch.Tensor, float, Dict]: 21 | pass 22 | 23 | 24 | class VanillaCFG: 25 | """ 26 | implements parallelized CFG 27 | """ 28 | 29 | def __init__(self, scale, dyn_thresh_config=None): 30 | self.scale = scale 31 | scale_schedule = lambda scale, sigma: scale # independent of step 32 | self.scale_schedule = partial(scale_schedule, scale) 33 | self.dyn_thresh = instantiate_from_config( 34 | default( 35 | dyn_thresh_config, 36 | { 37 | 'target': 38 | 'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding' 39 | }, 40 | )) 41 | 42 | def __call__(self, x, sigma, scale=None): 43 | x_u, x_c = x.chunk(2) 44 | scale_value = default(scale, self.scale_schedule(sigma)) 45 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 46 | return x_pred 47 | 48 | def prepare_inputs(self, x, s, c, uc): 49 | c_out = dict() 50 | 51 | for k in c: 52 | if k in ['vector', 'crossattn', 'concat']: 53 | c_out[k] = torch.cat((uc[k], c[k]), 0) 54 | else: 55 | assert c[k] == uc[k] 56 | c_out[k] = c[k] 57 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 58 | 59 | 60 | # class DynamicCFG(VanillaCFG): 61 | 62 | # def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): 63 | # super().__init__(scale, dyn_thresh_config) 64 | # scale_schedule = (lambda scale, sigma, step_index: 1 + scale * 65 | # (1 - math.cos(math.pi * 66 | # (step_index / num_steps)**exp)) / 2) 67 | # self.scale_schedule = partial(scale_schedule, scale) 68 | # self.dyn_thresh = instantiate_from_config( 69 | # default( 70 | # dyn_thresh_config, 71 | # { 72 | # 'target': 73 | # 'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding' 74 | # }, 75 | # )) 76 | 77 | # def __call__(self, x, sigma, step_index, scale=None): 78 | # x_u, x_c = x.chunk(2) 79 | # scale_value = self.scale_schedule(sigma, step_index.item()) 80 | # x_pred = self.dyn_thresh(x_u, x_c, scale_value) 81 | # return x_pred 82 | 83 | 84 | class DynamicCFG(VanillaCFG): 85 | 86 | def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): 87 | super().__init__(scale, dyn_thresh_config) 88 | 89 | self.scale = scale 90 | self.num_steps = num_steps 91 | self.exp = exp 92 | scale_schedule = (lambda scale, sigma, step_index: 1 + scale * 93 | (1 - math.cos(math.pi * 94 | (step_index / num_steps)**exp)) / 2) 95 | 96 | #self.scale_schedule = partial(scale_schedule, scale) 97 | self.dyn_thresh = instantiate_from_config( 98 | default( 99 | dyn_thresh_config, 100 | { 101 | 'target': 102 | 'sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding' 103 | }, 104 | )) 105 | 106 | def scale_schedule_dy(self, sigma, step_index): 107 | # print(self.scale) 108 | return 1 + self.scale * ( 109 | 1 - math.cos(math.pi * 110 | (step_index / self.num_steps)**self.exp)) / 2 111 | 112 | def __call__(self, x, sigma, step_index, scale=None): 113 | x_u, x_c = x.chunk(2) 114 | scale_value = self.scale_schedule_dy(sigma, step_index.item()) 115 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 116 | return x_pred 117 | 118 | 119 | class IdentityGuider: 120 | 121 | def __call__(self, x, sigma): 122 | return x 123 | 124 | def prepare_inputs(self, x, s, c, uc): 125 | c_out = dict() 126 | 127 | for k in c: 128 | c_out[k] = c[k] 129 | 130 | return x, s, c_out 131 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from omegaconf import ListConfig 6 | 7 | from sat import mpu 8 | 9 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 10 | from ...util import append_dims, instantiate_from_config 11 | 12 | 13 | class StandardDiffusionLoss(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | sigma_sampler_config, 18 | type='l2', 19 | offset_noise_level=0.0, 20 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, 21 | ): 22 | super().__init__() 23 | 24 | assert type in ['l2', 'l1', 'lpips'] 25 | 26 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 27 | 28 | self.type = type 29 | self.offset_noise_level = offset_noise_level 30 | 31 | if type == 'lpips': 32 | self.lpips = LPIPS().eval() 33 | 34 | if not batch2model_keys: 35 | batch2model_keys = [] 36 | 37 | if isinstance(batch2model_keys, str): 38 | batch2model_keys = [batch2model_keys] 39 | 40 | self.batch2model_keys = set(batch2model_keys) 41 | 42 | def __call__(self, network, denoiser, conditioner, input, batch): 43 | cond = conditioner(batch) 44 | additional_model_inputs = { 45 | key: batch[key] 46 | for key in self.batch2model_keys.intersection(batch) 47 | } 48 | 49 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device) 50 | noise = torch.randn_like(input) 51 | if self.offset_noise_level > 0.0: 52 | noise = (noise + append_dims( 53 | torch.randn(input.shape[0]).to(input.device), input.ndim) * 54 | self.offset_noise_level) 55 | noise = noise.to(input.dtype) 56 | noised_input = input.float() + noise * append_dims(sigmas, input.ndim) 57 | model_output = denoiser(network, noised_input, sigmas, cond, 58 | **additional_model_inputs) 59 | w = append_dims(denoiser.w(sigmas), input.ndim) 60 | return self.get_loss(model_output, input, w) 61 | 62 | def get_loss(self, model_output, target, w): 63 | if self.type == 'l2': 64 | return torch.mean( 65 | (w * (model_output - target)**2).reshape(target.shape[0], 66 | -1), 1) 67 | elif self.type == 'l1': 68 | return torch.mean((w * (model_output - target).abs()).reshape( 69 | target.shape[0], -1), 1) 70 | elif self.type == 'lpips': 71 | loss = self.lpips(model_output, target).reshape(-1) 72 | return loss 73 | 74 | 75 | class VideoDiffusionLoss(StandardDiffusionLoss): 76 | 77 | def __init__(self, 78 | block_scale=None, 79 | block_size=None, 80 | min_snr_value=None, 81 | fixed_frames=0, 82 | **kwargs): 83 | self.fixed_frames = fixed_frames 84 | self.block_scale = block_scale 85 | self.block_size = block_size 86 | self.min_snr_value = min_snr_value 87 | super().__init__(**kwargs) 88 | 89 | def __call__(self, network, denoiser, conditioner, input, batch): 90 | cond = conditioner(batch) 91 | additional_model_inputs = { 92 | key: batch[key] 93 | for key in self.batch2model_keys.intersection(batch) 94 | } 95 | 96 | alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], 97 | return_idx=True) 98 | #tensor([0.8585]) 99 | 100 | if 'ref_noise_step' in self.share_cache: 101 | 102 | print(self.share_cache['ref_noise_step']) 103 | ref_noise_step = self.share_cache['ref_noise_step'] 104 | ref_alphas_cumprod_sqrt = self.sigma_sampler.idx_to_sigma( 105 | torch.zeros(input.shape[0]).fill_(ref_noise_step).long()) 106 | ref_alphas_cumprod_sqrt = ref_alphas_cumprod_sqrt.to(input.device) 107 | ref_x = self.share_cache['ref_x'] 108 | ref_noise = torch.randn_like(ref_x) 109 | 110 | # *0.8505 + noise * 0.5128 sqrt(1-0.8505^2)**0.5 111 | ref_noised_input = ref_x * append_dims(ref_alphas_cumprod_sqrt, ref_x.ndim) \ 112 | + ref_noise * append_dims( 113 | (1 - ref_alphas_cumprod_sqrt**2) ** 0.5, ref_x.ndim 114 | ) 115 | self.share_cache['ref_x'] = ref_noised_input 116 | 117 | alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device) 118 | idx = idx.to(input.device) 119 | 120 | noise = torch.randn_like(input) 121 | 122 | # broadcast noise 123 | mp_size = mpu.get_model_parallel_world_size() 124 | global_rank = torch.distributed.get_rank() // mp_size 125 | src = global_rank * mp_size 126 | torch.distributed.broadcast(idx, 127 | src=src, 128 | group=mpu.get_model_parallel_group()) 129 | torch.distributed.broadcast(noise, 130 | src=src, 131 | group=mpu.get_model_parallel_group()) 132 | torch.distributed.broadcast(alphas_cumprod_sqrt, 133 | src=src, 134 | group=mpu.get_model_parallel_group()) 135 | 136 | additional_model_inputs['idx'] = idx 137 | 138 | if self.offset_noise_level > 0.0: 139 | noise = (noise + append_dims( 140 | torch.randn(input.shape[0]).to(input.device), input.ndim) * 141 | self.offset_noise_level) 142 | 143 | noised_input = input.float() * append_dims( 144 | alphas_cumprod_sqrt, input.ndim) + noise * append_dims( 145 | (1 - alphas_cumprod_sqrt**2)**0.5, input.ndim) 146 | 147 | if 'concat_images' in batch.keys(): 148 | cond['concat'] = batch['concat_images'] 149 | 150 | # [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx']) 151 | model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, 152 | cond, **additional_model_inputs) 153 | w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred 154 | 155 | if self.min_snr_value is not None: 156 | w = min(w, self.min_snr_value) 157 | return self.get_loss(model_output, input, w) 158 | 159 | def get_loss(self, model_output, target, w): 160 | if self.type == 'l2': 161 | # model_output.shape 162 | # torch.Size([1, 2, 16, 60, 88]) 163 | return torch.mean( 164 | (w * (model_output - target)**2).reshape(target.shape[0], 165 | -1), 1) 166 | elif self.type == 'l1': 167 | return torch.mean((w * (model_output - target).abs()).reshape( 168 | target.shape[0], -1), 1) 169 | elif self.type == 'lpips': 170 | loss = self.lpips(model_output, target).reshape(-1) 171 | return loss 172 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from scipy import integrate 4 | 5 | from ...util import append_dims 6 | 7 | 8 | class NoDynamicThresholding: 9 | 10 | def __call__(self, uncond, cond, scale): 11 | scale = append_dims(scale, cond.ndim) if isinstance( 12 | scale, torch.Tensor) else scale 13 | return uncond + scale * (cond - uncond) 14 | 15 | 16 | class StaticThresholding: 17 | 18 | def __call__(self, uncond, cond, scale): 19 | result = uncond + scale * (cond - uncond) 20 | result = torch.clamp(result, min=-1.0, max=1.0) 21 | return result 22 | 23 | 24 | def dynamic_threshold(x, p=0.95): 25 | N, T, C, H, W = x.shape 26 | x = rearrange(x, 'n t c h w -> n c (t h w)') 27 | l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), 28 | dim=-1, 29 | keepdim=True) 30 | s = torch.maximum(-l, r) 31 | threshold_mask = (s > 1).expand(-1, -1, H * W * T) 32 | if threshold_mask.any(): 33 | x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x) 34 | x = rearrange(x, 'n c (t h w) -> n t c h w', t=T, h=H, w=W) 35 | return x 36 | 37 | 38 | def dynamic_thresholding2(x0): 39 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. 40 | origin_dtype = x0.dtype 41 | x0 = x0.to(torch.float32) 42 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) 43 | s = append_dims(torch.maximum(s, 44 | torch.ones_like(s).to(s.device)), x0.dim()) 45 | x0 = torch.clamp(x0, -s, s) # / s 46 | return x0.to(origin_dtype) 47 | 48 | 49 | def latent_dynamic_thresholding(x0): 50 | p = 0.9995 51 | origin_dtype = x0.dtype 52 | x0 = x0.to(torch.float32) 53 | s = torch.quantile(torch.abs(x0), p, dim=2) 54 | s = append_dims(s, x0.dim()) 55 | x0 = torch.clamp(x0, -s, s) / s 56 | return x0.to(origin_dtype) 57 | 58 | 59 | def dynamic_thresholding3(x0): 60 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. 61 | origin_dtype = x0.dtype 62 | x0 = x0.to(torch.float32) 63 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) 64 | s = append_dims(torch.maximum(s, 65 | torch.ones_like(s).to(s.device)), x0.dim()) 66 | x0 = torch.clamp(x0, -s, s) # / s 67 | return x0.to(origin_dtype) 68 | 69 | 70 | class DynamicThresholding: 71 | 72 | def __call__(self, uncond, cond, scale): 73 | mean = uncond.mean() 74 | std = uncond.std() 75 | result = uncond + scale * (cond - uncond) 76 | result_mean, result_std = result.mean(), result.std() 77 | result = (result - result_mean) / result_std * std 78 | # result = dynamic_thresholding3(result) 79 | return result 80 | 81 | 82 | class DynamicThresholdingV1: 83 | 84 | def __init__(self, scale_factor): 85 | self.scale_factor = scale_factor 86 | 87 | def __call__(self, uncond, cond, scale): 88 | result = uncond + scale * (cond - uncond) 89 | unscaled_result = result / self.scale_factor 90 | B, T, C, H, W = unscaled_result.shape 91 | flattened = rearrange(unscaled_result, 'b t c h w -> b c (t h w)') 92 | means = flattened.mean(dim=2).unsqueeze(2) 93 | recentered = flattened - means 94 | magnitudes = recentered.abs().max() 95 | normalized = recentered / magnitudes 96 | thresholded = latent_dynamic_thresholding(normalized) 97 | denormalized = thresholded * magnitudes 98 | uncentered = denormalized + means 99 | unflattened = rearrange(uncentered, 100 | 'b c (t h w) -> b t c h w', 101 | t=T, 102 | h=H, 103 | w=W) 104 | scaled_result = unflattened * self.scale_factor 105 | return scaled_result 106 | 107 | 108 | class DynamicThresholdingV2: 109 | 110 | def __call__(self, uncond, cond, scale): 111 | B, T, C, H, W = uncond.shape 112 | diff = cond - uncond 113 | mim_target = uncond + diff * 4.0 114 | cfg_target = uncond + diff * 8.0 115 | 116 | mim_flattened = rearrange(mim_target, 'b t c h w -> b c (t h w)') 117 | cfg_flattened = rearrange(cfg_target, 'b t c h w -> b c (t h w)') 118 | mim_means = mim_flattened.mean(dim=2).unsqueeze(2) 119 | cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2) 120 | mim_centered = mim_flattened - mim_means 121 | cfg_centered = cfg_flattened - cfg_means 122 | 123 | mim_scaleref = mim_centered.std(dim=2).unsqueeze(2) 124 | cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2) 125 | 126 | cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref 127 | 128 | result = cfg_renormalized + cfg_means 129 | unflattened = rearrange(result, 130 | 'b c (t h w) -> b t c h w', 131 | t=T, 132 | h=H, 133 | w=W) 134 | 135 | return unflattened 136 | 137 | 138 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 139 | if order - 1 > i: 140 | raise ValueError(f'Order {order} too high for step {i}') 141 | 142 | def fn(tau): 143 | prod = 1.0 144 | for k in range(order): 145 | if j == k: 146 | continue 147 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 148 | return prod 149 | 150 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 151 | 152 | 153 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 154 | if not eta: 155 | return sigma_to, 0.0 156 | sigma_up = torch.minimum( 157 | sigma_to, 158 | eta * (sigma_to**2 * 159 | (sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5, 160 | ) 161 | sigma_down = (sigma_to**2 - sigma_up**2)**0.5 162 | return sigma_down, sigma_up 163 | 164 | 165 | def to_d(x, sigma, denoised): 166 | return (x - denoised) / append_dims(sigma, x.ndim) 167 | 168 | 169 | def to_neg_log_sigma(sigma): 170 | return sigma.log().neg() 171 | 172 | 173 | def to_sigma(neg_log_sigma): 174 | return neg_log_sigma.neg().exp() 175 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed 3 | 4 | from sat import mpu 5 | 6 | from ...util import default, instantiate_from_config 7 | 8 | 9 | class EDMSampling: 10 | 11 | def __init__(self, p_mean=-1.2, p_std=1.2): 12 | self.p_mean = p_mean 13 | self.p_std = p_std 14 | 15 | def __call__(self, n_samples, rand=None): 16 | log_sigma = self.p_mean + self.p_std * default( 17 | rand, torch.randn((n_samples, ))) 18 | return log_sigma.exp() 19 | 20 | 21 | class DiscreteSampling: 22 | 23 | def __init__(self, 24 | discretization_config, 25 | num_idx, 26 | do_append_zero=False, 27 | flip=True, 28 | uniform_sampling=False): 29 | self.num_idx = num_idx 30 | self.sigmas = instantiate_from_config(discretization_config)( 31 | num_idx, do_append_zero=do_append_zero, flip=flip) 32 | world_size = mpu.get_data_parallel_world_size() 33 | self.uniform_sampling = uniform_sampling 34 | if self.uniform_sampling: 35 | i = 1 36 | while True: 37 | if world_size % i != 0 or num_idx % (world_size // i) != 0: 38 | i += 1 39 | else: 40 | self.group_num = world_size // i 41 | break 42 | 43 | assert self.group_num > 0 44 | assert world_size % self.group_num == 0 45 | self.group_width = world_size // self.group_num # the number of rank in one group 46 | self.sigma_interval = self.num_idx // self.group_num 47 | 48 | def idx_to_sigma(self, idx): 49 | return self.sigmas[idx] 50 | 51 | def __call__(self, n_samples, rand=None, return_idx=False): 52 | if self.uniform_sampling: 53 | rank = mpu.get_data_parallel_rank() 54 | group_index = rank // self.group_width 55 | idx = default( 56 | rand, 57 | torch.randint(group_index * self.sigma_interval, 58 | (group_index + 1) * self.sigma_interval, 59 | (n_samples, )), 60 | ) 61 | else: 62 | idx = default( 63 | rand, 64 | torch.randint(0, self.num_idx, (n_samples, )), 65 | ) 66 | if return_idx: 67 | return self.idx_to_sigma(idx), idx 68 | else: 69 | return self.idx_to_sigma(idx) 70 | 71 | 72 | class PartialDiscreteSampling: 73 | 74 | def __init__(self, 75 | discretization_config, 76 | total_num_idx, 77 | partial_num_idx, 78 | do_append_zero=False, 79 | flip=True): 80 | self.total_num_idx = total_num_idx 81 | self.partial_num_idx = partial_num_idx 82 | self.sigmas = instantiate_from_config(discretization_config)( 83 | total_num_idx, do_append_zero=do_append_zero, flip=flip) 84 | 85 | def idx_to_sigma(self, idx): 86 | return self.sigmas[idx] 87 | 88 | def __call__(self, n_samples, rand=None): 89 | idx = default( 90 | rand, 91 | # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)), 92 | torch.randint(0, self.partial_num_idx, (n_samples, )), 93 | ) 94 | return self.idx_to_sigma(idx) 95 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = 'sgm.modules.diffusionmodules.wrappers.OpenAIWrapper' 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | 10 | def __init__(self, 11 | diffusion_model, 12 | compile_model: bool = False, 13 | dtype: torch.dtype = torch.float32): 14 | super().__init__() 15 | compile = (torch.compile if 16 | (version.parse(torch.__version__) >= version.parse('2.0.0')) 17 | and compile_model else lambda x: x) 18 | self.diffusion_model = compile(diffusion_model) 19 | self.dtype = dtype 20 | 21 | def forward(self, *args, **kwargs): 22 | return self.diffusion_model(*args, **kwargs) 23 | 24 | 25 | class OpenAIWrapper(IdentityWrapper): 26 | 27 | def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, 28 | **kwargs) -> torch.Tensor: 29 | for key in c: 30 | c[key] = c[key].to(self.dtype) 31 | 32 | if x.dim() == 4: 33 | x = torch.cat((x, c.get('concat', 34 | torch.Tensor([]).type_as(x))), 35 | dim=1) 36 | elif x.dim() == 5: 37 | x = torch.cat((x, c.get('concat', 38 | torch.Tensor([]).type_as(x))), 39 | dim=2) 40 | else: 41 | raise ValueError('Input tensor must be 4D or 5D') 42 | 43 | return self.diffusion_model( 44 | x, 45 | timesteps=t, 46 | context=c.get('crossattn', None), 47 | y=c.get('vector', None), 48 | **kwargs, 49 | ) 50 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /flashvideo/sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | 7 | def sample(self): 8 | raise NotImplementedError() 9 | 10 | def mode(self): 11 | raise NotImplementedError() 12 | 13 | 14 | class DiracDistribution(AbstractDistribution): 15 | 16 | def __init__(self, value): 17 | self.value = value 18 | 19 | def sample(self): 20 | return self.value 21 | 22 | def mode(self): 23 | return self.value 24 | 25 | 26 | class DiagonalGaussianDistribution: 27 | 28 | def __init__(self, parameters, deterministic=False): 29 | self.parameters = parameters 30 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 31 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 32 | self.deterministic = deterministic 33 | self.std = torch.exp(0.5 * self.logvar) 34 | self.var = torch.exp(self.logvar) 35 | if self.deterministic: 36 | self.var = self.std = torch.zeros_like( 37 | self.mean).to(device=self.parameters.device) 38 | 39 | def sample(self): 40 | # x = self.mean + self.std * torch.randn(self.mean.shape).to( 41 | # device=self.parameters.device 42 | # ) 43 | x = self.mean + self.std * torch.randn_like( 44 | self.mean).to(device=self.parameters.device) 45 | return x 46 | 47 | def kl(self, other=None): 48 | if self.deterministic: 49 | return torch.Tensor([0.0]) 50 | else: 51 | if other is None: 52 | return 0.5 * torch.sum( 53 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 54 | dim=[1, 2, 3], 55 | ) 56 | else: 57 | return 0.5 * torch.sum( 58 | torch.pow(self.mean - other.mean, 2) / other.var + 59 | self.var / other.var - 1.0 - self.logvar + other.logvar, 60 | dim=[1, 2, 3], 61 | ) 62 | 63 | def nll(self, sample, dims=[1, 2, 3]): 64 | if self.deterministic: 65 | return torch.Tensor([0.0]) 66 | logtwopi = np.log(2.0 * np.pi) 67 | return 0.5 * torch.sum( 68 | logtwopi + self.logvar + 69 | torch.pow(sample - self.mean, 2) / self.var, 70 | dim=dims, 71 | ) 72 | 73 | def mode(self): 74 | return self.mean 75 | 76 | 77 | def normal_kl(mean1, logvar1, mean2, logvar2): 78 | """ 79 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 80 | Compute the KL divergence between two gaussians. 81 | Shapes are automatically broadcasted, so batches can be compared to 82 | scalars, among other use cases. 83 | """ 84 | tensor = None 85 | for obj in (mean1, logvar1, mean2, logvar2): 86 | if isinstance(obj, torch.Tensor): 87 | tensor = obj 88 | break 89 | assert tensor is not None, 'at least one argument must be a Tensor' 90 | 91 | # Force variances to be Tensors. Broadcasting helps convert scalars to 92 | # Tensors, but it does not work for torch.exp(). 93 | logvar1, logvar2 = (x if isinstance(x, torch.Tensor) else 94 | torch.tensor(x).to(tensor) for x in (logvar1, logvar2)) 95 | 96 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + 97 | ((mean1 - mean2)**2) * torch.exp(-logvar2)) 98 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | 7 | def __init__(self, model, decay=0.9999, use_num_upates=True): 8 | super().__init__() 9 | if decay < 0.0 or decay > 1.0: 10 | raise ValueError('Decay must be between 0 and 1') 11 | 12 | self.m_name2s_name = {} 13 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 14 | self.register_buffer( 15 | 'num_updates', 16 | torch.tensor(0, dtype=torch.int) 17 | if use_num_upates else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace('.', '') 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, 39 | (1 + self.num_updates) / (10 + self.num_updates)) 40 | 41 | one_minus_decay = 1.0 - decay 42 | 43 | with torch.no_grad(): 44 | m_param = dict(model.named_parameters()) 45 | shadow_params = dict(self.named_buffers()) 46 | 47 | for key in m_param: 48 | if m_param[key].requires_grad: 49 | sname = self.m_name2s_name[key] 50 | shadow_params[sname] = shadow_params[sname].type_as( 51 | m_param[key]) 52 | shadow_params[sname].sub_( 53 | one_minus_decay * 54 | (shadow_params[sname] - m_param[key])) 55 | else: 56 | assert not key in self.m_name2s_name 57 | 58 | def copy_to(self, model): 59 | m_param = dict(model.named_parameters()) 60 | shadow_params = dict(self.named_buffers()) 61 | for key in m_param: 62 | if m_param[key].requires_grad: 63 | m_param[key].data.copy_( 64 | shadow_params[self.m_name2s_name[key]].data) 65 | else: 66 | assert not key in self.m_name2s_name 67 | 68 | def store(self, parameters): 69 | """ 70 | Save the current parameters for restoring later. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | temporarily stored. 74 | """ 75 | self.collected_params = [param.clone() for param in parameters] 76 | 77 | def restore(self, parameters): 78 | """ 79 | Restore the parameters stored with the `store` method. 80 | Useful to validate the model with EMA parameters without affecting the 81 | original optimization process. Store the parameters before the 82 | `copy_to` method. After validation (or model saving), use this to 83 | restore the former parameters. 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/FlashVideo/8de4ae2b2c468e78116ad821e8bcd2339282e2b8/flashvideo/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /flashvideo/sgm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from contextlib import nullcontext 3 | from functools import partial 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | import kornia 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange, repeat 11 | from omegaconf import ListConfig 12 | from torch.utils.checkpoint import checkpoint 13 | from transformers import T5EncoderModel, T5Tokenizer 14 | 15 | from ...util import (append_dims, autocast, count_params, default, 16 | disabled_train, expand_dims_like, instantiate_from_config) 17 | 18 | 19 | class AbstractEmbModel(nn.Module): 20 | 21 | def __init__(self): 22 | super().__init__() 23 | self._is_trainable = None 24 | self._ucg_rate = None 25 | self._input_key = None 26 | 27 | @property 28 | def is_trainable(self) -> bool: 29 | return self._is_trainable 30 | 31 | @property 32 | def ucg_rate(self) -> Union[float, torch.Tensor]: 33 | return self._ucg_rate 34 | 35 | @property 36 | def input_key(self) -> str: 37 | return self._input_key 38 | 39 | @is_trainable.setter 40 | def is_trainable(self, value: bool): 41 | self._is_trainable = value 42 | 43 | @ucg_rate.setter 44 | def ucg_rate(self, value: Union[float, torch.Tensor]): 45 | self._ucg_rate = value 46 | 47 | @input_key.setter 48 | def input_key(self, value: str): 49 | self._input_key = value 50 | 51 | @is_trainable.deleter 52 | def is_trainable(self): 53 | del self._is_trainable 54 | 55 | @ucg_rate.deleter 56 | def ucg_rate(self): 57 | del self._ucg_rate 58 | 59 | @input_key.deleter 60 | def input_key(self): 61 | del self._input_key 62 | 63 | 64 | class GeneralConditioner(nn.Module): 65 | OUTPUT_DIM2KEYS = {2: 'vector', 3: 'crossattn', 4: 'concat', 5: 'concat'} 66 | KEY2CATDIM = {'vector': 1, 'crossattn': 2, 'concat': 1} 67 | 68 | def __init__(self, 69 | emb_models: Union[List, ListConfig], 70 | cor_embs=[], 71 | cor_p=[]): 72 | super().__init__() 73 | embedders = [] 74 | for n, embconfig in enumerate(emb_models): 75 | embedder = instantiate_from_config(embconfig) 76 | assert isinstance( 77 | embedder, AbstractEmbModel 78 | ), f'embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel' 79 | embedder.is_trainable = embconfig.get('is_trainable', False) 80 | embedder.ucg_rate = embconfig.get('ucg_rate', 0.0) 81 | if not embedder.is_trainable: 82 | embedder.train = disabled_train 83 | for param in embedder.parameters(): 84 | param.requires_grad = False 85 | embedder.eval() 86 | print( 87 | f'Initialized embedder #{n}: {embedder.__class__.__name__} ' 88 | f'with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}' 89 | ) 90 | 91 | if 'input_key' in embconfig: 92 | embedder.input_key = embconfig['input_key'] 93 | elif 'input_keys' in embconfig: 94 | embedder.input_keys = embconfig['input_keys'] 95 | else: 96 | raise KeyError( 97 | f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" 98 | ) 99 | 100 | embedder.legacy_ucg_val = embconfig.get('legacy_ucg_value', None) 101 | if embedder.legacy_ucg_val is not None: 102 | embedder.ucg_prng = np.random.RandomState() 103 | 104 | embedders.append(embedder) 105 | self.embedders = nn.ModuleList(embedders) 106 | 107 | if len(cor_embs) > 0: 108 | assert len(cor_p) == 2**len(cor_embs) 109 | self.cor_embs = cor_embs 110 | self.cor_p = cor_p 111 | 112 | def possibly_get_ucg_val(self, embedder: AbstractEmbModel, 113 | batch: Dict) -> Dict: 114 | assert embedder.legacy_ucg_val is not None 115 | p = embedder.ucg_rate 116 | val = embedder.legacy_ucg_val 117 | for i in range(len(batch[embedder.input_key])): 118 | if embedder.ucg_prng.choice(2, p=[1 - p, p]): 119 | batch[embedder.input_key][i] = val 120 | return batch 121 | 122 | def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, 123 | cond_or_not) -> Dict: 124 | assert embedder.legacy_ucg_val is not None 125 | val = embedder.legacy_ucg_val 126 | for i in range(len(batch[embedder.input_key])): 127 | if cond_or_not[i]: 128 | batch[embedder.input_key][i] = val 129 | return batch 130 | 131 | def get_single_embedding( 132 | self, 133 | embedder, 134 | batch, 135 | output, 136 | cond_or_not: Optional[np.ndarray] = None, 137 | force_zero_embeddings: Optional[List] = None, 138 | ): 139 | embedding_context = nullcontext if embedder.is_trainable else torch.no_grad 140 | with embedding_context(): 141 | if hasattr(embedder, 'input_key') and (embedder.input_key 142 | is not None): 143 | if embedder.legacy_ucg_val is not None: 144 | if cond_or_not is None: 145 | batch = self.possibly_get_ucg_val(embedder, batch) 146 | else: 147 | batch = self.surely_get_ucg_val( 148 | embedder, batch, cond_or_not) 149 | emb_out = embedder(batch[embedder.input_key]) 150 | elif hasattr(embedder, 'input_keys'): 151 | emb_out = embedder(*[batch[k] for k in embedder.input_keys]) 152 | assert isinstance( 153 | emb_out, (torch.Tensor, list, tuple) 154 | ), f'encoder outputs must be tensors or a sequence, but got {type(emb_out)}' 155 | if not isinstance(emb_out, (list, tuple)): 156 | emb_out = [emb_out] 157 | for emb in emb_out: 158 | out_key = self.OUTPUT_DIM2KEYS[emb.dim()] 159 | if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: 160 | if cond_or_not is None: 161 | emb = (expand_dims_like( 162 | torch.bernoulli( 163 | (1.0 - embedder.ucg_rate) * 164 | torch.ones(emb.shape[0], device=emb.device)), 165 | emb, 166 | ) * emb) 167 | else: 168 | emb = (expand_dims_like( 169 | torch.tensor(1 - cond_or_not, 170 | dtype=emb.dtype, 171 | device=emb.device), 172 | emb, 173 | ) * emb) 174 | if hasattr(embedder, 'input_key' 175 | ) and embedder.input_key in force_zero_embeddings: 176 | emb = torch.zeros_like(emb) 177 | if out_key in output: 178 | output[out_key] = torch.cat((output[out_key], emb), 179 | self.KEY2CATDIM[out_key]) 180 | else: 181 | output[out_key] = emb 182 | return output 183 | 184 | def forward(self, 185 | batch: Dict, 186 | force_zero_embeddings: Optional[List] = None) -> Dict: 187 | output = dict() 188 | if force_zero_embeddings is None: 189 | force_zero_embeddings = [] 190 | 191 | if len(self.cor_embs) > 0: 192 | batch_size = len(batch[list(batch.keys())[0]]) 193 | rand_idx = np.random.choice(len(self.cor_p), 194 | size=(batch_size, ), 195 | p=self.cor_p) 196 | for emb_idx in self.cor_embs: 197 | cond_or_not = rand_idx % 2 198 | rand_idx //= 2 199 | output = self.get_single_embedding( 200 | self.embedders[emb_idx], 201 | batch, 202 | output=output, 203 | cond_or_not=cond_or_not, 204 | force_zero_embeddings=force_zero_embeddings, 205 | ) 206 | 207 | for i, embedder in enumerate(self.embedders): 208 | if i in self.cor_embs: 209 | continue 210 | output = self.get_single_embedding( 211 | embedder, 212 | batch, 213 | output=output, 214 | force_zero_embeddings=force_zero_embeddings) 215 | return output 216 | 217 | def get_unconditional_conditioning(self, 218 | batch_c, 219 | batch_uc=None, 220 | force_uc_zero_embeddings=None): 221 | if force_uc_zero_embeddings is None: 222 | force_uc_zero_embeddings = [] 223 | ucg_rates = list() 224 | for embedder in self.embedders: 225 | ucg_rates.append(embedder.ucg_rate) 226 | embedder.ucg_rate = 0.0 227 | cor_embs = self.cor_embs 228 | cor_p = self.cor_p 229 | self.cor_embs = [] 230 | self.cor_p = [] 231 | 232 | c = self(batch_c) 233 | uc = self(batch_c if batch_uc is None else batch_uc, 234 | force_uc_zero_embeddings) 235 | 236 | for embedder, rate in zip(self.embedders, ucg_rates): 237 | embedder.ucg_rate = rate 238 | self.cor_embs = cor_embs 239 | self.cor_p = cor_p 240 | 241 | return c, uc 242 | 243 | 244 | class FrozenT5Embedder(AbstractEmbModel): 245 | """Uses the T5 transformer encoder for text""" 246 | 247 | def __init__( 248 | self, 249 | model_dir='google/t5-v1_1-xxl', 250 | device='cuda', 251 | max_length=77, 252 | freeze=True, 253 | cache_dir=None, 254 | ): 255 | super().__init__() 256 | if model_dir != 'google/t5-v1_1-xxl': 257 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir) 258 | self.transformer = T5EncoderModel.from_pretrained(model_dir) 259 | else: 260 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir, 261 | cache_dir=cache_dir) 262 | self.transformer = T5EncoderModel.from_pretrained( 263 | model_dir, cache_dir=cache_dir) 264 | self.device = device 265 | self.max_length = max_length 266 | if freeze: 267 | self.freeze() 268 | 269 | def freeze(self): 270 | self.transformer = self.transformer.eval() 271 | 272 | for param in self.parameters(): 273 | param.requires_grad = False 274 | 275 | # @autocast 276 | def forward(self, text): 277 | batch_encoding = self.tokenizer( 278 | text, 279 | truncation=True, 280 | max_length=self.max_length, 281 | return_length=True, 282 | return_overflowing_tokens=False, 283 | padding='max_length', 284 | return_tensors='pt', 285 | ) 286 | tokens = batch_encoding['input_ids'].to(self.device) 287 | with torch.autocast('cuda', enabled=False): 288 | outputs = self.transformer(input_ids=tokens) 289 | z = outputs.last_hidden_state 290 | return z 291 | 292 | def encode(self, text): 293 | return self(text) 294 | -------------------------------------------------------------------------------- /flashvideo/sgm/modules/video_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..modules.attention import * 4 | from ..modules.diffusionmodules.util import (AlphaBlender, linear, 5 | timestep_embedding) 6 | 7 | 8 | class TimeMixSequential(nn.Sequential): 9 | 10 | def forward(self, x, context=None, timesteps=None): 11 | for layer in self: 12 | x = layer(x, context, timesteps) 13 | 14 | return x 15 | 16 | 17 | class VideoTransformerBlock(nn.Module): 18 | ATTENTION_MODES = { 19 | 'softmax': CrossAttention, 20 | 'softmax-xformers': MemoryEfficientCrossAttention, 21 | } 22 | 23 | def __init__( 24 | self, 25 | dim, 26 | n_heads, 27 | d_head, 28 | dropout=0.0, 29 | context_dim=None, 30 | gated_ff=True, 31 | checkpoint=True, 32 | timesteps=None, 33 | ff_in=False, 34 | inner_dim=None, 35 | attn_mode='softmax', 36 | disable_self_attn=False, 37 | disable_temporal_crossattention=False, 38 | switch_temporal_ca_to_sa=False, 39 | ): 40 | super().__init__() 41 | 42 | attn_cls = self.ATTENTION_MODES[attn_mode] 43 | 44 | self.ff_in = ff_in or inner_dim is not None 45 | if inner_dim is None: 46 | inner_dim = dim 47 | 48 | assert int(n_heads * d_head) == inner_dim 49 | 50 | self.is_res = inner_dim == dim 51 | 52 | if self.ff_in: 53 | self.norm_in = nn.LayerNorm(dim) 54 | self.ff_in = FeedForward(dim, 55 | dim_out=inner_dim, 56 | dropout=dropout, 57 | glu=gated_ff) 58 | 59 | self.timesteps = timesteps 60 | self.disable_self_attn = disable_self_attn 61 | if self.disable_self_attn: 62 | self.attn1 = attn_cls( 63 | query_dim=inner_dim, 64 | heads=n_heads, 65 | dim_head=d_head, 66 | context_dim=context_dim, 67 | dropout=dropout, 68 | ) # is a cross-attention 69 | else: 70 | self.attn1 = attn_cls(query_dim=inner_dim, 71 | heads=n_heads, 72 | dim_head=d_head, 73 | dropout=dropout) # is a self-attention 74 | 75 | self.ff = FeedForward(inner_dim, 76 | dim_out=dim, 77 | dropout=dropout, 78 | glu=gated_ff) 79 | 80 | if disable_temporal_crossattention: 81 | if switch_temporal_ca_to_sa: 82 | raise ValueError 83 | else: 84 | self.attn2 = None 85 | else: 86 | self.norm2 = nn.LayerNorm(inner_dim) 87 | if switch_temporal_ca_to_sa: 88 | self.attn2 = attn_cls(query_dim=inner_dim, 89 | heads=n_heads, 90 | dim_head=d_head, 91 | dropout=dropout) # is a self-attention 92 | else: 93 | self.attn2 = attn_cls( 94 | query_dim=inner_dim, 95 | context_dim=context_dim, 96 | heads=n_heads, 97 | dim_head=d_head, 98 | dropout=dropout, 99 | ) # is self-attn if context is none 100 | 101 | self.norm1 = nn.LayerNorm(inner_dim) 102 | self.norm3 = nn.LayerNorm(inner_dim) 103 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa 104 | 105 | self.checkpoint = checkpoint 106 | if self.checkpoint: 107 | print(f'{self.__class__.__name__} is using checkpointing') 108 | 109 | def forward(self, 110 | x: torch.Tensor, 111 | context: torch.Tensor = None, 112 | timesteps: int = None) -> torch.Tensor: 113 | if self.checkpoint: 114 | return checkpoint(self._forward, x, context, timesteps) 115 | else: 116 | return self._forward(x, context, timesteps=timesteps) 117 | 118 | def _forward(self, x, context=None, timesteps=None): 119 | assert self.timesteps or timesteps 120 | assert not (self.timesteps 121 | and timesteps) or self.timesteps == timesteps 122 | timesteps = self.timesteps or timesteps 123 | B, S, C = x.shape 124 | x = rearrange(x, '(b t) s c -> (b s) t c', t=timesteps) 125 | 126 | if self.ff_in: 127 | x_skip = x 128 | x = self.ff_in(self.norm_in(x)) 129 | if self.is_res: 130 | x += x_skip 131 | 132 | if self.disable_self_attn: 133 | x = self.attn1(self.norm1(x), context=context) + x 134 | else: 135 | x = self.attn1(self.norm1(x)) + x 136 | 137 | if self.attn2 is not None: 138 | if self.switch_temporal_ca_to_sa: 139 | x = self.attn2(self.norm2(x)) + x 140 | else: 141 | x = self.attn2(self.norm2(x), context=context) + x 142 | x_skip = x 143 | x = self.ff(self.norm3(x)) 144 | if self.is_res: 145 | x += x_skip 146 | 147 | x = rearrange(x, 148 | '(b s) t c -> (b t) s c', 149 | s=S, 150 | b=B // timesteps, 151 | c=C, 152 | t=timesteps) 153 | return x 154 | 155 | def get_last_layer(self): 156 | return self.ff.net[-1].weight 157 | 158 | 159 | str_to_dtype = { 160 | 'fp32': torch.float32, 161 | 'fp16': torch.float16, 162 | 'bf16': torch.bfloat16 163 | } 164 | 165 | 166 | class SpatialVideoTransformer(SpatialTransformer): 167 | 168 | def __init__( 169 | self, 170 | in_channels, 171 | n_heads, 172 | d_head, 173 | depth=1, 174 | dropout=0.0, 175 | use_linear=False, 176 | context_dim=None, 177 | use_spatial_context=False, 178 | timesteps=None, 179 | merge_strategy: str = 'fixed', 180 | merge_factor: float = 0.5, 181 | time_context_dim=None, 182 | ff_in=False, 183 | checkpoint=False, 184 | time_depth=1, 185 | attn_mode='softmax', 186 | disable_self_attn=False, 187 | disable_temporal_crossattention=False, 188 | max_time_embed_period: int = 10000, 189 | dtype='fp32', 190 | ): 191 | super().__init__( 192 | in_channels, 193 | n_heads, 194 | d_head, 195 | depth=depth, 196 | dropout=dropout, 197 | attn_type=attn_mode, 198 | use_checkpoint=checkpoint, 199 | context_dim=context_dim, 200 | use_linear=use_linear, 201 | disable_self_attn=disable_self_attn, 202 | ) 203 | self.time_depth = time_depth 204 | self.depth = depth 205 | self.max_time_embed_period = max_time_embed_period 206 | 207 | time_mix_d_head = d_head 208 | n_time_mix_heads = n_heads 209 | 210 | time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) 211 | 212 | inner_dim = n_heads * d_head 213 | if use_spatial_context: 214 | time_context_dim = context_dim 215 | 216 | self.time_stack = nn.ModuleList([ 217 | VideoTransformerBlock( 218 | inner_dim, 219 | n_time_mix_heads, 220 | time_mix_d_head, 221 | dropout=dropout, 222 | context_dim=time_context_dim, 223 | timesteps=timesteps, 224 | checkpoint=checkpoint, 225 | ff_in=ff_in, 226 | inner_dim=time_mix_inner_dim, 227 | attn_mode=attn_mode, 228 | disable_self_attn=disable_self_attn, 229 | disable_temporal_crossattention=disable_temporal_crossattention, 230 | ) for _ in range(self.depth) 231 | ]) 232 | 233 | assert len(self.time_stack) == len(self.transformer_blocks) 234 | 235 | self.use_spatial_context = use_spatial_context 236 | self.in_channels = in_channels 237 | 238 | time_embed_dim = self.in_channels * 4 239 | self.time_pos_embed = nn.Sequential( 240 | linear(self.in_channels, time_embed_dim), 241 | nn.SiLU(), 242 | linear(time_embed_dim, self.in_channels), 243 | ) 244 | 245 | self.time_mixer = AlphaBlender(alpha=merge_factor, 246 | merge_strategy=merge_strategy) 247 | self.dtype = str_to_dtype[dtype] 248 | 249 | def forward( 250 | self, 251 | x: torch.Tensor, 252 | context: Optional[torch.Tensor] = None, 253 | time_context: Optional[torch.Tensor] = None, 254 | timesteps: Optional[int] = None, 255 | image_only_indicator: Optional[torch.Tensor] = None, 256 | ) -> torch.Tensor: 257 | _, _, h, w = x.shape 258 | x_in = x 259 | spatial_context = None 260 | if exists(context): 261 | spatial_context = context 262 | 263 | if self.use_spatial_context: 264 | assert context.ndim == 3, f'n dims of spatial context should be 3 but are {context.ndim}' 265 | 266 | time_context = context 267 | time_context_first_timestep = time_context[::timesteps] 268 | time_context = repeat(time_context_first_timestep, 269 | 'b ... -> (b n) ...', 270 | n=h * w) 271 | elif time_context is not None and not self.use_spatial_context: 272 | time_context = repeat(time_context, 'b ... -> (b n) ...', n=h * w) 273 | if time_context.ndim == 2: 274 | time_context = rearrange(time_context, 'b c -> b 1 c') 275 | 276 | x = self.norm(x) 277 | if not self.use_linear: 278 | x = self.proj_in(x) 279 | x = rearrange(x, 'b c h w -> b (h w) c') 280 | if self.use_linear: 281 | x = self.proj_in(x) 282 | 283 | num_frames = torch.arange(timesteps, device=x.device) 284 | num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps) 285 | num_frames = rearrange(num_frames, 'b t -> (b t)') 286 | t_emb = timestep_embedding( 287 | num_frames, 288 | self.in_channels, 289 | repeat_only=False, 290 | max_period=self.max_time_embed_period, 291 | dtype=self.dtype, 292 | ) 293 | emb = self.time_pos_embed(t_emb) 294 | emb = emb[:, None, :] 295 | 296 | for it_, (block, mix_block) in enumerate( 297 | zip(self.transformer_blocks, self.time_stack)): 298 | x = block( 299 | x, 300 | context=spatial_context, 301 | ) 302 | 303 | x_mix = x 304 | x_mix = x_mix + emb 305 | 306 | x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) 307 | x = self.time_mixer( 308 | x_spatial=x, 309 | x_temporal=x_mix, 310 | image_only_indicator=image_only_indicator, 311 | ) 312 | if self.use_linear: 313 | x = self.proj_out(x) 314 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 315 | if not self.use_linear: 316 | x = self.proj_out(x) 317 | out = x + x_in 318 | return out 319 | -------------------------------------------------------------------------------- /flashvideo/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import torch 6 | from omegaconf import ListConfig 7 | from sgm.util import instantiate_from_config 8 | 9 | 10 | def read_from_file(p, rank=0, world_size=1): 11 | with open(p) as fin: 12 | cnt = -1 13 | for l in fin: 14 | cnt += 1 15 | if cnt % world_size != rank: 16 | continue 17 | yield l.strip(), cnt 18 | 19 | 20 | def disable_all_init(): 21 | """Disable all redundant torch default initialization to accelerate model 22 | creation.""" 23 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) 24 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) 25 | setattr(torch.nn.modules.sparse.Embedding, 'reset_parameters', 26 | lambda self: None) 27 | setattr(torch.nn.modules.conv.Conv2d, 'reset_parameters', 28 | lambda self: None) 29 | setattr(torch.nn.modules.normalization.GroupNorm, 'reset_parameters', 30 | lambda self: None) 31 | 32 | 33 | def get_unique_embedder_keys_from_conditioner(conditioner): 34 | return list({x.input_key for x in conditioner.embedders}) 35 | 36 | 37 | def get_batch(keys, 38 | value_dict, 39 | N: Union[List, ListConfig], 40 | T=None, 41 | device='cuda'): 42 | batch = {} 43 | batch_uc = {} 44 | 45 | for key in keys: 46 | if key == 'txt': 47 | batch['txt'] = np.repeat([value_dict['prompt']], 48 | repeats=math.prod(N)).reshape(N).tolist() 49 | batch_uc['txt'] = np.repeat( 50 | [value_dict['negative_prompt']], 51 | repeats=math.prod(N)).reshape(N).tolist() 52 | else: 53 | batch[key] = value_dict[key] 54 | 55 | if T is not None: 56 | batch['num_video_frames'] = T 57 | 58 | for key in batch.keys(): 59 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): 60 | batch_uc[key] = torch.clone(batch[key]) 61 | return batch, batch_uc 62 | 63 | 64 | def decode(first_stage_model, latent): 65 | first_stage_model.to(torch.float16) 66 | latent = latent.to(torch.float16) 67 | recons = [] 68 | T = latent.shape[2] 69 | if T > 2: 70 | loop_num = (T - 1) // 2 71 | for i in range(loop_num): 72 | if i == 0: 73 | start_frame, end_frame = 0, 3 74 | else: 75 | start_frame, end_frame = i * 2 + 1, i * 2 + 3 76 | if i == loop_num - 1: 77 | clear_fake_cp_cache = True 78 | else: 79 | clear_fake_cp_cache = False 80 | with torch.no_grad(): 81 | recon = first_stage_model.decode( 82 | latent[:, :, start_frame:end_frame].contiguous(), 83 | clear_fake_cp_cache=clear_fake_cp_cache) 84 | 85 | recons.append(recon) 86 | else: 87 | 88 | clear_fake_cp_cache = True 89 | if latent.shape[2] > 1: 90 | for m in first_stage_model.modules(): 91 | m.force_split = True 92 | recon = first_stage_model.decode( 93 | latent.contiguous(), clear_fake_cp_cache=clear_fake_cp_cache) 94 | recons.append(recon) 95 | recon = torch.cat(recons, dim=2).to(torch.float32) 96 | samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() 97 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() 98 | samples = (samples * 255).squeeze(0).permute(0, 2, 3, 1) 99 | save_frames = samples 100 | 101 | return save_frames 102 | 103 | 104 | def save_mem_decode(first_stage_model, latent): 105 | 106 | l_h, l_w = latent.shape[3], latent.shape[4] 107 | T = latent.shape[2] 108 | F = 8 109 | # split spatial along h w 110 | num_h_splits = 1 111 | num_w_splits = 2 112 | ori_video = torch.zeros((1, 3, 1 + 4 * (T - 1), l_h * 8, l_w * 8), 113 | device=latent.device) 114 | for h_idx in range(num_h_splits): 115 | for w_idx in range(num_w_splits): 116 | start_h = h_idx * latent.shape[3] // num_h_splits 117 | end_h = (h_idx + 1) * latent.shape[3] // num_h_splits 118 | start_w = w_idx * latent.shape[4] // num_w_splits 119 | end_w = (w_idx + 1) * latent.shape[4] // num_w_splits 120 | 121 | latent_overlap = 16 122 | if (start_h - latent_overlap >= 0) and (num_h_splits > 1): 123 | real_start_h = start_h - latent_overlap 124 | h_start_overlap = latent_overlap * F 125 | else: 126 | h_start_overlap = 0 127 | real_start_h = start_h 128 | if (end_h + latent_overlap <= l_h) and (num_h_splits > 1): 129 | real_end_h = end_h + latent_overlap 130 | h_end_overlap = latent_overlap * F 131 | else: 132 | h_end_overlap = 0 133 | real_end_h = end_h 134 | 135 | if (start_w - latent_overlap >= 0) and (num_w_splits > 1): 136 | real_start_w = start_w - latent_overlap 137 | w_start_overlap = latent_overlap * F 138 | else: 139 | w_start_overlap = 0 140 | real_start_w = start_w 141 | 142 | if (end_w + latent_overlap <= l_w) and (num_w_splits > 1): 143 | real_end_w = end_w + latent_overlap 144 | w_end_overlap = latent_overlap * F 145 | else: 146 | w_end_overlap = 0 147 | real_end_w = end_w 148 | 149 | latent_slice = latent[:, :, :, real_start_h:real_end_h, 150 | real_start_w:real_end_w] 151 | recon = decode(first_stage_model, latent_slice) 152 | 153 | recon = recon.permute(3, 0, 1, 2).contiguous()[None] 154 | 155 | recon = recon[:, :, :, 156 | h_start_overlap:recon.shape[3] - h_end_overlap, 157 | w_start_overlap:recon.shape[4] - w_end_overlap] 158 | ori_video[:, :, :, start_h * 8:end_h * 8, 159 | start_w * 8:end_w * 8] = recon 160 | ori_video = ori_video.squeeze(0) 161 | ori_video = ori_video.permute(1, 2, 3, 0).contiguous().cpu() 162 | return ori_video 163 | 164 | 165 | def prepare_input(text, model, T, negative_prompt=None, pos_prompt=None): 166 | 167 | if negative_prompt is None: 168 | negative_prompt = '' 169 | if pos_prompt is None: 170 | pos_prompt = '' 171 | value_dict = { 172 | 'prompt': text + pos_prompt, 173 | 'negative_prompt': negative_prompt, 174 | 'num_frames': torch.tensor(T).unsqueeze(0), 175 | } 176 | print(value_dict) 177 | batch, batch_uc = get_batch( 178 | get_unique_embedder_keys_from_conditioner(model.conditioner), 179 | value_dict, [1]) 180 | 181 | for key in batch: 182 | if isinstance(batch[key], torch.Tensor): 183 | print(key, batch[key].shape) 184 | elif isinstance(batch[key], list): 185 | print(key, [len(l) for l in batch[key]]) 186 | else: 187 | print(key, batch[key]) 188 | c, uc = model.conditioner.get_unconditional_conditioning( 189 | batch, 190 | batch_uc=batch_uc, 191 | force_uc_zero_embeddings=['txt'], 192 | ) 193 | 194 | for k in c: 195 | if not k == 'crossattn': 196 | c[k], uc[k] = map(lambda y: y[k][:math.prod([1])].to('cuda'), 197 | (c, uc)) 198 | return c, uc 199 | 200 | 201 | def save_memory_encode_first_stage(x, model): 202 | splits_x = torch.split(x, [17, 16, 16], dim=2) 203 | all_out = [] 204 | 205 | with torch.autocast('cuda', enabled=False): 206 | for idx, input_x in enumerate(splits_x): 207 | if idx == len(splits_x) - 1: 208 | clear_fake_cp_cache = True 209 | else: 210 | clear_fake_cp_cache = False 211 | out = model.first_stage_model.encode( 212 | input_x.contiguous(), clear_fake_cp_cache=clear_fake_cp_cache) 213 | all_out.append(out) 214 | 215 | z = torch.cat(all_out, dim=2) 216 | z = model.scale_factor * z 217 | return z 218 | 219 | 220 | def seed_everything(seed: int = 42): 221 | import os 222 | import random 223 | 224 | import numpy as np 225 | import torch 226 | 227 | # Python random module 228 | random.seed(seed) 229 | 230 | # Numpy 231 | np.random.seed(seed) 232 | 233 | # PyTorch 234 | torch.manual_seed(seed) 235 | 236 | # If using CUDA 237 | torch.cuda.manual_seed(seed) 238 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 239 | 240 | # # CuDNN 241 | # torch.backends.cudnn.deterministic = True 242 | # torch.backends.cudnn.benchmark = False 243 | 244 | # OS environment 245 | os.environ['PYTHONHASHSEED'] = str(seed) 246 | 247 | 248 | def get_time_slice_vae(): 249 | vae_config = { 250 | 'target': 'vae_modules.autoencoder.VideoAutoencoderInferenceWrapper', 251 | 'params': { 252 | 'cp_size': 1, 253 | 'ckpt_path': './checkpoints/3d-vae.pt', 254 | 'ignore_keys': ['loss'], 255 | 'loss_config': { 256 | 'target': 'torch.nn.Identity' 257 | }, 258 | 'regularizer_config': { 259 | 'target': 260 | 'vae_modules.regularizers.DiagonalGaussianRegularizer' 261 | }, 262 | 'encoder_config': { 263 | 'target': 264 | 'vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D', 265 | 'params': { 266 | 'double_z': True, 267 | 'z_channels': 16, 268 | 'resolution': 256, 269 | 'in_channels': 3, 270 | 'out_ch': 3, 271 | 'ch': 128, 272 | 'ch_mult': [1, 2, 2, 4], 273 | 'attn_resolutions': [], 274 | 'num_res_blocks': 3, 275 | 'dropout': 0.0, 276 | 'gather_norm': False 277 | } 278 | }, 279 | 'decoder_config': { 280 | 'target': 'vae_modules.cp_enc_dec.ContextParallelDecoder3D', 281 | 'params': { 282 | 'double_z': True, 283 | 'z_channels': 16, 284 | 'resolution': 256, 285 | 'in_channels': 3, 286 | 'out_ch': 3, 287 | 'ch': 128, 288 | 'ch_mult': [1, 2, 2, 4], 289 | 'attn_resolutions': [], 290 | 'num_res_blocks': 3, 291 | 'dropout': 0.0, 292 | 'gather_norm': False 293 | } 294 | } 295 | } 296 | } 297 | 298 | vae = instantiate_from_config(vae_config).eval().half().cuda() 299 | return vae 300 | -------------------------------------------------------------------------------- /flashvideo/vae_modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | 7 | def __init__(self, model, decay=0.9999, use_num_upates=True): 8 | super().__init__() 9 | if decay < 0.0 or decay > 1.0: 10 | raise ValueError('Decay must be between 0 and 1') 11 | 12 | self.m_name2s_name = {} 13 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 14 | self.register_buffer( 15 | 'num_updates', 16 | torch.tensor(0, dtype=torch.int) 17 | if use_num_upates else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace('.', '') 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, 39 | (1 + self.num_updates) / (10 + self.num_updates)) 40 | 41 | one_minus_decay = 1.0 - decay 42 | 43 | with torch.no_grad(): 44 | m_param = dict(model.named_parameters()) 45 | shadow_params = dict(self.named_buffers()) 46 | 47 | for key in m_param: 48 | if m_param[key].requires_grad: 49 | sname = self.m_name2s_name[key] 50 | shadow_params[sname] = shadow_params[sname].type_as( 51 | m_param[key]) 52 | shadow_params[sname].sub_( 53 | one_minus_decay * 54 | (shadow_params[sname] - m_param[key])) 55 | else: 56 | assert not key in self.m_name2s_name 57 | 58 | def copy_to(self, model): 59 | m_param = dict(model.named_parameters()) 60 | shadow_params = dict(self.named_buffers()) 61 | for key in m_param: 62 | if m_param[key].requires_grad: 63 | m_param[key].data.copy_( 64 | shadow_params[self.m_name2s_name[key]].data) 65 | else: 66 | assert not key in self.m_name2s_name 67 | 68 | def store(self, parameters): 69 | """ 70 | Save the current parameters for restoring later. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | temporarily stored. 74 | """ 75 | self.collected_params = [param.clone() for param in parameters] 76 | 77 | def restore(self, parameters): 78 | """ 79 | Restore the parameters stored with the `store` method. 80 | Useful to validate the model with EMA parameters without affecting the 81 | original optimization process. Store the parameters before the 82 | `copy_to` method. After validation (or model saving), use this to 83 | restore the former parameters. 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | -------------------------------------------------------------------------------- /flashvideo/vae_modules/regularizers.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class DiagonalGaussianDistribution: 11 | 12 | def __init__(self, parameters, deterministic=False): 13 | self.parameters = parameters 14 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 15 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 16 | self.deterministic = deterministic 17 | self.std = torch.exp(0.5 * self.logvar) 18 | self.var = torch.exp(self.logvar) 19 | if self.deterministic: 20 | self.var = self.std = torch.zeros_like( 21 | self.mean).to(device=self.parameters.device) 22 | 23 | def sample(self): 24 | # x = self.mean + self.std * torch.randn(self.mean.shape).to( 25 | # device=self.parameters.device 26 | # ) 27 | x = self.mean + self.std * torch.randn_like(self.mean) 28 | return x 29 | 30 | def kl(self, other=None): 31 | if self.deterministic: 32 | return torch.Tensor([0.0]) 33 | else: 34 | if other is None: 35 | return 0.5 * torch.sum( 36 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 37 | dim=[1, 2, 3], 38 | ) 39 | else: 40 | return 0.5 * torch.sum( 41 | torch.pow(self.mean - other.mean, 2) / other.var + 42 | self.var / other.var - 1.0 - self.logvar + other.logvar, 43 | dim=[1, 2, 3], 44 | ) 45 | 46 | def nll(self, sample, dims=[1, 2, 3]): 47 | if self.deterministic: 48 | return torch.Tensor([0.0]) 49 | logtwopi = np.log(2.0 * np.pi) 50 | return 0.5 * torch.sum( 51 | logtwopi + self.logvar + 52 | torch.pow(sample - self.mean, 2) / self.var, 53 | dim=dims, 54 | ) 55 | 56 | def mode(self): 57 | return self.mean 58 | 59 | 60 | class AbstractRegularizer(nn.Module): 61 | 62 | def __init__(self): 63 | super().__init__() 64 | 65 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 66 | raise NotImplementedError() 67 | 68 | @abstractmethod 69 | def get_trainable_parameters(self) -> Any: 70 | raise NotImplementedError() 71 | 72 | 73 | class IdentityRegularizer(AbstractRegularizer): 74 | 75 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 76 | return z, dict() 77 | 78 | def get_trainable_parameters(self) -> Any: 79 | yield from () 80 | 81 | 82 | def measure_perplexity( 83 | predicted_indices: torch.Tensor, 84 | num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: 85 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 86 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 87 | encodings = F.one_hot(predicted_indices, 88 | num_centroids).float().reshape(-1, num_centroids) 89 | avg_probs = encodings.mean(0) 90 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 91 | cluster_use = torch.sum(avg_probs > 0) 92 | return perplexity, cluster_use 93 | 94 | 95 | class DiagonalGaussianRegularizer(AbstractRegularizer): 96 | 97 | def __init__(self, sample: bool = True): 98 | super().__init__() 99 | self.sample = sample 100 | 101 | def get_trainable_parameters(self) -> Any: 102 | yield from () 103 | 104 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 105 | log = dict() 106 | posterior = DiagonalGaussianDistribution(z) 107 | if self.sample: 108 | z = posterior.sample() 109 | else: 110 | z = posterior.mode() 111 | kl_loss = posterior.kl() 112 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 113 | log['kl_loss'] = kl_loss 114 | return z, log 115 | -------------------------------------------------------------------------------- /inf_270_1080p.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 \ 2 | --nnodes=1 \ 3 | --node_rank=0 \ 4 | --master_port=20023 flashvideo/dist_inf_text_file.py \ 5 | --base "flashvideo/configs/stage1.yaml" \ 6 | --second "flashvideo/configs/stage2.yaml" \ 7 | --inf-ckpt ./checkpoints/stage1.pt \ 8 | --inf-ckpt2 ./checkpoints/stage2.pt \ 9 | --input-file ./example.txt \ 10 | --output-dir ./vis_270p_1080p_example 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested 2 | diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested 3 | gradio>=4.42.0 # For HF gradio demo 4 | imageio==2.34.2 # For diffusers inference export video 5 | imageio-ffmpeg==0.5.1 # For diffusers inference export video 6 | moviepy==1.0.3 # For export video 7 | numpy==1.26.0 8 | openai>=1.42.0 # For prompt refiner 9 | pillow==9.5.0 10 | sentencepiece>=0.2.0 # T5 used 11 | streamlit>=1.38.0 # For streamlit web demo 12 | SwissArmyTransformer>=0.4.12 13 | torch>=2.4.0 # Tested in 2.2 2.3 2.4 and 2.5, The development team is working on version 2.4.0. 14 | torchvision>=0.19.0 # The development team is working on version 0.19.0. 15 | transformers>=4.44.2 # The development team is working on version 4.44.2 16 | --------------------------------------------------------------------------------