├── .gitignore
├── Dockerfile.ds
├── LICENSE
├── README.md
├── README_zh-CN.md
├── __init__.py
├── app.py
├── config
├── easyanimate_image_normal_v1.yaml
├── easyanimate_video_long_sequence_v1.yaml
└── easyanimate_video_motion_module_v1.yaml
├── datasets
└── put datasets here.txt
├── easyanimate
├── __init__.py
├── data
│ ├── bucket_sampler.py
│ ├── dataset_image.py
│ ├── dataset_image_video.py
│ └── dataset_video.py
├── models
│ ├── attention.py
│ ├── autoencoder_magvit.py
│ ├── motion_module.py
│ ├── patch.py
│ ├── transformer2d.py
│ └── transformer3d.py
├── pipeline
│ ├── pipeline_easyanimate.py
│ ├── pipeline_easyanimate_inpaint.py
│ └── pipeline_pixart_magvit.py
├── ui
│ └── ui.py
└── utils
│ ├── IDDIM.py
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── gaussian_diffusion.py
│ ├── lora_utils.py
│ ├── respace.py
│ └── utils.py
├── models
└── put models here.txt
├── nodes.py
├── predict_t2i.py
├── predict_t2v.py
├── requirements.txt
├── scripts
├── extra_motion_module.py
├── train_t2i.py
├── train_t2i.sh
├── train_t2i_lora.py
├── train_t2i_lora.sh
├── train_t2iv.py
├── train_t2iv.sh
├── train_t2v.py
├── train_t2v.sh
├── train_t2v_lora.py
└── train_t2v_lora.sh
├── wf.json
└── wf.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | models*
3 | output*
4 | samples*
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/Dockerfile.ds:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
2 | ENV DEBIAN_FRONTEND noninteractive
3 |
4 | RUN rm -r /etc/apt/sources.list.d/
5 |
6 | RUN apt-get update -y && apt-get install -y \
7 | libgl1 libglib2.0-0 google-perftools \
8 | sudo wget git git-lfs vim tig pkg-config libcairo2-dev \
9 | telnet curl net-tools iputils-ping wget jq \
10 | python3-pip python-is-python3 python3.10-venv tzdata lsof && \
11 | rm -rf /var/lib/apt/lists/*
12 | RUN pip3 install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple/
13 |
14 | # add all extensions
15 | RUN apt-get update -y && apt-get install -y zip && \
16 | rm -rf /var/lib/apt/lists/*
17 | RUN pip install wandb tqdm GitPython==3.1.32 Pillow==9.5.0 setuptools --upgrade -i https://mirrors.aliyun.com/pypi/simple/
18 |
19 | # reinstall torch to keep compatible with xformers
20 | RUN pip uninstall -qy torch torchvision && \
21 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118
22 | RUN pip uninstall -qy xfromers && pip install xformers==0.0.24 --index-url https://download.pytorch.org/whl/cu118
23 |
24 | # install requirements
25 | COPY ./requirements.txt /root/requirements.txt
26 | RUN pip install -r /root/requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
27 | RUN rm -rf /root/requirements.txt
28 |
29 | ENV PYTHONUNBUFFERED 1
30 | ENV NVIDIA_DISABLE_REQUIRE 1
31 |
32 | WORKDIR /root/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-EasyAnimate
2 |
3 | ## workflow
4 |
5 | [basic](https://github.com/chaojie/ComfyUI-EasyAnimate/blob/main/wf.json)
6 |
7 |
8 |
9 | ### 1、Model Weights
10 | | Name | Type | Storage Space | Url | Description |
11 | |--|--|--|--|--|
12 | | easyanimate_v1_mm.safetensors | Motion Module | 4.1GB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Motion_Module/easyanimate_v1_mm.safetensors) | ComfyUI/models/checkpoints |
13 | | PixArt-XL-2-512x512.tar | Pixart | 11.4GB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Diffusion_Transformer/PixArt-XL-2-512x512.tar)| ComfyUI/models/diffusers (tar -xvf PixArt-XL-2-512x512.tar) |
14 |
15 | ### 2、Optional Model Weights
16 | | Name | Type | Storage Space | Url | Description |
17 | |--|--|--|--|--|
18 | | easyanimate_portrait.safetensors | Checkpoint of Pixart | 2.3GB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Personalized_Model/easyanimate_portrait.safetensors) | ComfyUI/models/checkpoints |
19 | | easyanimate_portrait_lora.safetensors | Lora of Pixart | 654.0MB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Personalized_Model/easyanimate_portrait_lora.safetensors)| ComfyUI/models/checkpoints
20 |
21 | ## [EasyAnimate](https://github.com/aigc-apps/EasyAnimate)
22 |
--------------------------------------------------------------------------------
/README_zh-CN.md:
--------------------------------------------------------------------------------
1 | # EasyAnimate | 您的智能生成器。
2 | 😊 EasyAnimate是一个用于生成长视频和训练基于transformer的扩散生成器的repo。
3 |
4 | 😊 我们基于类SORA结构与DIT,使用transformer进行作为扩散器进行视频生成。为了保证良好的拓展性,我们基于motion module构建了EasyAnimate,未来我们也会尝试更多的训练方案一提高效果。
5 |
6 | 😊 Welcome!
7 |
8 | [English](./README.md) | 简体中文
9 |
10 | # 目录
11 | - [目录](#目录)
12 | - [简介](#简介)
13 | - [TODO List](#todo-list)
14 | - [Model zoo](#model-zoo)
15 | - [1、运动权重](#1运动权重)
16 | - [2、其他权重](#2其他权重)
17 | - [快速启动](#快速启动)
18 | - [1. 云使用: AliyunDSW/Docker](#1-云使用-aliyundswdocker)
19 | - [2. 本地安装: 环境检查/下载/安装](#2-本地安装-环境检查下载安装)
20 | - [如何使用](#如何使用)
21 | - [1. 生成](#1-生成)
22 | - [2. 模型训练](#2-模型训练)
23 | - [算法细节](#算法细节)
24 | - [参考文献](#参考文献)
25 | - [许可证](#许可证)
26 |
27 | # 简介
28 | EasyAnimate是一个基于transformer结构的pipeline,可用于生成AI动画、训练Diffusion Transformer的基线模型与Lora模型,我们支持从已经训练好的EasyAnimate模型直接进行预测,生成不同分辨率,6秒左右、fps12的视频(40 ~ 80帧, 未来会支持更长的视频),也支持用户训练自己的基线模型与Lora模型,进行一定的风格变换。
29 |
30 | 我们会逐渐支持从不同平台快速启动,请参阅 [快速启动](#快速启动)。
31 |
32 | 新特性:
33 | - 创建代码!现在支持 Windows 和 Linux。[ 2024.04.12 ]
34 |
35 | 这些是我们的生成结果:
36 |
37 | 我们的ui界面如下:
38 | 
39 |
40 | # TODO List
41 | - 支持更大分辨率的文视频生成模型。
42 | - 支持基于magvit的文视频生成模型。
43 | - 支持视频inpaint模型。
44 |
45 | # Model zoo
46 | ### 1、运动权重
47 | | 名称 | 种类 | 存储空间 | 下载地址 | 描述 |
48 | |--|--|--|--|--|
49 | | easyanimate_v1_mm.safetensors | Motion Module | 4.1GB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Motion_Module/easyanimate_v1_mm.safetensors) | Training with 80 frames and fps 12 |
50 |
51 | ### 2、其他权重
52 | | 名称 | 种类 | 存储空间 | 下载地址 | 描述 |
53 | |--|--|--|--|--|
54 | | PixArt-XL-2-512x512.tar | Pixart | 11.4GB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Diffusion_Transformer/PixArt-XL-2-512x512.tar)| Pixart-Alpha official weights |
55 | | easyanimate_portrait.safetensors | Checkpoint of Pixart | 2.3GB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Personalized_Model/easyanimate_portrait.safetensors) | Training with internal portrait datasets |
56 | | easyanimate_portrait_lora.safetensors | Lora of Pixart | 654.0MB | [download](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Personalized_Model/easyanimate_portrait_lora.safetensors)| Training with internal portrait datasets |
57 |
58 |
59 | # 生成效果
60 | 在生成风景类animation时,采样器推荐使用DPM++和Euler A。在生成人像类animation时,采样器推荐使用Euler A和Euler。
61 |
62 | 有些时候Github无法正常显示大GIF,可以通过Download GIF下载到本地查看。
63 |
64 | 使用原始的pixart checkpoint进行预测。
65 |
66 | | Base Models | Sampler | Seed | Resolution (h x w x f) | Prompt | GenerationResult | Download |
67 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
68 | | PixArt | DPM++ | 43 | 512x512x80 | A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff\'s precipices. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/1-cliff.gif) |
69 | | PixArt | DPM++ | 43 | 448x640x80 | The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/2-waterfall.gif) |
70 | | PixArt | DPM++ | 43 | 704x384x80 | A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/3-snowy.gif) |
71 | | PixArt | DPM++ | 43 | 448x640x64 | The vibrant beauty of a sunflower field. The sunflowers, with their bright yellow petals and dark brown centers, are in full bloom, creating a stunning contrast against the green leaves and stems. The sunflowers are arranged in neat rows, creating a sense of order and symmetry. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/4-sunflower.gif) |
72 | | PixArt | DPM++ | 43 | 384x704x48 | A tranquil Vermont autumn, with leaves in vibrant colors of orange and red fluttering down a mountain stream. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/5-autumn.gif) |
73 | | PixArt | DPM++ | 43 | 704x384x48 | A vibrant underwater scene. A group of blue fish, with yellow fins, are swimming around a coral reef. The coral reef is a mix of brown and green, providing a natural habitat for the fish. The water is a deep blue, indicating a depth of around 30 feet. The fish are swimming in a circular pattern around the coral reef, indicating a sense of motion and activity. The overall scene is a beautiful representation of marine life. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/6-underwater.gif) |
74 | | PixArt | DPM++ | 43 | 576x448x48 | Pacific coast, carmel by the blue sea ocean and peaceful waves |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/7-coast.gif) |
75 | | PixArt | DPM++ | 43 | 576x448x80 | A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/8-forest.gif) |
76 | | PixArt | DPM++ | 43 | 640x448x64 | The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements. | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/9-grasses.gif) |
77 | | PixArt | DPM++ | 43 | 704x384x80 | A serene night scene in a forested area. The first frame shows a tranquil lake reflecting the star-filled sky above. The second frame reveals a beautiful sunset, casting a warm glow over the landscape. The third frame showcases the night sky, filled with stars and a vibrant Milky Way galaxy. The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. The style of the video is naturalistic, emphasizing the beauty of the night sky and the peacefulness of the forest. | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/10-night.gif) |
78 | | PixArt | DPM++ | 43 | 640x448x80 | Sunset over the sea. |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/11-sunset.gif) |
79 |
80 | 使用人像checkpoint进行预测。
81 |
82 | | Base Models | Sampler | Seed | Resolution (h x w x f) | Prompt | GenerationResult | Download |
83 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
84 | | Portrait | Euler A | 43 | 448x576x80 | 1girl, 3d, black hair, brown eyes, earrings, grey background, jewelry, lips, long hair, looking at viewer, photo \\(medium\\), realistic, red lips, solo |  | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/1-check.gif) |
85 | | Portrait | Euler A | 43 | 448x576x80 | 1girl, bare shoulders, blurry, brown eyes, dirty, dirty face, freckles, lips, long hair, looking at viewer, realistic, sleeveless, solo, upper body | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/2-check.gif) |
86 | | Portrait | Euler A | 43 | 512x512x64 | 1girl, black hair, brown eyes, earrings, grey background, jewelry, lips, looking at viewer, mole, mole under eye, neck tattoo, nose, ponytail, realistic, shirt, simple background, solo, tattoo | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/3-check.gif) |
87 | | Portrait | Euler A | 43 | 576x448x64 | 1girl, black hair, lips, looking at viewer, mole, mole under eye, mole under mouth, realistic, solo | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/5-check.gif) |
88 |
89 | 使用人像Lora进行预测。
90 |
91 | | Base Models | Sampler | Seed | Resolution (h x w x f) | Prompt | GenerationResult | Download |
92 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
93 | | Pixart + Lora | Euler A | 43 | 512x512x64 | 1girl, 3d, black hair, brown eyes, earrings, grey background, jewelry, lips, long hair, looking at viewer, photo \\(medium\\), realistic, red lips, solo | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/1-lora.gif) |
94 | | Pixart + Lora | Euler A | 43 | 512x512x64 | 1girl, bare shoulders, blurry, brown eyes, dirty, dirty face, freckles, lips, long hair, looking at viewer, mole, mole on breast, mole on neck, mole under eye, mole under mouth, realistic, sleeveless, solo, upper body | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/2-lora.gif) |
95 | | Pixart + Lora | Euler A | 43 | 512x512x64 | 1girl, black hair, lips, looking at viewer, mole, mole under eye, mole under mouth, realistic, solo | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/5-lora.gif) |
96 | | Pixart + Lora | Euler A | 43 | 512x512x80 | 1girl, bare shoulders, blurry, blurry background, blurry foreground, bokeh, brown eyes, christmas tree, closed mouth, collarbone, depth of field, earrings, jewelry, lips, long hair, looking at viewer, photo \\(medium\\), realistic, smile, solo | | [Download GIF](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/8-lora.gif) |
97 |
98 | # 快速启动
99 | ### 1. 云使用: AliyunDSW/Docker
100 | #### a. 通过阿里云 DSW
101 | 敬请期待。
102 |
103 | #### b. 通过docker
104 | 使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令:
105 | ```
106 | # 拉取镜像
107 | docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate
108 |
109 | # 进入镜像
110 | docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate
111 |
112 | # clone 代码
113 | git clone https://github.com/aigc-apps/EasyAnimate.git
114 |
115 | # 进入EasyAnimate文件夹
116 | cd EasyAnimate
117 |
118 | # 下载权重
119 | mkdir models/Diffusion_Transformer
120 | mkdir models/Motion_Module
121 | mkdir models/Personalized_Model
122 |
123 | wget https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Motion_Module/easyanimate_v1_mm.safetensors -O models/Motion_Module/easyanimate_v1_mm.safetensors
124 | wget https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Personalized_Model/easyanimate_portrait.safetensors -O models/Personalized_Model/easyanimate_portrait.safetensors
125 | wget https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Personalized_Model/easyanimate_portrait_lora.safetensors -O models/Personalized_Model/easyanimate_portrait_lora.safetensors
126 | wget https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Diffusion_Transformer/PixArt-XL-2-512x512.tar -O models/Diffusion_Transformer/PixArt-XL-2-512x512.tar
127 |
128 | cd models/Diffusion_Transformer/
129 | tar -xvf PixArt-XL-2-512x512.tar
130 | cd ../../
131 | ```
132 |
133 | ### 2. 本地安装: 环境检查/下载/安装
134 | #### a. 环境检查
135 | 我们已验证EasyAnimate可在以下环境中执行:
136 |
137 | Linux 的详细信息:
138 | - 操作系统 Ubuntu 20.04, CentOS
139 | - python: python3.10 & python3.11
140 | - pytorch: torch2.2.0
141 | - CUDA: 11.8
142 | - CUDNN: 8+
143 | - GPU: Nvidia-A10 24G & Nvidia-A100 40G & Nvidia-A100 80G
144 |
145 | 我们需要大约 60GB 的可用磁盘空间,请检查!
146 |
147 | #### b. 权重放置
148 | 我们最好将权重按照指定路径进行放置:
149 |
150 | ```
151 | 📦 models/
152 | ├── 📂 Diffusion_Transformer/
153 | │ └── 📂 PixArt-XL-2-512x512/
154 | ├── 📂 Motion_Module/
155 | │ └── 📄 easyanimate_v1_mm.safetensors
156 | ├── 📂 Motion_Module/
157 | │ ├── 📄 easyanimate_portrait.safetensors
158 | │ └── 📄 easyanimate_portrait_lora.safetensors
159 | ```
160 |
161 | # 如何使用
162 | ### 1. 生成
163 | #### a. 视频生成
164 | ##### i、运行python文件
165 | - 步骤1:下载对应权重放入models文件夹。
166 | - 步骤2:在predict_t2v.py文件中修改prompt、neg_prompt、guidance_scale和seed。
167 | - 步骤3:运行predict_t2v.py文件,等待生成结果,结果保存在samples/easyanimate-videos文件夹中。
168 | - 步骤4:如果想结合自己训练的其他backbone与Lora,则看情况修改predict_t2v.py中的predict_t2v.py和lora_path。
169 |
170 | ##### ii、通过ui界面
171 | - 步骤1:下载对应权重放入models文件夹。
172 | - 步骤2:运行app.py文件,进入gradio页面。
173 | - 步骤3:根据页面选择生成模型,填入prompt、neg_prompt、guidance_scale和seed等,点击生成,等待生成结果,结果保存在sample文件夹中。
174 |
175 | ### 2. 模型训练
176 | #### a、训练视频生成模型
177 | ##### i、基于webvid数据集
178 | 如果使用webvid数据集进行训练,则需要首先下载webvid的数据集。
179 |
180 | 您需要以这种格式排列webvid数据集。
181 | ```
182 | 📦 project/
183 | ├── 📂 datasets/
184 | │ ├── 📂 webvid/
185 | │ ├── 📂 videos/
186 | │ │ ├── 📄 00000001.mp4
187 | │ │ ├── 📄 00000002.mp4
188 | │ │ └── 📄 .....
189 | │ └── 📄 csv_of_webvid.csv
190 | ```
191 |
192 | 然后,进入scripts/train_t2v.sh进行设置。
193 | ```
194 | export DATASET_NAME="datasets/webvid/videos/"
195 | export DATASET_META_NAME="datasets/webvid/csv_of_webvid.csv"
196 |
197 | ...
198 |
199 | train_data_format="webvid"
200 | ```
201 |
202 | 最后运行scripts/train_t2v.sh。
203 | ```sh
204 | sh scripts/train_t2v.sh
205 | ```
206 |
207 | ##### ii、基于自建数据集
208 | 如果使用内部数据集进行训练,则需要首先格式化数据集。
209 |
210 | 您需要以这种格式排列数据集。
211 | ```
212 | 📦 project/
213 | ├── 📂 datasets/
214 | │ ├── 📂 internal_datasets/
215 | │ ├── 📂 videos/
216 | │ │ ├── 📄 00000001.mp4
217 | │ │ ├── 📄 00000002.mp4
218 | │ │ └── 📄 .....
219 | │ └── 📄 json_of_internal_datasets.json
220 | ```
221 |
222 | json_of_internal_datasets.json是一个标准的json文件,如下所示:
223 | ```json
224 | [
225 | {
226 | "file_path": "videos/00000001.mp4",
227 | "text": "A group of young men in suits and sunglasses are walking down a city street.",
228 | "type": "video"
229 | },
230 | {
231 | "file_path": "videos/00000002.mp4",
232 | "text": "A notepad with a drawing of a woman on it.",
233 | "type": "video"
234 | }
235 | .....
236 | ]
237 | ```
238 | json中的file_path需要设置为相对路径。
239 |
240 | 然后,进入scripts/train_t2v.sh进行设置。
241 | ```
242 | export DATASET_NAME="datasets/internal_datasets/"
243 | export DATASET_META_NAME="datasets/internal_datasets/json_of_internal_datasets.json"
244 |
245 | ...
246 |
247 | train_data_format="normal"
248 | ```
249 |
250 | 最后运行scripts/train_t2v.sh。
251 | ```sh
252 | sh scripts/train_t2v.sh
253 | ```
254 |
255 | #### b、训练基础文生图模型
256 | ##### i、基于diffusers格式
257 | 数据集的格式可以设置为diffusers格式。
258 |
259 | ```
260 | 📦 project/
261 | ├── 📂 datasets/
262 | │ ├── 📂 diffusers_datasets/
263 | │ ├── 📂 train/
264 | │ │ ├── 📄 00000001.jpg
265 | │ │ ├── 📄 00000002.jpg
266 | │ │ └── 📄 .....
267 | │ └── 📄 metadata.jsonl
268 | ```
269 |
270 | 然后,进入scripts/train_t2i.sh进行设置。
271 | ```
272 | export DATASET_NAME="datasets/diffusers_datasets/"
273 |
274 | ...
275 |
276 | train_data_format="diffusers"
277 | ```
278 |
279 | 最后运行scripts/train_t2i.sh。
280 | ```sh
281 | sh scripts/train_t2i.sh
282 | ```
283 | ##### ii、基于自建数据集
284 | 如果使用自建数据集进行训练,则需要首先格式化数据集。
285 |
286 | 您需要以这种格式排列数据集。
287 | ```
288 | 📦 project/
289 | ├── 📂 datasets/
290 | │ ├── 📂 internal_datasets/
291 | │ ├── 📂 train/
292 | │ │ ├── 📄 00000001.jpg
293 | │ │ ├── 📄 00000002.jpg
294 | │ │ └── 📄 .....
295 | │ └── 📄 json_of_internal_datasets.json
296 | ```
297 |
298 | json_of_internal_datasets.json是一个标准的json文件,如下所示:
299 | ```json
300 | [
301 | {
302 | "file_path": "train/00000001.jpg",
303 | "text": "A group of young men in suits and sunglasses are walking down a city street.",
304 | "type": "image"
305 | },
306 | {
307 | "file_path": "train/00000002.jpg",
308 | "text": "A notepad with a drawing of a woman on it.",
309 | "type": "image"
310 | }
311 | .....
312 | ]
313 | ```
314 | json中的file_path需要设置为相对路径。
315 |
316 | 然后,进入scripts/train_t2i.sh进行设置。
317 | ```
318 | export DATASET_NAME="datasets/internal_datasets/"
319 | export DATASET_META_NAME="datasets/internal_datasets/json_of_internal_datasets.json"
320 |
321 | ...
322 |
323 | train_data_format="normal"
324 | ```
325 |
326 | 最后运行scripts/train_t2i.sh。
327 | ```sh
328 | sh scripts/train_t2i.sh
329 | ```
330 |
331 | #### c、训练Lora模型
332 | ##### i、基于diffusers格式
333 | 数据集的格式可以设置为diffusers格式。
334 | ```
335 | 📦 project/
336 | ├── 📂 datasets/
337 | │ ├── 📂 diffusers_datasets/
338 | │ ├── 📂 train/
339 | │ │ ├── 📄 00000001.jpg
340 | │ │ ├── 📄 00000002.jpg
341 | │ │ └── 📄 .....
342 | │ └── 📄 metadata.jsonl
343 | ```
344 |
345 | 然后,进入scripts/train_lora.sh进行设置。
346 | ```
347 | export DATASET_NAME="datasets/diffusers_datasets/"
348 |
349 | ...
350 |
351 | train_data_format="diffusers"
352 | ```
353 |
354 | 最后运行scripts/train_lora.sh。
355 | ```sh
356 | sh scripts/train_lora.sh
357 | ```
358 |
359 | ##### ii、基于自建数据集
360 | 如果使用自建数据集进行训练,则需要首先格式化数据集。
361 |
362 | 您需要以这种格式排列数据集。
363 | ```
364 | 📦 project/
365 | ├── 📂 datasets/
366 | │ ├── 📂 internal_datasets/
367 | │ ├── 📂 train/
368 | │ │ ├── 📄 00000001.jpg
369 | │ │ ├── 📄 00000002.jpg
370 | │ │ └── 📄 .....
371 | │ └── 📄 json_of_internal_datasets.json
372 | ```
373 |
374 | json_of_internal_datasets.json是一个标准的json文件,如下所示:
375 | ```json
376 | [
377 | {
378 | "file_path": "train/00000001.jpg",
379 | "text": "A group of young men in suits and sunglasses are walking down a city street.",
380 | "type": "image"
381 | },
382 | {
383 | "file_path": "train/00000002.jpg",
384 | "text": "A notepad with a drawing of a woman on it.",
385 | "type": "image"
386 | }
387 | .....
388 | ]
389 | ```
390 | json中的file_path需要设置为相对路径。
391 |
392 | 然后,进入scripts/train_lora.sh进行设置。
393 | ```
394 | export DATASET_NAME="datasets/internal_datasets/"
395 | export DATASET_META_NAME="datasets/internal_datasets/json_of_internal_datasets.json"
396 |
397 | ...
398 |
399 | train_data_format="normal"
400 | ```
401 |
402 | 最后运行scripts/train_lora.sh。
403 | ```sh
404 | sh scripts/train_lora.sh
405 | ```
406 | # 算法细节
407 | 我们使用了[PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha)作为基础模型,并在此基础上引入额外的运动模块(motion module)来将DiT模型从2D图像生成扩展到3D视频生成上来。其框架图如下:
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 | 其中,Motion Module 用于捕捉时序维度的帧间关系,其结构如下:
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 | 我们在时序维度上引入注意力机制来让模型学习时序信息,以进行连续视频帧的生成。同时,我们利用额外的网格计算(Grid Reshape),来扩大注意力机制的input token数目,从而更多地利用图像的空间信息以达到更好的生成效果。Motion Module 作为一个单独的模块,在推理时可以用在不同的DiT基线模型上。此外,EasyAnimate不仅支持了motion-module模块的训练,也支持了DiT基模型/LoRA模型的训练,以方便用户根据自身需要来完成自定义风格的模型训练,进而生成任意风格的视频。
424 |
425 |
426 | # 算法限制
427 | - 受
428 |
429 | # 参考文献
430 | - magvit: https://github.com/google-research/magvit
431 | - PixArt: https://github.com/PixArt-alpha/PixArt-alpha
432 | - Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
433 | - Open-Sora: https://github.com/hpcaitech/Open-Sora
434 | - Animatediff: https://github.com/guoyww/AnimateDiff
435 |
436 | # 许可证
437 | 本项目采用 [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
438 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes import NODE_CLASS_MAPPINGS
2 |
3 | __all__ = ['NODE_CLASS_MAPPINGS']
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | from easyanimate.ui.ui import ui
2 |
3 | if __name__ == "__main__":
4 | server_name = "0.0.0.0"
5 | demo = ui()
6 | demo.launch(server_name=server_name)
--------------------------------------------------------------------------------
/config/easyanimate_image_normal_v1.yaml:
--------------------------------------------------------------------------------
1 | noise_scheduler_kwargs:
2 | beta_start: 0.0001
3 | beta_end: 0.02
4 | beta_schedule: "linear"
5 | steps_offset: 1
6 |
7 | vae_kwargs:
8 | enable_magvit: false
--------------------------------------------------------------------------------
/config/easyanimate_video_long_sequence_v1.yaml:
--------------------------------------------------------------------------------
1 | transformer_additional_kwargs:
2 | patch_3d: false
3 | fake_3d: false
4 | basic_block_type: "selfattentiontemporal"
5 | time_position_encoding_before_transformer: true
6 |
7 | noise_scheduler_kwargs:
8 | beta_start: 0.0001
9 | beta_end: 0.02
10 | beta_schedule: "linear"
11 | steps_offset: 1
12 |
13 | vae_kwargs:
14 | enable_magvit: false
--------------------------------------------------------------------------------
/config/easyanimate_video_motion_module_v1.yaml:
--------------------------------------------------------------------------------
1 | transformer_additional_kwargs:
2 | patch_3d: false
3 | fake_3d: false
4 | basic_block_type: "motionmodule"
5 | time_position_encoding_before_transformer: false
6 | motion_module_type: "VanillaGrid"
7 |
8 | motion_module_kwargs:
9 | num_attention_heads: 8
10 | num_transformer_block: 1
11 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12 | temporal_position_encoding: true
13 | temporal_position_encoding_max_len: 4096
14 | temporal_attention_dim_div: 1
15 | block_size: 2
16 |
17 | noise_scheduler_kwargs:
18 | beta_start: 0.0001
19 | beta_end: 0.02
20 | beta_schedule: "linear"
21 | steps_offset: 1
22 |
23 | vae_kwargs:
24 | enable_magvit: false
--------------------------------------------------------------------------------
/datasets/put datasets here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-EasyAnimate/9bef69d1ceda9d300613488517af6cc66cf5c360/datasets/put datasets here.txt
--------------------------------------------------------------------------------
/easyanimate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-EasyAnimate/9bef69d1ceda9d300613488517af6cc66cf5c360/easyanimate/__init__.py
--------------------------------------------------------------------------------
/easyanimate/data/bucket_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 | from torch.utils.data import BatchSampler, Dataset, Sampler
8 |
9 | ASPECT_RATIO_512 = {
10 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
11 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
12 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
13 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
14 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
15 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
16 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
17 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
18 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
19 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
20 | }
21 | ASPECT_RATIO_RANDOM_CROP_512 = {
22 | '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
23 | '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
24 | '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
25 | '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
26 | '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
27 | }
28 | ASPECT_RATIO_RANDOM_CROP_PROB = [
29 | 1, 2,
30 | 4, 4, 4, 4,
31 | 8, 8, 8,
32 | 4, 4, 4, 4,
33 | 2, 1
34 | ]
35 | ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
36 |
37 | def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
38 | aspect_ratio = height / width
39 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
40 | return ratios[closest_ratio], float(closest_ratio)
41 |
42 | def get_image_size_without_loading(path):
43 | with Image.open(path) as img:
44 | return img.size # (width, height)
45 |
46 | class AspectRatioBatchImageSampler(BatchSampler):
47 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
48 |
49 | Args:
50 | sampler (Sampler): Base sampler.
51 | dataset (Dataset): Dataset providing data information.
52 | batch_size (int): Size of mini-batch.
53 | drop_last (bool): If ``True``, the sampler will drop the last batch if
54 | its size would be less than ``batch_size``.
55 | aspect_ratios (dict): The predefined aspect ratios.
56 | """
57 | def __init__(
58 | self,
59 | sampler: Sampler,
60 | dataset: Dataset,
61 | batch_size: int,
62 | train_folder: str = None,
63 | aspect_ratios: dict = ASPECT_RATIO_512,
64 | drop_last: bool = False,
65 | config=None,
66 | **kwargs
67 | ) -> None:
68 | if not isinstance(sampler, Sampler):
69 | raise TypeError('sampler should be an instance of ``Sampler``, '
70 | f'but got {sampler}')
71 | if not isinstance(batch_size, int) or batch_size <= 0:
72 | raise ValueError('batch_size should be a positive integer value, '
73 | f'but got batch_size={batch_size}')
74 | self.sampler = sampler
75 | self.dataset = dataset
76 | self.train_folder = train_folder
77 | self.batch_size = batch_size
78 | self.aspect_ratios = aspect_ratios
79 | self.drop_last = drop_last
80 | self.config = config
81 | # buckets for each aspect ratio
82 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
83 | # [str(k) for k, v in aspect_ratios]
84 | self.current_available_bucket_keys = list(aspect_ratios.keys())
85 |
86 | def __iter__(self):
87 | for idx in self.sampler:
88 | try:
89 | image_dict = self.dataset[idx]
90 |
91 | image_id, name = image_dict['file_path'], image_dict['text']
92 | if self.train_folder is None:
93 | image_dir = image_id
94 | else:
95 | image_dir = os.path.join(self.train_folder, image_id)
96 |
97 | width, height = get_image_size_without_loading(image_dir)
98 |
99 | ratio = height / width # self.dataset[idx]
100 | except Exception as e:
101 | print(e)
102 | continue
103 | # find the closest aspect ratio
104 | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
105 | if closest_ratio not in self.current_available_bucket_keys:
106 | continue
107 | bucket = self._aspect_ratio_buckets[closest_ratio]
108 | bucket.append(idx)
109 | # yield a batch of indices in the same aspect ratio group
110 | if len(bucket) == self.batch_size:
111 | yield bucket[:]
112 | del bucket[:]
113 |
114 | class AspectRatioBatchSampler(BatchSampler):
115 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
116 |
117 | Args:
118 | sampler (Sampler): Base sampler.
119 | dataset (Dataset): Dataset providing data information.
120 | batch_size (int): Size of mini-batch.
121 | drop_last (bool): If ``True``, the sampler will drop the last batch if
122 | its size would be less than ``batch_size``.
123 | aspect_ratios (dict): The predefined aspect ratios.
124 | """
125 | def __init__(
126 | self,
127 | sampler: Sampler,
128 | dataset: Dataset,
129 | batch_size: int,
130 | video_folder: str = None,
131 | train_data_format: str = "webvid",
132 | aspect_ratios: dict = ASPECT_RATIO_512,
133 | drop_last: bool = False,
134 | config=None,
135 | **kwargs
136 | ) -> None:
137 | if not isinstance(sampler, Sampler):
138 | raise TypeError('sampler should be an instance of ``Sampler``, '
139 | f'but got {sampler}')
140 | if not isinstance(batch_size, int) or batch_size <= 0:
141 | raise ValueError('batch_size should be a positive integer value, '
142 | f'but got batch_size={batch_size}')
143 | self.sampler = sampler
144 | self.dataset = dataset
145 | self.video_folder = video_folder
146 | self.train_data_format = train_data_format
147 | self.batch_size = batch_size
148 | self.aspect_ratios = aspect_ratios
149 | self.drop_last = drop_last
150 | self.config = config
151 | # buckets for each aspect ratio
152 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
153 | # [str(k) for k, v in aspect_ratios]
154 | self.current_available_bucket_keys = list(aspect_ratios.keys())
155 |
156 | def __iter__(self):
157 | for idx in self.sampler:
158 | try:
159 | video_dict = self.dataset[idx]
160 | if self.train_data_format == "normal":
161 | video_id, name = video_dict['file_path'], video_dict['text']
162 | if self.video_folder is None:
163 | video_dir = video_id
164 | else:
165 | video_dir = os.path.join(self.video_folder, video_id)
166 | else:
167 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
168 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
169 | cap = cv2.VideoCapture(video_dir)
170 |
171 | # 获取视频尺寸
172 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
173 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
174 |
175 | ratio = height / width # self.dataset[idx]
176 | except Exception as e:
177 | print(e)
178 | continue
179 | # find the closest aspect ratio
180 | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
181 | if closest_ratio not in self.current_available_bucket_keys:
182 | continue
183 | bucket = self._aspect_ratio_buckets[closest_ratio]
184 | bucket.append(idx)
185 | # yield a batch of indices in the same aspect ratio group
186 | if len(bucket) == self.batch_size:
187 | yield bucket[:]
188 | del bucket[:]
--------------------------------------------------------------------------------
/easyanimate/data/dataset_image.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms as transforms
8 | from PIL import Image
9 | from torch.utils.data.dataset import Dataset
10 |
11 |
12 | class CC15M(Dataset):
13 | def __init__(
14 | self,
15 | json_path,
16 | video_folder=None,
17 | resolution=512,
18 | enable_bucket=False,
19 | ):
20 | print(f"loading annotations from {json_path} ...")
21 | self.dataset = json.load(open(json_path, 'r'))
22 | self.length = len(self.dataset)
23 | print(f"data scale: {self.length}")
24 |
25 | self.enable_bucket = enable_bucket
26 | self.video_folder = video_folder
27 |
28 | resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29 | self.pixel_transforms = transforms.Compose([
30 | transforms.Resize(resolution[0]),
31 | transforms.CenterCrop(resolution),
32 | transforms.ToTensor(),
33 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34 | ])
35 |
36 | def get_batch(self, idx):
37 | video_dict = self.dataset[idx]
38 | video_id, name = video_dict['file_path'], video_dict['text']
39 |
40 | if self.video_folder is None:
41 | video_dir = video_id
42 | else:
43 | video_dir = os.path.join(self.video_folder, video_id)
44 |
45 | pixel_values = Image.open(video_dir).convert("RGB")
46 | return pixel_values, name
47 |
48 | def __len__(self):
49 | return self.length
50 |
51 | def __getitem__(self, idx):
52 | while True:
53 | try:
54 | pixel_values, name = self.get_batch(idx)
55 | break
56 | except Exception as e:
57 | print(e)
58 | idx = random.randint(0, self.length-1)
59 |
60 | if not self.enable_bucket:
61 | pixel_values = self.pixel_transforms(pixel_values)
62 | else:
63 | pixel_values = np.array(pixel_values)
64 |
65 | sample = dict(pixel_values=pixel_values, text=name)
66 | return sample
67 |
68 | if __name__ == "__main__":
69 | dataset = CC15M(
70 | csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71 | resolution=512,
72 | )
73 |
74 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75 | for idx, batch in enumerate(dataloader):
76 | print(batch["pixel_values"].shape, len(batch["text"]))
--------------------------------------------------------------------------------
/easyanimate/data/dataset_image_video.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import io
3 | import json
4 | import math
5 | import os
6 | import random
7 | from threading import Thread
8 |
9 | import albumentations
10 | import cv2
11 | import numpy as np
12 | import torch
13 | import torchvision.transforms as transforms
14 | from decord import VideoReader
15 | from PIL import Image
16 | from torch.utils.data import BatchSampler, Sampler
17 | from torch.utils.data.dataset import Dataset
18 |
19 |
20 | class ImageVideoSampler(BatchSampler):
21 | """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
22 |
23 | Args:
24 | sampler (Sampler): Base sampler.
25 | dataset (Dataset): Dataset providing data information.
26 | batch_size (int): Size of mini-batch.
27 | drop_last (bool): If ``True``, the sampler will drop the last batch if
28 | its size would be less than ``batch_size``.
29 | aspect_ratios (dict): The predefined aspect ratios.
30 | """
31 |
32 | def __init__(self,
33 | sampler: Sampler,
34 | dataset: Dataset,
35 | batch_size: int,
36 | drop_last: bool = False
37 | ) -> None:
38 | if not isinstance(sampler, Sampler):
39 | raise TypeError('sampler should be an instance of ``Sampler``, '
40 | f'but got {sampler}')
41 | if not isinstance(batch_size, int) or batch_size <= 0:
42 | raise ValueError('batch_size should be a positive integer value, '
43 | f'but got batch_size={batch_size}')
44 | self.sampler = sampler
45 | self.dataset = dataset
46 | self.batch_size = batch_size
47 | self.drop_last = drop_last
48 |
49 | # buckets for each aspect ratio
50 | self.bucket = {'image':[], 'video':[]}
51 |
52 | def __iter__(self):
53 | for idx in self.sampler:
54 | content_type = self.dataset.dataset[idx].get('type', 'image')
55 | self.bucket[content_type].append(idx)
56 |
57 | # yield a batch of indices in the same aspect ratio group
58 | if len(self.bucket['video']) == self.batch_size:
59 | bucket = self.bucket['video']
60 | yield bucket[:]
61 | del bucket[:]
62 | elif len(self.bucket['image']) == self.batch_size:
63 | bucket = self.bucket['image']
64 | yield bucket[:]
65 | del bucket[:]
66 |
67 | class ImageVideoDataset(Dataset):
68 | def __init__(
69 | self,
70 | ann_path, data_root=None,
71 | video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
72 | image_sample_size=512,
73 | # For Random Crop
74 | min_crop_f=0.9, max_crop_f=1,
75 | video_repeat=0,
76 | enable_bucket=False
77 | ):
78 | # Loading annotations from files
79 | print(f"loading annotations from {ann_path} ...")
80 | if ann_path.endswith('.csv'):
81 | with open(ann_path, 'r') as csvfile:
82 | dataset = list(csv.DictReader(csvfile))
83 | elif ann_path.endswith('.json'):
84 | dataset = json.load(open(ann_path))
85 |
86 | self.data_root = data_root
87 |
88 | # It's used to balance num of images and videos.
89 | self.dataset = []
90 | for data in dataset:
91 | if data.get('data_type', 'image') != 'video' or data.get('type', 'image') != 'video':
92 | self.dataset.append(data)
93 | if video_repeat > 0:
94 | for _ in range(video_repeat):
95 | for data in dataset:
96 | if data.get('data_type', 'image') == 'video' or data.get('type', 'image') == 'video':
97 | self.dataset.append(data)
98 | del dataset
99 |
100 | self.length = len(self.dataset)
101 | print(f"data scale: {self.length}")
102 | self.min_crop_f = min_crop_f
103 | self.max_crop_f = max_crop_f
104 | # TODO: enable bucket training
105 | self.enable_bucket = enable_bucket
106 |
107 | # Video params
108 | self.video_sample_stride = video_sample_stride
109 | self.video_sample_n_frames = video_sample_n_frames
110 | self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
111 | self.video_rescaler = albumentations.SmallestMaxSize(max_size=min(self.video_sample_size), interpolation=cv2.INTER_AREA)
112 |
113 | # Image params
114 | self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
115 | self.image_transforms = transforms.Compose([
116 | transforms.RandomHorizontalFlip(),
117 | transforms.Resize(min(self.image_sample_size)),
118 | transforms.CenterCrop(self.image_sample_size),
119 | transforms.ToTensor(),
120 | transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
121 | ])
122 |
123 | def get_batch(self, idx):
124 | data_info = self.dataset[idx % len(self.dataset)]
125 |
126 | if data_info.get('data_type', 'image')=='video' or data_info.get('type', 'image')=='video':
127 | video_path, text = data_info['file_path'], data_info['text']
128 | # Get abs path of video
129 | if self.data_root is not None:
130 | video_path = os.path.join(self.data_root, video_path)
131 |
132 | # Get video information firstly
133 | video_reader = VideoReader(video_path, num_threads=2)
134 | h, w, c = video_reader[0].shape
135 | del video_reader
136 |
137 | # Resize to bigger firstly
138 | t_h = int(self.video_sample_size[0] * 1.25 * h / min(h, w))
139 | t_w = int(self.video_sample_size[0] * 1.25 * w / min(h, w))
140 |
141 | # Get video pixels
142 | video_reader = VideoReader(video_path, width=t_w, height=t_h, num_threads=2)
143 | video_length = len(video_reader)
144 | clip_length = min(video_length, (self.video_sample_n_frames - 1) * self.video_sample_stride + 1)
145 | start_idx = random.randint(0, video_length - clip_length)
146 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_sample_n_frames, dtype=int)
147 | imgs = video_reader.get_batch(batch_index).asnumpy()
148 | del video_reader
149 | if imgs.shape[0] != self.video_sample_n_frames:
150 | raise ValueError('Video data Sampler Error')
151 |
152 | # Crop center of above videos
153 | min_side_len = min(imgs[0].shape[:2])
154 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
155 | crop_side_len = int(crop_side_len)
156 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
157 | imgs = np.transpose(imgs, (1, 2, 3, 0))
158 | imgs = self.cropper(image=imgs)["image"]
159 | imgs = np.transpose(imgs, (3, 0, 1, 2))
160 | out_imgs = []
161 |
162 | # Resize to video_sample_size
163 | for img in imgs:
164 | img = self.video_rescaler(image=img)["image"]
165 | out_imgs.append(img[None, :, :, :])
166 | imgs = np.concatenate(out_imgs).transpose(0, 3, 1, 2)
167 |
168 | # Normalize to -1~1
169 | imgs = ((imgs - 127.5) / 127.5).astype(np.float32)
170 | if imgs.shape[0] != self.video_sample_n_frames:
171 | raise ValueError('video data sampler error')
172 |
173 | # Random use no text generation
174 | if random.random() < 0.1:
175 | text = ''
176 | return torch.from_numpy(imgs), text, 'video'
177 | else:
178 | image_path, text = data_info['file_path'], data_info['text']
179 | if self.data_root is not None:
180 | image_path = os.path.join(self.data_root, image_path)
181 | image = Image.open(image_path).convert('RGB')
182 | image = self.image_transforms(image).unsqueeze(0)
183 | if random.random()<0.1:
184 | text = ''
185 | return image, text, 'video'
186 |
187 | def __len__(self):
188 | return self.length
189 |
190 | def __getitem__(self, idx):
191 | while True:
192 | sample = {}
193 | def get_data(data_idx):
194 | pixel_values, name, data_type = self.get_batch(idx)
195 | sample["pixel_values"] = pixel_values
196 | sample["text"] = name
197 | sample["data_type"] = data_type
198 | sample["idx"] = idx
199 | try:
200 | t = Thread(target=get_data, args=(idx, ))
201 | t.start()
202 | t.join(5)
203 | if len(sample)>0:
204 | break
205 | except Exception as e:
206 | print(self.dataset[idx])
207 | idx = idx - 1
208 | return sample
209 |
210 | if __name__ == "__main__":
211 | dataset = ImageVideoDataset(
212 | ann_path="test.json"
213 | )
214 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
215 | for idx, batch in enumerate(dataloader):
216 | print(batch["pixel_values"].shape, len(batch["text"]))
--------------------------------------------------------------------------------
/easyanimate/data/dataset_video.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import io
3 | import json
4 | import math
5 | import os
6 | import random
7 |
8 | import numpy as np
9 | import torch
10 | import torchvision.transforms as transforms
11 | from decord import VideoReader
12 | from einops import rearrange
13 | from torch.utils.data.dataset import Dataset
14 |
15 |
16 | def get_random_mask(shape):
17 | f, c, h, w = shape
18 |
19 | mask_index = np.random.randint(0, 4)
20 | mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
21 | if mask_index == 0:
22 | mask[1:, :, :, :] = 1
23 | elif mask_index == 1:
24 | mask_frame_index = 1
25 | mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
26 | elif mask_index == 2:
27 | center_x = torch.randint(0, w, (1,)).item()
28 | center_y = torch.randint(0, h, (1,)).item()
29 | block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
30 | block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
31 |
32 | start_x = max(center_x - block_size_x // 2, 0)
33 | end_x = min(center_x + block_size_x // 2, w)
34 | start_y = max(center_y - block_size_y // 2, 0)
35 | end_y = min(center_y + block_size_y // 2, h)
36 | mask[:, :, start_y:end_y, start_x:end_x] = 1
37 | elif mask_index == 3:
38 | center_x = torch.randint(0, w, (1,)).item()
39 | center_y = torch.randint(0, h, (1,)).item()
40 | block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
41 | block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
42 |
43 | start_x = max(center_x - block_size_x // 2, 0)
44 | end_x = min(center_x + block_size_x // 2, w)
45 | start_y = max(center_y - block_size_y // 2, 0)
46 | end_y = min(center_y + block_size_y // 2, h)
47 |
48 | mask_frame_before = np.random.randint(0, f // 2)
49 | mask_frame_after = np.random.randint(f // 2, f)
50 | mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
51 | else:
52 | raise ValueError(f"The mask_index {mask_index} is not define")
53 | return mask
54 |
55 |
56 | class WebVid10M(Dataset):
57 | def __init__(
58 | self,
59 | csv_path, video_folder,
60 | sample_size=256, sample_stride=4, sample_n_frames=16,
61 | enable_bucket=False, enable_inpaint=False, is_image=False,
62 | ):
63 | print(f"loading annotations from {csv_path} ...")
64 | with open(csv_path, 'r') as csvfile:
65 | self.dataset = list(csv.DictReader(csvfile))
66 | self.length = len(self.dataset)
67 | print(f"data scale: {self.length}")
68 |
69 | self.video_folder = video_folder
70 | self.sample_stride = sample_stride
71 | self.sample_n_frames = sample_n_frames
72 | self.enable_bucket = enable_bucket
73 | self.enable_inpaint = enable_inpaint
74 | self.is_image = is_image
75 |
76 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
77 | self.pixel_transforms = transforms.Compose([
78 | transforms.Resize(sample_size[0]),
79 | transforms.CenterCrop(sample_size),
80 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
81 | ])
82 |
83 | def get_batch(self, idx):
84 | video_dict = self.dataset[idx]
85 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
86 |
87 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
88 | video_reader = VideoReader(video_dir)
89 | video_length = len(video_reader)
90 |
91 | if not self.is_image:
92 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
93 | start_idx = random.randint(0, video_length - clip_length)
94 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
95 | else:
96 | batch_index = [random.randint(0, video_length - 1)]
97 |
98 | if not self.enable_bucket:
99 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
100 | pixel_values = pixel_values / 255.
101 | del video_reader
102 | else:
103 | pixel_values = video_reader.get_batch(batch_index).asnumpy()
104 |
105 | if self.is_image:
106 | pixel_values = pixel_values[0]
107 | return pixel_values, name
108 |
109 | def __len__(self):
110 | return self.length
111 |
112 | def __getitem__(self, idx):
113 | while True:
114 | try:
115 | pixel_values, name = self.get_batch(idx)
116 | break
117 |
118 | except Exception as e:
119 | print("Error info:", e)
120 | idx = random.randint(0, self.length-1)
121 |
122 | if not self.enable_bucket:
123 | pixel_values = self.pixel_transforms(pixel_values)
124 | if self.enable_inpaint:
125 | mask = get_random_mask(pixel_values.size())
126 | mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
127 | sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
128 | else:
129 | sample = dict(pixel_values=pixel_values, text=name)
130 | return sample
131 |
132 |
133 | class VideoDataset(Dataset):
134 | def __init__(
135 | self,
136 | json_path, video_folder=None,
137 | sample_size=256, sample_stride=4, sample_n_frames=16,
138 | enable_bucket=False, enable_inpaint=False
139 | ):
140 | print(f"loading annotations from {json_path} ...")
141 | self.dataset = json.load(open(json_path, 'r'))
142 | self.length = len(self.dataset)
143 | print(f"data scale: {self.length}")
144 |
145 | self.video_folder = video_folder
146 | self.sample_stride = sample_stride
147 | self.sample_n_frames = sample_n_frames
148 | self.enable_bucket = enable_bucket
149 | self.enable_inpaint = enable_inpaint
150 |
151 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
152 | self.pixel_transforms = transforms.Compose(
153 | [
154 | transforms.Resize(sample_size[0]),
155 | transforms.CenterCrop(sample_size),
156 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
157 | ]
158 | )
159 |
160 | def get_batch(self, idx):
161 | video_dict = self.dataset[idx]
162 | video_id, name = video_dict['file_path'], video_dict['text']
163 |
164 | if self.video_folder is None:
165 | video_dir = video_id
166 | else:
167 | video_dir = os.path.join(self.video_folder, video_id)
168 | video_reader = VideoReader(video_dir)
169 | video_length = len(video_reader)
170 |
171 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
172 | start_idx = random.randint(0, video_length - clip_length)
173 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
174 |
175 | if not self.enable_bucket:
176 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
177 | pixel_values = pixel_values / 255.
178 | del video_reader
179 | else:
180 | pixel_values = video_reader.get_batch(batch_index).asnumpy()
181 |
182 | return pixel_values, name
183 |
184 | def __len__(self):
185 | return self.length
186 |
187 | def __getitem__(self, idx):
188 | while True:
189 | try:
190 | pixel_values, name = self.get_batch(idx)
191 | break
192 |
193 | except Exception as e:
194 | print("Error info:", e)
195 | idx = random.randint(0, self.length-1)
196 |
197 | if not self.enable_bucket:
198 | pixel_values = self.pixel_transforms(pixel_values)
199 | if self.enable_inpaint:
200 | mask = get_random_mask(pixel_values.size())
201 | mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
202 | sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
203 | else:
204 | sample = dict(pixel_values=pixel_values, text=name)
205 | return sample
206 |
207 |
208 | if __name__ == "__main__":
209 | if 1:
210 | dataset = VideoDataset(
211 | json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
212 | sample_size=256,
213 | sample_stride=4, sample_n_frames=16,
214 | )
215 |
216 | if 0:
217 | dataset = WebVid10M(
218 | csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
219 | video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
220 | sample_size=256,
221 | sample_stride=4, sample_n_frames=16,
222 | is_image=False,
223 | )
224 |
225 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
226 | for idx, batch in enumerate(dataloader):
227 | print(batch["pixel_values"].shape, len(batch["text"]))
--------------------------------------------------------------------------------
/easyanimate/models/motion_module.py:
--------------------------------------------------------------------------------
1 | """Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | """
3 | import math
4 | from typing import Any, Callable, List, Optional, Tuple, Union
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from diffusers.models.attention import FeedForward
9 | from diffusers.utils.import_utils import is_xformers_available
10 | from einops import rearrange, repeat
11 | from torch import nn
12 |
13 | if is_xformers_available():
14 | import xformers
15 | import xformers.ops
16 | else:
17 | xformers = None
18 |
19 | class CrossAttention(nn.Module):
20 | r"""
21 | A cross attention layer.
22 |
23 | Parameters:
24 | query_dim (`int`): The number of channels in the query.
25 | cross_attention_dim (`int`, *optional*):
26 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
27 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
28 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
29 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
30 | bias (`bool`, *optional*, defaults to False):
31 | Set to `True` for the query, key, and value linear layers to contain a bias parameter.
32 | """
33 |
34 | def __init__(
35 | self,
36 | query_dim: int,
37 | cross_attention_dim: Optional[int] = None,
38 | heads: int = 8,
39 | dim_head: int = 64,
40 | dropout: float = 0.0,
41 | bias=False,
42 | upcast_attention: bool = False,
43 | upcast_softmax: bool = False,
44 | added_kv_proj_dim: Optional[int] = None,
45 | norm_num_groups: Optional[int] = None,
46 | ):
47 | super().__init__()
48 | inner_dim = dim_head * heads
49 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
50 | self.upcast_attention = upcast_attention
51 | self.upcast_softmax = upcast_softmax
52 |
53 | self.scale = dim_head**-0.5
54 |
55 | self.heads = heads
56 | # for slice_size > 0 the attention score computation
57 | # is split across the batch axis to save memory
58 | # You can set slice_size with `set_attention_slice`
59 | self.sliceable_head_dim = heads
60 | self._slice_size = None
61 | self._use_memory_efficient_attention_xformers = False
62 | self.added_kv_proj_dim = added_kv_proj_dim
63 |
64 | if norm_num_groups is not None:
65 | self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
66 | else:
67 | self.group_norm = None
68 |
69 | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
70 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
71 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
72 |
73 | if self.added_kv_proj_dim is not None:
74 | self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
75 | self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
76 |
77 | self.to_out = nn.ModuleList([])
78 | self.to_out.append(nn.Linear(inner_dim, query_dim))
79 | self.to_out.append(nn.Dropout(dropout))
80 |
81 | def set_use_memory_efficient_attention_xformers(
82 | self, valid: bool, attention_op: Optional[Callable] = None
83 | ) -> None:
84 | self._use_memory_efficient_attention_xformers = valid
85 |
86 | def reshape_heads_to_batch_dim(self, tensor):
87 | batch_size, seq_len, dim = tensor.shape
88 | head_size = self.heads
89 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
90 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
91 | return tensor
92 |
93 | def reshape_batch_dim_to_heads(self, tensor):
94 | batch_size, seq_len, dim = tensor.shape
95 | head_size = self.heads
96 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
97 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
98 | return tensor
99 |
100 | def set_attention_slice(self, slice_size):
101 | if slice_size is not None and slice_size > self.sliceable_head_dim:
102 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
103 |
104 | self._slice_size = slice_size
105 |
106 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
107 | batch_size, sequence_length, _ = hidden_states.shape
108 |
109 | encoder_hidden_states = encoder_hidden_states
110 |
111 | if self.group_norm is not None:
112 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
113 |
114 | query = self.to_q(hidden_states)
115 | dim = query.shape[-1]
116 | query = self.reshape_heads_to_batch_dim(query)
117 |
118 | if self.added_kv_proj_dim is not None:
119 | key = self.to_k(hidden_states)
120 | value = self.to_v(hidden_states)
121 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
122 | encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
123 |
124 | key = self.reshape_heads_to_batch_dim(key)
125 | value = self.reshape_heads_to_batch_dim(value)
126 | encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
127 | encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
128 |
129 | key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
130 | value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
131 | else:
132 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
133 | key = self.to_k(encoder_hidden_states)
134 | value = self.to_v(encoder_hidden_states)
135 |
136 | key = self.reshape_heads_to_batch_dim(key)
137 | value = self.reshape_heads_to_batch_dim(value)
138 |
139 | if attention_mask is not None:
140 | if attention_mask.shape[-1] != query.shape[1]:
141 | target_length = query.shape[1]
142 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
143 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
144 |
145 | # attention, what we cannot get enough of
146 | if self._use_memory_efficient_attention_xformers:
147 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
148 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
149 | hidden_states = hidden_states.to(query.dtype)
150 | else:
151 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
152 | hidden_states = self._attention(query, key, value, attention_mask)
153 | else:
154 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
155 |
156 | # linear proj
157 | hidden_states = self.to_out[0](hidden_states)
158 |
159 | # dropout
160 | hidden_states = self.to_out[1](hidden_states)
161 | return hidden_states
162 |
163 | def _attention(self, query, key, value, attention_mask=None):
164 | if self.upcast_attention:
165 | query = query.float()
166 | key = key.float()
167 |
168 | attention_scores = torch.baddbmm(
169 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
170 | query,
171 | key.transpose(-1, -2),
172 | beta=0,
173 | alpha=self.scale,
174 | )
175 |
176 | if attention_mask is not None:
177 | attention_scores = attention_scores + attention_mask
178 |
179 | if self.upcast_softmax:
180 | attention_scores = attention_scores.float()
181 |
182 | attention_probs = attention_scores.softmax(dim=-1)
183 |
184 | # cast back to the original dtype
185 | attention_probs = attention_probs.to(value.dtype)
186 |
187 | # compute attention output
188 | hidden_states = torch.bmm(attention_probs, value)
189 |
190 | # reshape hidden_states
191 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
192 | return hidden_states
193 |
194 | def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
195 | batch_size_attention = query.shape[0]
196 | hidden_states = torch.zeros(
197 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
198 | )
199 | slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
200 | for i in range(hidden_states.shape[0] // slice_size):
201 | start_idx = i * slice_size
202 | end_idx = (i + 1) * slice_size
203 |
204 | query_slice = query[start_idx:end_idx]
205 | key_slice = key[start_idx:end_idx]
206 |
207 | if self.upcast_attention:
208 | query_slice = query_slice.float()
209 | key_slice = key_slice.float()
210 |
211 | attn_slice = torch.baddbmm(
212 | torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
213 | query_slice,
214 | key_slice.transpose(-1, -2),
215 | beta=0,
216 | alpha=self.scale,
217 | )
218 |
219 | if attention_mask is not None:
220 | attn_slice = attn_slice + attention_mask[start_idx:end_idx]
221 |
222 | if self.upcast_softmax:
223 | attn_slice = attn_slice.float()
224 |
225 | attn_slice = attn_slice.softmax(dim=-1)
226 |
227 | # cast back to the original dtype
228 | attn_slice = attn_slice.to(value.dtype)
229 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
230 |
231 | hidden_states[start_idx:end_idx] = attn_slice
232 |
233 | # reshape hidden_states
234 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
235 | return hidden_states
236 |
237 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
238 | # TODO attention_mask
239 | query = query.contiguous()
240 | key = key.contiguous()
241 | value = value.contiguous()
242 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
243 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
244 | return hidden_states
245 |
246 | def zero_module(module):
247 | # Zero out the parameters of a module and return it.
248 | for p in module.parameters():
249 | p.detach().zero_()
250 | return module
251 |
252 | def get_motion_module(
253 | in_channels,
254 | motion_module_type: str,
255 | motion_module_kwargs: dict,
256 | ):
257 | if motion_module_type == "Vanilla":
258 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
259 | elif motion_module_type == "VanillaGrid":
260 | return VanillaTemporalModule(in_channels=in_channels, grid=True, **motion_module_kwargs,)
261 | else:
262 | raise ValueError
263 |
264 | class VanillaTemporalModule(nn.Module):
265 | def __init__(
266 | self,
267 | in_channels,
268 | num_attention_heads = 8,
269 | num_transformer_block = 2,
270 | attention_block_types =( "Temporal_Self", "Temporal_Self" ),
271 | cross_frame_attention_mode = None,
272 | temporal_position_encoding = False,
273 | temporal_position_encoding_max_len = 4096,
274 | temporal_attention_dim_div = 1,
275 | zero_initialize = True,
276 | block_size = 1,
277 | grid = False,
278 | ):
279 | super().__init__()
280 |
281 | self.temporal_transformer = TemporalTransformer3DModel(
282 | in_channels=in_channels,
283 | num_attention_heads=num_attention_heads,
284 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
285 | num_layers=num_transformer_block,
286 | attention_block_types=attention_block_types,
287 | cross_frame_attention_mode=cross_frame_attention_mode,
288 | temporal_position_encoding=temporal_position_encoding,
289 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
290 | grid=grid,
291 | block_size=block_size,
292 | )
293 | if zero_initialize:
294 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
295 |
296 | def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
297 | hidden_states = input_tensor
298 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
299 |
300 | output = hidden_states
301 | return output
302 |
303 | class TemporalTransformer3DModel(nn.Module):
304 | def __init__(
305 | self,
306 | in_channels,
307 | num_attention_heads,
308 | attention_head_dim,
309 |
310 | num_layers,
311 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
312 | dropout = 0.0,
313 | norm_num_groups = 32,
314 | cross_attention_dim = 768,
315 | activation_fn = "geglu",
316 | attention_bias = False,
317 | upcast_attention = False,
318 |
319 | cross_frame_attention_mode = None,
320 | temporal_position_encoding = False,
321 | temporal_position_encoding_max_len = 4096,
322 | grid = False,
323 | block_size = 1,
324 | ):
325 | super().__init__()
326 |
327 | inner_dim = num_attention_heads * attention_head_dim
328 |
329 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
330 | self.proj_in = nn.Linear(in_channels, inner_dim)
331 |
332 | self.block_size = block_size
333 | self.transformer_blocks = nn.ModuleList(
334 | [
335 | TemporalTransformerBlock(
336 | dim=inner_dim,
337 | num_attention_heads=num_attention_heads,
338 | attention_head_dim=attention_head_dim,
339 | attention_block_types=attention_block_types,
340 | dropout=dropout,
341 | norm_num_groups=norm_num_groups,
342 | cross_attention_dim=cross_attention_dim,
343 | activation_fn=activation_fn,
344 | attention_bias=attention_bias,
345 | upcast_attention=upcast_attention,
346 | cross_frame_attention_mode=cross_frame_attention_mode,
347 | temporal_position_encoding=temporal_position_encoding,
348 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
349 | block_size=block_size,
350 | grid=grid,
351 | )
352 | for d in range(num_layers)
353 | ]
354 | )
355 | self.proj_out = nn.Linear(inner_dim, in_channels)
356 |
357 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
358 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
359 | video_length = hidden_states.shape[2]
360 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
361 |
362 | batch, channel, height, weight = hidden_states.shape
363 | residual = hidden_states
364 |
365 | hidden_states = self.norm(hidden_states)
366 | inner_dim = hidden_states.shape[1]
367 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
368 | hidden_states = self.proj_in(hidden_states)
369 |
370 | # Transformer Blocks
371 | for block in self.transformer_blocks:
372 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, height=height, weight=weight)
373 |
374 | # output
375 | hidden_states = self.proj_out(hidden_states)
376 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
377 |
378 | output = hidden_states + residual
379 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
380 |
381 | return output
382 |
383 | class TemporalTransformerBlock(nn.Module):
384 | def __init__(
385 | self,
386 | dim,
387 | num_attention_heads,
388 | attention_head_dim,
389 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
390 | dropout = 0.0,
391 | norm_num_groups = 32,
392 | cross_attention_dim = 768,
393 | activation_fn = "geglu",
394 | attention_bias = False,
395 | upcast_attention = False,
396 | cross_frame_attention_mode = None,
397 | temporal_position_encoding = False,
398 | temporal_position_encoding_max_len = 4096,
399 | block_size = 1,
400 | grid = False,
401 | ):
402 | super().__init__()
403 |
404 | attention_blocks = []
405 | norms = []
406 |
407 | for block_name in attention_block_types:
408 | attention_blocks.append(
409 | VersatileAttention(
410 | attention_mode=block_name.split("_")[0],
411 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
412 |
413 | query_dim=dim,
414 | heads=num_attention_heads,
415 | dim_head=attention_head_dim,
416 | dropout=dropout,
417 | bias=attention_bias,
418 | upcast_attention=upcast_attention,
419 |
420 | cross_frame_attention_mode=cross_frame_attention_mode,
421 | temporal_position_encoding=temporal_position_encoding,
422 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
423 | block_size=block_size,
424 | grid=grid,
425 | )
426 | )
427 | norms.append(nn.LayerNorm(dim))
428 |
429 | self.attention_blocks = nn.ModuleList(attention_blocks)
430 | self.norms = nn.ModuleList(norms)
431 |
432 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
433 | self.ff_norm = nn.LayerNorm(dim)
434 |
435 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
436 | for attention_block, norm in zip(self.attention_blocks, self.norms):
437 | norm_hidden_states = norm(hidden_states)
438 | hidden_states = attention_block(
439 | norm_hidden_states,
440 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
441 | video_length=video_length,
442 | height=height,
443 | weight=weight,
444 | ) + hidden_states
445 |
446 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
447 |
448 | output = hidden_states
449 | return output
450 |
451 | class PositionalEncoding(nn.Module):
452 | def __init__(
453 | self,
454 | d_model,
455 | dropout = 0.,
456 | max_len = 4096
457 | ):
458 | super().__init__()
459 | self.dropout = nn.Dropout(p=dropout)
460 | position = torch.arange(max_len).unsqueeze(1)
461 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
462 | pe = torch.zeros(1, max_len, d_model)
463 | pe[0, :, 0::2] = torch.sin(position * div_term)
464 | pe[0, :, 1::2] = torch.cos(position * div_term)
465 | self.register_buffer('pe', pe)
466 |
467 | def forward(self, x):
468 | x = x + self.pe[:, :x.size(1)]
469 | return self.dropout(x)
470 |
471 | class VersatileAttention(CrossAttention):
472 | def __init__(
473 | self,
474 | attention_mode = None,
475 | cross_frame_attention_mode = None,
476 | temporal_position_encoding = False,
477 | temporal_position_encoding_max_len = 4096,
478 | grid = False,
479 | block_size = 1,
480 | *args, **kwargs
481 | ):
482 | super().__init__(*args, **kwargs)
483 | assert attention_mode == "Temporal"
484 |
485 | self.attention_mode = attention_mode
486 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
487 |
488 | self.block_size = block_size
489 | self.grid = grid
490 | self.pos_encoder = PositionalEncoding(
491 | kwargs["query_dim"],
492 | dropout=0.,
493 | max_len=temporal_position_encoding_max_len
494 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None
495 |
496 | def extra_repr(self):
497 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
498 |
499 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
500 | batch_size, sequence_length, _ = hidden_states.shape
501 |
502 | if self.attention_mode == "Temporal":
503 | # for add pos_encoder
504 | _, before_d, _c = hidden_states.size()
505 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
506 | if self.pos_encoder is not None:
507 | hidden_states = self.pos_encoder(hidden_states)
508 |
509 | if self.grid:
510 | hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
511 | hidden_states = rearrange(hidden_states, "b f (h w) c -> b f h w c", h=height, w=weight)
512 |
513 | hidden_states = rearrange(hidden_states, "b f (h n) (w m) c -> (b h w) (f n m) c", n=self.block_size, m=self.block_size)
514 | d = before_d // self.block_size // self.block_size
515 | else:
516 | d = before_d
517 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
518 | else:
519 | raise NotImplementedError
520 |
521 | encoder_hidden_states = encoder_hidden_states
522 |
523 | if self.group_norm is not None:
524 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
525 |
526 | query = self.to_q(hidden_states)
527 | dim = query.shape[-1]
528 | query = self.reshape_heads_to_batch_dim(query)
529 |
530 | if self.added_kv_proj_dim is not None:
531 | raise NotImplementedError
532 |
533 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
534 | key = self.to_k(encoder_hidden_states)
535 | value = self.to_v(encoder_hidden_states)
536 |
537 | key = self.reshape_heads_to_batch_dim(key)
538 | value = self.reshape_heads_to_batch_dim(value)
539 |
540 | if attention_mask is not None:
541 | if attention_mask.shape[-1] != query.shape[1]:
542 | target_length = query.shape[1]
543 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
544 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
545 |
546 | bs = 512
547 | new_hidden_states = []
548 | for i in range(0, query.shape[0], bs):
549 | # attention, what we cannot get enough of
550 | if self._use_memory_efficient_attention_xformers:
551 | hidden_states = self._memory_efficient_attention_xformers(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
552 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
553 | hidden_states = hidden_states.to(query.dtype)
554 | else:
555 | if self._slice_size is None or query[i : i + bs].shape[0] // self._slice_size == 1:
556 | hidden_states = self._attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
557 | else:
558 | hidden_states = self._sliced_attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], sequence_length, dim, attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
559 | new_hidden_states.append(hidden_states)
560 | hidden_states = torch.cat(new_hidden_states, dim = 0)
561 |
562 | # linear proj
563 | hidden_states = self.to_out[0](hidden_states)
564 |
565 | # dropout
566 | hidden_states = self.to_out[1](hidden_states)
567 |
568 | if self.attention_mode == "Temporal":
569 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
570 | if self.grid:
571 | hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
572 | hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
573 | hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
574 |
575 | return hidden_states
--------------------------------------------------------------------------------
/easyanimate/models/patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.nn.init as init
7 | from einops import rearrange
8 | from torch import nn
9 |
10 |
11 | def get_2d_sincos_pos_embed(
12 | embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
13 | ):
14 | """
15 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
16 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
17 | """
18 | if isinstance(grid_size, int):
19 | grid_size = (grid_size, grid_size)
20 |
21 | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
22 | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
23 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
24 | grid = np.stack(grid, axis=0)
25 |
26 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
27 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
28 | if cls_token and extra_tokens > 0:
29 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
30 | return pos_embed
31 |
32 |
33 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
34 | if embed_dim % 2 != 0:
35 | raise ValueError("embed_dim must be divisible by 2")
36 |
37 | # use half of dimensions to encode grid_h
38 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
39 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
40 |
41 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
42 | return emb
43 |
44 |
45 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
46 | """
47 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
48 | """
49 | if embed_dim % 2 != 0:
50 | raise ValueError("embed_dim must be divisible by 2")
51 |
52 | omega = np.arange(embed_dim // 2, dtype=np.float64)
53 | omega /= embed_dim / 2.0
54 | omega = 1.0 / 10000**omega # (D/2,)
55 |
56 | pos = pos.reshape(-1) # (M,)
57 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
58 |
59 | emb_sin = np.sin(out) # (M, D/2)
60 | emb_cos = np.cos(out) # (M, D/2)
61 |
62 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
63 | return emb
64 |
65 | class Patch1D(nn.Module):
66 | def __init__(
67 | self,
68 | channels: int,
69 | use_conv: bool = False,
70 | out_channels: Optional[int] = None,
71 | stride: int = 2,
72 | padding: int = 0,
73 | name: str = "conv",
74 | ):
75 | super().__init__()
76 | self.channels = channels
77 | self.out_channels = out_channels or channels
78 | self.use_conv = use_conv
79 | self.padding = padding
80 | self.name = name
81 |
82 | if use_conv:
83 | self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding)
84 | init.constant_(self.conv.weight, 0.0)
85 | with torch.no_grad():
86 | for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride
87 | init.constant_(self.conv.bias, 0.0)
88 | else:
89 | assert self.channels == self.out_channels
90 | self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
91 |
92 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
93 | assert inputs.shape[1] == self.channels
94 | return self.conv(inputs)
95 |
96 | class UnPatch1D(nn.Module):
97 | def __init__(
98 | self,
99 | channels: int,
100 | use_conv: bool = False,
101 | use_conv_transpose: bool = False,
102 | out_channels: Optional[int] = None,
103 | name: str = "conv",
104 | ):
105 | super().__init__()
106 | self.channels = channels
107 | self.out_channels = out_channels or channels
108 | self.use_conv = use_conv
109 | self.use_conv_transpose = use_conv_transpose
110 | self.name = name
111 |
112 | self.conv = None
113 | if use_conv_transpose:
114 | self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
115 | elif use_conv:
116 | self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
117 |
118 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
119 | assert inputs.shape[1] == self.channels
120 | if self.use_conv_transpose:
121 | return self.conv(inputs)
122 |
123 | outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
124 |
125 | if self.use_conv:
126 | outputs = self.conv(outputs)
127 |
128 | return outputs
129 |
130 | class PatchEmbed3D(nn.Module):
131 | """3D Image to Patch Embedding"""
132 |
133 | def __init__(
134 | self,
135 | height=224,
136 | width=224,
137 | patch_size=16,
138 | time_patch_size=4,
139 | in_channels=3,
140 | embed_dim=768,
141 | layer_norm=False,
142 | flatten=True,
143 | bias=True,
144 | interpolation_scale=1,
145 | ):
146 | super().__init__()
147 |
148 | num_patches = (height // patch_size) * (width // patch_size)
149 | self.flatten = flatten
150 | self.layer_norm = layer_norm
151 |
152 | self.proj = nn.Conv3d(
153 | in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias
154 | )
155 | if layer_norm:
156 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
157 | else:
158 | self.norm = None
159 |
160 | self.patch_size = patch_size
161 | # See:
162 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
163 | self.height, self.width = height // patch_size, width // patch_size
164 | self.base_size = height // patch_size
165 | self.interpolation_scale = interpolation_scale
166 | pos_embed = get_2d_sincos_pos_embed(
167 | embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
168 | )
169 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
170 |
171 | def forward(self, latent):
172 | height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
173 |
174 | latent = self.proj(latent)
175 | latent = rearrange(latent, "b c f h w -> (b f) c h w")
176 | if self.flatten:
177 | latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
178 | if self.layer_norm:
179 | latent = self.norm(latent)
180 |
181 | # Interpolate positional embeddings if needed.
182 | # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
183 | if self.height != height or self.width != width:
184 | pos_embed = get_2d_sincos_pos_embed(
185 | embed_dim=self.pos_embed.shape[-1],
186 | grid_size=(height, width),
187 | base_size=self.base_size,
188 | interpolation_scale=self.interpolation_scale,
189 | )
190 | pos_embed = torch.from_numpy(pos_embed)
191 | pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
192 | else:
193 | pos_embed = self.pos_embed
194 |
195 | return (latent + pos_embed).to(latent.dtype)
196 |
197 | class PatchEmbedF3D(nn.Module):
198 | """Fake 3D Image to Patch Embedding"""
199 |
200 | def __init__(
201 | self,
202 | height=224,
203 | width=224,
204 | patch_size=16,
205 | in_channels=3,
206 | embed_dim=768,
207 | layer_norm=False,
208 | flatten=True,
209 | bias=True,
210 | interpolation_scale=1,
211 | ):
212 | super().__init__()
213 |
214 | num_patches = (height // patch_size) * (width // patch_size)
215 | self.flatten = flatten
216 | self.layer_norm = layer_norm
217 |
218 | self.proj = nn.Conv2d(
219 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
220 | )
221 | self.proj_t = Patch1D(
222 | embed_dim, True, stride=patch_size
223 | )
224 | if layer_norm:
225 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
226 | else:
227 | self.norm = None
228 |
229 | self.patch_size = patch_size
230 | # See:
231 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
232 | self.height, self.width = height // patch_size, width // patch_size
233 | self.base_size = height // patch_size
234 | self.interpolation_scale = interpolation_scale
235 | pos_embed = get_2d_sincos_pos_embed(
236 | embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
237 | )
238 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
239 |
240 | def forward(self, latent):
241 | height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
242 | b, c, f, h, w = latent.size()
243 | latent = rearrange(latent, "b c f h w -> (b f) c h w")
244 | latent = self.proj(latent)
245 | latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f)
246 |
247 | latent = rearrange(latent, "b c f h w -> (b h w) c f")
248 | latent = self.proj_t(latent)
249 | latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2)
250 |
251 | latent = rearrange(latent, "b c f h w -> (b f) c h w")
252 | if self.flatten:
253 | latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
254 | if self.layer_norm:
255 | latent = self.norm(latent)
256 |
257 | # Interpolate positional embeddings if needed.
258 | # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
259 | if self.height != height or self.width != width:
260 | pos_embed = get_2d_sincos_pos_embed(
261 | embed_dim=self.pos_embed.shape[-1],
262 | grid_size=(height, width),
263 | base_size=self.base_size,
264 | interpolation_scale=self.interpolation_scale,
265 | )
266 | pos_embed = torch.from_numpy(pos_embed)
267 | pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
268 | else:
269 | pos_embed = self.pos_embed
270 |
271 | return (latent + pos_embed).to(latent.dtype)
--------------------------------------------------------------------------------
/easyanimate/models/transformer2d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import json
15 | import os
16 | from dataclasses import dataclass
17 | from typing import Any, Dict, Optional
18 |
19 | import numpy as np
20 | import torch
21 | import torch.nn.functional as F
22 | import torch.nn.init as init
23 | from diffusers.configuration_utils import ConfigMixin, register_to_config
24 | from diffusers.models.attention import BasicTransformerBlock
25 | from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
26 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
27 | from diffusers.models.modeling_utils import ModelMixin
28 | from diffusers.models.normalization import AdaLayerNormSingle
29 | from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
30 | is_torch_version)
31 | from einops import rearrange
32 | from torch import nn
33 |
34 | try:
35 | from diffusers.models.embeddings import PixArtAlphaTextProjection
36 | except:
37 | from diffusers.models.embeddings import \
38 | CaptionProjection as PixArtAlphaTextProjection
39 |
40 |
41 | @dataclass
42 | class Transformer2DModelOutput(BaseOutput):
43 | """
44 | The output of [`Transformer2DModel`].
45 |
46 | Args:
47 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
48 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
49 | distributions for the unnoised latent pixels.
50 | """
51 |
52 | sample: torch.FloatTensor
53 |
54 |
55 | class Transformer2DModel(ModelMixin, ConfigMixin):
56 | """
57 | A 2D Transformer model for image-like data.
58 |
59 | Parameters:
60 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
61 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
62 | in_channels (`int`, *optional*):
63 | The number of channels in the input and output (specify if the input is **continuous**).
64 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
65 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
67 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
68 | This is fixed during training since it is used to learn a number of position embeddings.
69 | num_vector_embeds (`int`, *optional*):
70 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
71 | Includes the class for the masked latent pixel.
72 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
73 | num_embeds_ada_norm ( `int`, *optional*):
74 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is
75 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
76 | added to the hidden states.
77 |
78 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
79 | attention_bias (`bool`, *optional*):
80 | Configure if the `TransformerBlocks` attention should contain a bias parameter.
81 | """
82 | _supports_gradient_checkpointing = True
83 |
84 | @register_to_config
85 | def __init__(
86 | self,
87 | num_attention_heads: int = 16,
88 | attention_head_dim: int = 88,
89 | in_channels: Optional[int] = None,
90 | out_channels: Optional[int] = None,
91 | num_layers: int = 1,
92 | dropout: float = 0.0,
93 | norm_num_groups: int = 32,
94 | cross_attention_dim: Optional[int] = None,
95 | attention_bias: bool = False,
96 | sample_size: Optional[int] = None,
97 | num_vector_embeds: Optional[int] = None,
98 | patch_size: Optional[int] = None,
99 | activation_fn: str = "geglu",
100 | num_embeds_ada_norm: Optional[int] = None,
101 | use_linear_projection: bool = False,
102 | only_cross_attention: bool = False,
103 | double_self_attention: bool = False,
104 | upcast_attention: bool = False,
105 | norm_type: str = "layer_norm",
106 | norm_elementwise_affine: bool = True,
107 | norm_eps: float = 1e-5,
108 | attention_type: str = "default",
109 | caption_channels: int = None,
110 | ):
111 | super().__init__()
112 | self.use_linear_projection = use_linear_projection
113 | self.num_attention_heads = num_attention_heads
114 | self.attention_head_dim = attention_head_dim
115 | inner_dim = num_attention_heads * attention_head_dim
116 |
117 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
118 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
119 |
120 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
121 | # Define whether input is continuous or discrete depending on configuration
122 | self.is_input_continuous = (in_channels is not None) and (patch_size is None)
123 | self.is_input_vectorized = num_vector_embeds is not None
124 | self.is_input_patches = in_channels is not None and patch_size is not None
125 |
126 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
127 | deprecation_message = (
128 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
129 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
130 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
131 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
132 | " would be very nice if you could open a Pull request for the `transformer/config.json` file"
133 | )
134 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
135 | norm_type = "ada_norm"
136 |
137 | if self.is_input_continuous and self.is_input_vectorized:
138 | raise ValueError(
139 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
140 | " sure that either `in_channels` or `num_vector_embeds` is None."
141 | )
142 | elif self.is_input_vectorized and self.is_input_patches:
143 | raise ValueError(
144 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
145 | " sure that either `num_vector_embeds` or `num_patches` is None."
146 | )
147 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
148 | raise ValueError(
149 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
150 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
151 | )
152 |
153 | # 2. Define input layers
154 | if self.is_input_continuous:
155 | self.in_channels = in_channels
156 |
157 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
158 | if use_linear_projection:
159 | self.proj_in = linear_cls(in_channels, inner_dim)
160 | else:
161 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
162 | elif self.is_input_vectorized:
163 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
164 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
165 |
166 | self.height = sample_size
167 | self.width = sample_size
168 | self.num_vector_embeds = num_vector_embeds
169 | self.num_latent_pixels = self.height * self.width
170 |
171 | self.latent_image_embedding = ImagePositionalEmbeddings(
172 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
173 | )
174 | elif self.is_input_patches:
175 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
176 |
177 | self.height = sample_size
178 | self.width = sample_size
179 |
180 | self.patch_size = patch_size
181 | interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
182 | interpolation_scale = max(interpolation_scale, 1)
183 | self.pos_embed = PatchEmbed(
184 | height=sample_size,
185 | width=sample_size,
186 | patch_size=patch_size,
187 | in_channels=in_channels,
188 | embed_dim=inner_dim,
189 | interpolation_scale=interpolation_scale,
190 | )
191 |
192 | # 3. Define transformers blocks
193 | self.transformer_blocks = nn.ModuleList(
194 | [
195 | BasicTransformerBlock(
196 | inner_dim,
197 | num_attention_heads,
198 | attention_head_dim,
199 | dropout=dropout,
200 | cross_attention_dim=cross_attention_dim,
201 | activation_fn=activation_fn,
202 | num_embeds_ada_norm=num_embeds_ada_norm,
203 | attention_bias=attention_bias,
204 | only_cross_attention=only_cross_attention,
205 | double_self_attention=double_self_attention,
206 | upcast_attention=upcast_attention,
207 | norm_type=norm_type,
208 | norm_elementwise_affine=norm_elementwise_affine,
209 | norm_eps=norm_eps,
210 | attention_type=attention_type,
211 | )
212 | for d in range(num_layers)
213 | ]
214 | )
215 |
216 | # 4. Define output layers
217 | self.out_channels = in_channels if out_channels is None else out_channels
218 | if self.is_input_continuous:
219 | # TODO: should use out_channels for continuous projections
220 | if use_linear_projection:
221 | self.proj_out = linear_cls(inner_dim, in_channels)
222 | else:
223 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
224 | elif self.is_input_vectorized:
225 | self.norm_out = nn.LayerNorm(inner_dim)
226 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
227 | elif self.is_input_patches and norm_type != "ada_norm_single":
228 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
229 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
230 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
231 | elif self.is_input_patches and norm_type == "ada_norm_single":
232 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
233 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
234 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
235 |
236 | # 5. PixArt-Alpha blocks.
237 | self.adaln_single = None
238 | self.use_additional_conditions = False
239 | if norm_type == "ada_norm_single":
240 | self.use_additional_conditions = self.config.sample_size == 128
241 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
242 | # additional conditions until we find better name
243 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
244 |
245 | self.caption_projection = None
246 | if caption_channels is not None:
247 | self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
248 |
249 | self.gradient_checkpointing = False
250 |
251 | def _set_gradient_checkpointing(self, module, value=False):
252 | if hasattr(module, "gradient_checkpointing"):
253 | module.gradient_checkpointing = value
254 |
255 | def forward(
256 | self,
257 | hidden_states: torch.Tensor,
258 | encoder_hidden_states: Optional[torch.Tensor] = None,
259 | timestep: Optional[torch.LongTensor] = None,
260 | added_cond_kwargs: Dict[str, torch.Tensor] = None,
261 | class_labels: Optional[torch.LongTensor] = None,
262 | cross_attention_kwargs: Dict[str, Any] = None,
263 | attention_mask: Optional[torch.Tensor] = None,
264 | encoder_attention_mask: Optional[torch.Tensor] = None,
265 | return_dict: bool = True,
266 | ):
267 | """
268 | The [`Transformer2DModel`] forward method.
269 |
270 | Args:
271 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
272 | Input `hidden_states`.
273 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
274 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
275 | self-attention.
276 | timestep ( `torch.LongTensor`, *optional*):
277 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
278 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
279 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
280 | `AdaLayerZeroNorm`.
281 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
282 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
283 | `self.processor` in
284 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
285 | attention_mask ( `torch.Tensor`, *optional*):
286 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
287 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
288 | negative values to the attention scores corresponding to "discard" tokens.
289 | encoder_attention_mask ( `torch.Tensor`, *optional*):
290 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
291 |
292 | * Mask `(batch, sequence_length)` True = keep, False = discard.
293 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
294 |
295 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
296 | above. This bias will be added to the cross-attention scores.
297 | return_dict (`bool`, *optional*, defaults to `True`):
298 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
299 | tuple.
300 |
301 | Returns:
302 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
303 | `tuple` where the first element is the sample tensor.
304 | """
305 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
306 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
307 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
308 | # expects mask of shape:
309 | # [batch, key_tokens]
310 | # adds singleton query_tokens dimension:
311 | # [batch, 1, key_tokens]
312 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
313 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
314 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
315 | if attention_mask is not None and attention_mask.ndim == 2:
316 | # assume that mask is expressed as:
317 | # (1 = keep, 0 = discard)
318 | # convert mask into a bias that can be added to attention scores:
319 | # (keep = +0, discard = -10000.0)
320 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
321 | attention_mask = attention_mask.unsqueeze(1)
322 |
323 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
324 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
325 | encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
326 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
327 |
328 | # Retrieve lora scale.
329 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
330 |
331 | # 1. Input
332 | if self.is_input_continuous:
333 | batch, _, height, width = hidden_states.shape
334 | residual = hidden_states
335 |
336 | hidden_states = self.norm(hidden_states)
337 | if not self.use_linear_projection:
338 | hidden_states = (
339 | self.proj_in(hidden_states, scale=lora_scale)
340 | if not USE_PEFT_BACKEND
341 | else self.proj_in(hidden_states)
342 | )
343 | inner_dim = hidden_states.shape[1]
344 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
345 | else:
346 | inner_dim = hidden_states.shape[1]
347 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
348 | hidden_states = (
349 | self.proj_in(hidden_states, scale=lora_scale)
350 | if not USE_PEFT_BACKEND
351 | else self.proj_in(hidden_states)
352 | )
353 |
354 | elif self.is_input_vectorized:
355 | hidden_states = self.latent_image_embedding(hidden_states)
356 | elif self.is_input_patches:
357 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
358 | hidden_states = self.pos_embed(hidden_states)
359 |
360 | if self.adaln_single is not None:
361 | if self.use_additional_conditions and added_cond_kwargs is None:
362 | raise ValueError(
363 | "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
364 | )
365 | batch_size = hidden_states.shape[0]
366 | timestep, embedded_timestep = self.adaln_single(
367 | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
368 | )
369 |
370 | # 2. Blocks
371 | if self.caption_projection is not None:
372 | batch_size = hidden_states.shape[0]
373 | encoder_hidden_states = self.caption_projection(encoder_hidden_states)
374 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
375 |
376 | for block in self.transformer_blocks:
377 | if self.training and self.gradient_checkpointing:
378 | hidden_states = torch.utils.checkpoint.checkpoint(
379 | block,
380 | hidden_states,
381 | attention_mask,
382 | encoder_hidden_states,
383 | encoder_attention_mask,
384 | timestep,
385 | cross_attention_kwargs,
386 | class_labels,
387 | use_reentrant=False,
388 | )
389 | else:
390 | hidden_states = block(
391 | hidden_states,
392 | attention_mask=attention_mask,
393 | encoder_hidden_states=encoder_hidden_states,
394 | encoder_attention_mask=encoder_attention_mask,
395 | timestep=timestep,
396 | cross_attention_kwargs=cross_attention_kwargs,
397 | class_labels=class_labels,
398 | )
399 |
400 | # 3. Output
401 | if self.is_input_continuous:
402 | if not self.use_linear_projection:
403 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
404 | hidden_states = (
405 | self.proj_out(hidden_states, scale=lora_scale)
406 | if not USE_PEFT_BACKEND
407 | else self.proj_out(hidden_states)
408 | )
409 | else:
410 | hidden_states = (
411 | self.proj_out(hidden_states, scale=lora_scale)
412 | if not USE_PEFT_BACKEND
413 | else self.proj_out(hidden_states)
414 | )
415 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
416 |
417 | output = hidden_states + residual
418 | elif self.is_input_vectorized:
419 | hidden_states = self.norm_out(hidden_states)
420 | logits = self.out(hidden_states)
421 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
422 | logits = logits.permute(0, 2, 1)
423 |
424 | # log(p(x_0))
425 | output = F.log_softmax(logits.double(), dim=1).float()
426 |
427 | if self.is_input_patches:
428 | if self.config.norm_type != "ada_norm_single":
429 | conditioning = self.transformer_blocks[0].norm1.emb(
430 | timestep, class_labels, hidden_dtype=hidden_states.dtype
431 | )
432 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
433 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
434 | hidden_states = self.proj_out_2(hidden_states)
435 | elif self.config.norm_type == "ada_norm_single":
436 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
437 | hidden_states = self.norm_out(hidden_states)
438 | # Modulation
439 | hidden_states = hidden_states * (1 + scale) + shift
440 | hidden_states = self.proj_out(hidden_states)
441 | hidden_states = hidden_states.squeeze(1)
442 |
443 | # unpatchify
444 | if self.adaln_single is None:
445 | height = width = int(hidden_states.shape[1] ** 0.5)
446 | hidden_states = hidden_states.reshape(
447 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
448 | )
449 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
450 | output = hidden_states.reshape(
451 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
452 | )
453 |
454 | if not return_dict:
455 | return (output,)
456 |
457 | return Transformer2DModelOutput(sample=output)
458 |
459 | @classmethod
460 | def from_pretrained(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
461 | if subfolder is not None:
462 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
463 | print(f"loaded 2D transformer's pretrained weights from {pretrained_model_path} ...")
464 |
465 | config_file = os.path.join(pretrained_model_path, 'config.json')
466 | if not os.path.isfile(config_file):
467 | raise RuntimeError(f"{config_file} does not exist")
468 | with open(config_file, "r") as f:
469 | config = json.load(f)
470 |
471 | from diffusers.utils import WEIGHTS_NAME
472 | model = cls.from_config(config, **transformer_additional_kwargs)
473 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
474 | model_file_safetensors = model_file.replace(".bin", ".safetensors")
475 | if os.path.exists(model_file_safetensors):
476 | from safetensors.torch import load_file, safe_open
477 | state_dict = load_file(model_file_safetensors)
478 | else:
479 | if not os.path.isfile(model_file):
480 | raise RuntimeError(f"{model_file} does not exist")
481 | state_dict = torch.load(model_file, map_location="cpu")
482 |
483 | if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
484 | new_shape = model.state_dict()['pos_embed.proj.weight'].size()
485 | state_dict['pos_embed.proj.weight'] = torch.tile(state_dict['proj_out.weight'], [1, 2, 1, 1])
486 |
487 | if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
488 | new_shape = model.state_dict()['proj_out.weight'].size()
489 | state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
490 |
491 | if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
492 | new_shape = model.state_dict()['proj_out.bias'].size()
493 | state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
494 |
495 | tmp_state_dict = {}
496 | for key in state_dict:
497 | if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
498 | tmp_state_dict[key] = state_dict[key]
499 | else:
500 | print(key, "Size don't match, skip")
501 | state_dict = tmp_state_dict
502 |
503 | m, u = model.load_state_dict(state_dict, strict=False)
504 | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
505 |
506 | params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
507 | print(f"### Postion Parameters: {sum(params) / 1e6} M")
508 |
509 | return model
--------------------------------------------------------------------------------
/easyanimate/ui/ui.py:
--------------------------------------------------------------------------------
1 | """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2 | """
3 | import gc
4 | import json
5 | import os
6 | import random
7 | from datetime import datetime
8 | from glob import glob
9 |
10 | import gradio as gr
11 | import torch
12 | from diffusers import (AutoencoderKL, DDIMScheduler,
13 | DPMSolverMultistepScheduler,
14 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
15 | PNDMScheduler)
16 | from diffusers.utils.import_utils import is_xformers_available
17 | from omegaconf import OmegaConf
18 | from safetensors import safe_open
19 | from transformers import T5EncoderModel, T5Tokenizer
20 |
21 | from easyanimate.models.transformer3d import Transformer3DModel
22 | from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
23 | from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
24 | from easyanimate.utils.utils import save_videos_grid
25 |
26 | sample_idx = 0
27 | scheduler_dict = {
28 | "Euler": EulerDiscreteScheduler,
29 | "Euler A": EulerAncestralDiscreteScheduler,
30 | "DPM++": DPMSolverMultistepScheduler,
31 | "PNDM": PNDMScheduler,
32 | "DDIM": DDIMScheduler,
33 | }
34 |
35 | css = """
36 | .toolbutton {
37 | margin-buttom: 0em 0em 0em 0em;
38 | max-width: 2.5em;
39 | min-width: 2.5em !important;
40 | height: 2.5em;
41 | }
42 | """
43 |
44 | class EasyAnimateController:
45 | def __init__(self):
46 | # config dirs
47 | self.basedir = os.getcwd()
48 | self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
49 | self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
50 | self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
51 | self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
52 | self.savedir_sample = os.path.join(self.savedir, "sample")
53 | os.makedirs(self.savedir, exist_ok=True)
54 |
55 | self.diffusion_transformer_list = []
56 | self.motion_module_list = []
57 | self.personalized_model_list = []
58 |
59 | self.refresh_diffusion_transformer()
60 | self.refresh_motion_module()
61 | self.refresh_personalized_model()
62 |
63 | # config models
64 | self.tokenizer = None
65 | self.text_encoder = None
66 | self.vae = None
67 | self.transformer = None
68 | self.pipeline = None
69 | self.lora_model_path = "none"
70 |
71 | self.weight_dtype = torch.float16
72 | self.inference_config = OmegaConf.load("config/easyanimate_video_motion_module_v1.yaml")
73 |
74 | def refresh_diffusion_transformer(self):
75 | self.diffusion_transformer_list = glob(os.path.join(self.diffusion_transformer_dir, "*/"))
76 |
77 | def refresh_motion_module(self):
78 | motion_module_list = glob(os.path.join(self.motion_module_dir, "*.safetensors"))
79 | self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
80 |
81 | def refresh_personalized_model(self):
82 | personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
83 | self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
84 |
85 | def update_diffusion_transformer(self, diffusion_transformer_dropdown):
86 | print("Update diffusion transformer")
87 | if diffusion_transformer_dropdown == "none":
88 | return gr.Dropdown.update()
89 | self.vae = AutoencoderKL.from_pretrained(diffusion_transformer_dropdown, subfolder="vae", torch_dtype = self.weight_dtype)
90 | self.transformer = Transformer3DModel.from_pretrained_2d(
91 | diffusion_transformer_dropdown,
92 | subfolder="transformer",
93 | transformer_additional_kwargs=OmegaConf.to_container(self.inference_config.transformer_additional_kwargs)
94 | ).to(self.weight_dtype)
95 | self.tokenizer = T5Tokenizer.from_pretrained(diffusion_transformer_dropdown, subfolder="tokenizer")
96 | self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype = self.weight_dtype)
97 | print("Update diffusion transformer done")
98 | return gr.Dropdown.update()
99 |
100 | def update_motion_module(self, motion_module_dropdown):
101 | print("Update motion module")
102 | if motion_module_dropdown == "none":
103 | return gr.Dropdown.update()
104 | if self.transformer is None:
105 | gr.Info(f"Please select a pretrained model path.")
106 | return gr.Dropdown.update(value=None)
107 | else:
108 | motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
109 | if motion_module_dropdown.endswith(".safetensors"):
110 | from safetensors.torch import load_file, safe_open
111 | motion_module_state_dict = load_file(motion_module_dropdown)
112 | else:
113 | if not os.path.isfile(motion_module_dropdown):
114 | raise RuntimeError(f"{motion_module_dropdown} does not exist")
115 | motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
116 | missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
117 | assert len(unexpected) == 0
118 | print("Update motion module done")
119 | return gr.Dropdown.update()
120 |
121 | def update_base_model(self, base_model_dropdown):
122 | print("Update base model")
123 | if base_model_dropdown == "none":
124 | return gr.Dropdown.update()
125 | if self.transformer is None:
126 | gr.Info(f"Please select a pretrained model path.")
127 | return gr.Dropdown.update(value=None)
128 | else:
129 | base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
130 | base_model_state_dict = {}
131 | with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
132 | for key in f.keys():
133 | base_model_state_dict[key] = f.get_tensor(key)
134 | self.transformer.load_state_dict(base_model_state_dict, strict=False)
135 | print("Update base done")
136 | return gr.Dropdown.update()
137 |
138 | def update_lora_model(self, lora_model_dropdown):
139 | lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
140 | self.lora_model_path = lora_model_dropdown
141 | return gr.Dropdown.update()
142 |
143 | def generate(
144 | self,
145 | diffusion_transformer_dropdown,
146 | motion_module_dropdown,
147 | base_model_dropdown,
148 | lora_alpha_slider,
149 | prompt_textbox,
150 | negative_prompt_textbox,
151 | sampler_dropdown,
152 | sample_step_slider,
153 | width_slider,
154 | length_slider,
155 | height_slider,
156 | cfg_scale_slider,
157 | seed_textbox
158 | ):
159 | global sample_idx
160 | if self.transformer is None:
161 | raise gr.Error(f"Please select a pretrained model path.")
162 | if motion_module_dropdown == "":
163 | raise gr.Error(f"Please select a motion module.")
164 |
165 | if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
166 |
167 | pipeline = EasyAnimatePipeline(
168 | vae=self.vae,
169 | text_encoder=self.text_encoder,
170 | tokenizer=self.tokenizer,
171 | transformer=self.transformer,
172 | scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
173 | )
174 | if self.lora_model_path != "none":
175 | # lora part
176 | pipeline = merge_lora(pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
177 |
178 | pipeline.to("cuda")
179 |
180 | if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
181 | else: torch.seed()
182 | seed = torch.initial_seed()
183 |
184 | try:
185 | sample = pipeline(
186 | prompt_textbox,
187 | negative_prompt = negative_prompt_textbox,
188 | num_inference_steps = sample_step_slider,
189 | guidance_scale = cfg_scale_slider,
190 | width = width_slider,
191 | height = height_slider,
192 | video_length = length_slider,
193 | ).videos
194 | except Exception as e:
195 | # lora part
196 | gc.collect()
197 | torch.cuda.empty_cache()
198 | torch.cuda.ipc_collect()
199 | if self.lora_model_path != "none":
200 | pipeline = unmerge_lora(pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
201 | return gr.Video.update()
202 |
203 | # lora part
204 | if self.lora_model_path != "none":
205 | pipeline = unmerge_lora(pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
206 |
207 | save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
208 | sample_idx += 1
209 | save_videos_grid(sample, save_sample_path, fps=12)
210 |
211 | sample_config = {
212 | "prompt": prompt_textbox,
213 | "n_prompt": negative_prompt_textbox,
214 | "sampler": sampler_dropdown,
215 | "num_inference_steps": sample_step_slider,
216 | "guidance_scale": cfg_scale_slider,
217 | "width": width_slider,
218 | "height": height_slider,
219 | "video_length": length_slider,
220 | "seed": seed
221 | }
222 | json_str = json.dumps(sample_config, indent=4)
223 | with open(os.path.join(self.savedir, "logs.json"), "a") as f:
224 | f.write(json_str)
225 | f.write("\n\n")
226 |
227 | return gr.Video.update(value=save_sample_path)
228 |
229 |
230 | def ui():
231 | controller = EasyAnimateController()
232 |
233 | with gr.Blocks(css=css) as demo:
234 | gr.Markdown(
235 | """
236 | # EasyAnimate: Generate your animation easily
237 | [Github](https://github.com/aigc-apps/EasyAnimate/)
238 | """
239 | )
240 | with gr.Column(variant="panel"):
241 | gr.Markdown(
242 | """
243 | ### 1. Model checkpoints (select pretrained model path first).
244 | """
245 | )
246 | with gr.Row():
247 | diffusion_transformer_dropdown = gr.Dropdown(
248 | label="Pretrained Model Path",
249 | choices=controller.diffusion_transformer_list,
250 | value="none",
251 | interactive=True,
252 | )
253 | diffusion_transformer_dropdown.change(
254 | fn=controller.update_diffusion_transformer,
255 | inputs=[diffusion_transformer_dropdown],
256 | outputs=[diffusion_transformer_dropdown]
257 | )
258 |
259 | diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
260 | def refresh_diffusion_transformer():
261 | controller.refresh_diffusion_transformer()
262 | return gr.Dropdown.update(choices=controller.diffusion_transformer_list)
263 | diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
264 |
265 | with gr.Row():
266 | motion_module_dropdown = gr.Dropdown(
267 | label="Select motion module",
268 | choices=controller.motion_module_list,
269 | value="none",
270 | interactive=True,
271 | )
272 | motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
273 |
274 | motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
275 | def update_motion_module():
276 | controller.refresh_motion_module()
277 | return gr.Dropdown.update(choices=controller.motion_module_list)
278 | motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
279 |
280 | base_model_dropdown = gr.Dropdown(
281 | label="Select base Dreambooth model (optional)",
282 | choices=controller.personalized_model_list,
283 | value="none",
284 | interactive=True,
285 | )
286 | base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
287 |
288 | lora_model_dropdown = gr.Dropdown(
289 | label="Select LoRA model (optional)",
290 | choices=["none"] + controller.personalized_model_list,
291 | value="none",
292 | interactive=True,
293 | )
294 | lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
295 |
296 | lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
297 |
298 | personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
299 | def update_personalized_model():
300 | controller.refresh_personalized_model()
301 | return [
302 | gr.Dropdown.update(choices=controller.personalized_model_list),
303 | gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
304 | ]
305 | personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
306 |
307 | with gr.Column(variant="panel"):
308 | gr.Markdown(
309 | """
310 | ### 2. Configs for Generation.
311 | """
312 | )
313 |
314 | prompt_textbox = gr.Textbox(label="Prompt", lines=2)
315 | negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly" )
316 |
317 | with gr.Row().style(equal_height=False):
318 | with gr.Column():
319 | with gr.Row():
320 | sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
321 | sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
322 |
323 | width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
324 | height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
325 | length_slider = gr.Slider(label="Animation length", value=80, minimum=16, maximum=96, step=1)
326 | cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
327 |
328 | with gr.Row():
329 | seed_textbox = gr.Textbox(label="Seed", value=-1)
330 | seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
331 | seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
332 |
333 | generate_button = gr.Button(value="Generate", variant='primary')
334 |
335 | result_video = gr.Video(label="Generated Animation", interactive=False)
336 |
337 | generate_button.click(
338 | fn=controller.generate,
339 | inputs=[
340 | diffusion_transformer_dropdown,
341 | motion_module_dropdown,
342 | base_model_dropdown,
343 | lora_alpha_slider,
344 | prompt_textbox,
345 | negative_prompt_textbox,
346 | sampler_dropdown,
347 | sample_step_slider,
348 | width_slider,
349 | length_slider,
350 | height_slider,
351 | cfg_scale_slider,
352 | seed_textbox,
353 | ],
354 | outputs=[result_video]
355 | )
356 | return demo
357 |
--------------------------------------------------------------------------------
/easyanimate/utils/IDDIM.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 | from . import gaussian_diffusion as gd
6 | from .respace import SpacedDiffusion, space_timesteps
7 |
8 |
9 | def IDDPM(
10 | timestep_respacing,
11 | noise_schedule="linear",
12 | use_kl=False,
13 | sigma_small=False,
14 | predict_xstart=False,
15 | learn_sigma=True,
16 | pred_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000,
19 | snr=False,
20 | return_startx=False,
21 | ):
22 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
23 | if use_kl:
24 | loss_type = gd.LossType.RESCALED_KL
25 | elif rescale_learned_sigmas:
26 | loss_type = gd.LossType.RESCALED_MSE
27 | else:
28 | loss_type = gd.LossType.MSE
29 | if timestep_respacing is None or timestep_respacing == "":
30 | timestep_respacing = [diffusion_steps]
31 | return SpacedDiffusion(
32 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
33 | betas=betas,
34 | model_mean_type=(
35 | gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON
36 | ),
37 | model_var_type=(
38 | (gd.ModelVarType.LEARNED_RANGE if learn_sigma else (
39 | gd.ModelVarType.FIXED_LARGE
40 | if not sigma_small
41 | else gd.ModelVarType.FIXED_SMALL
42 | )
43 | )
44 | if pred_sigma
45 | else None
46 | ),
47 | loss_type=loss_type,
48 | snr=snr,
49 | return_startx=return_startx,
50 | # rescale_timesteps=rescale_timesteps,
51 | )
--------------------------------------------------------------------------------
/easyanimate/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-EasyAnimate/9bef69d1ceda9d300613488517af6cc66cf5c360/easyanimate/utils/__init__.py
--------------------------------------------------------------------------------
/easyanimate/utils/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = next(
17 | (
18 | obj
19 | for obj in (mean1, logvar1, mean2, logvar2)
20 | if isinstance(obj, th.Tensor)
21 | ),
22 | None,
23 | )
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a continuous Gaussian distribution.
53 | :param x: the targets
54 | :param means: the Gaussian mean Tensor.
55 | :param log_scales: the Gaussian log stddev Tensor.
56 | :return: a tensor like x of log probabilities (in nats).
57 | """
58 | centered_x = x - means
59 | inv_stdv = th.exp(-log_scales)
60 | normalized_x = centered_x * inv_stdv
61 | return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
62 | normalized_x
63 | )
64 |
65 |
66 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
67 | """
68 | Compute the log-likelihood of a Gaussian distribution discretizing to a
69 | given image.
70 | :param x: the target images. It is assumed that this was uint8 values,
71 | rescaled to the range [-1, 1].
72 | :param means: the Gaussian mean Tensor.
73 | :param log_scales: the Gaussian log stddev Tensor.
74 | :return: a tensor like x of log probabilities (in nats).
75 | """
76 | assert x.shape == means.shape == log_scales.shape
77 | centered_x = x - means
78 | inv_stdv = th.exp(-log_scales)
79 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
80 | cdf_plus = approx_standard_normal_cdf(plus_in)
81 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
82 | cdf_min = approx_standard_normal_cdf(min_in)
83 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
84 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
85 | cdf_delta = cdf_plus - cdf_min
86 | log_probs = th.where(
87 | x < -0.999,
88 | log_cdf_plus,
89 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
90 | )
91 | assert log_probs.shape == x.shape
92 | return log_probs
--------------------------------------------------------------------------------
/easyanimate/utils/lora_utils.py:
--------------------------------------------------------------------------------
1 | # LoRA network module
2 | # reference:
3 | # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4 | # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5 | # https://github.com/bmaltais/kohya_ss
6 |
7 | import hashlib
8 | import math
9 | import os
10 | from collections import defaultdict
11 | from io import BytesIO
12 | from typing import List, Optional, Type, Union
13 |
14 | import safetensors.torch
15 | import torch
16 | import torch.utils.checkpoint
17 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
18 | from safetensors.torch import load_file
19 | from transformers import T5EncoderModel
20 |
21 |
22 | class LoRAModule(torch.nn.Module):
23 | """
24 | replaces forward method of the original Linear, instead of replacing the original Linear module.
25 | """
26 |
27 | def __init__(
28 | self,
29 | lora_name,
30 | org_module: torch.nn.Module,
31 | multiplier=1.0,
32 | lora_dim=4,
33 | alpha=1,
34 | dropout=None,
35 | rank_dropout=None,
36 | module_dropout=None,
37 | ):
38 | """if alpha == 0 or None, alpha is rank (no scaling)."""
39 | super().__init__()
40 | self.lora_name = lora_name
41 |
42 | if org_module.__class__.__name__ == "Conv2d":
43 | in_dim = org_module.in_channels
44 | out_dim = org_module.out_channels
45 | else:
46 | in_dim = org_module.in_features
47 | out_dim = org_module.out_features
48 |
49 | self.lora_dim = lora_dim
50 | if org_module.__class__.__name__ == "Conv2d":
51 | kernel_size = org_module.kernel_size
52 | stride = org_module.stride
53 | padding = org_module.padding
54 | self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
55 | self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
56 | else:
57 | self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
58 | self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
59 |
60 | if type(alpha) == torch.Tensor:
61 | alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
62 | alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
63 | self.scale = alpha / self.lora_dim
64 | self.register_buffer("alpha", torch.tensor(alpha))
65 |
66 | # same as microsoft's
67 | torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68 | torch.nn.init.zeros_(self.lora_up.weight)
69 |
70 | self.multiplier = multiplier
71 | self.org_module = org_module # remove in applying
72 | self.dropout = dropout
73 | self.rank_dropout = rank_dropout
74 | self.module_dropout = module_dropout
75 |
76 | def apply_to(self):
77 | self.org_forward = self.org_module.forward
78 | self.org_module.forward = self.forward
79 | del self.org_module
80 |
81 | def forward(self, x, *args, **kwargs):
82 | weight_dtype = x.dtype
83 | org_forwarded = self.org_forward(x)
84 |
85 | # module dropout
86 | if self.module_dropout is not None and self.training:
87 | if torch.rand(1) < self.module_dropout:
88 | return org_forwarded
89 |
90 | lx = self.lora_down(x.to(self.lora_down.weight.dtype))
91 |
92 | # normal dropout
93 | if self.dropout is not None and self.training:
94 | lx = torch.nn.functional.dropout(lx, p=self.dropout)
95 |
96 | # rank dropout
97 | if self.rank_dropout is not None and self.training:
98 | mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
99 | if len(lx.size()) == 3:
100 | mask = mask.unsqueeze(1) # for Text Encoder
101 | elif len(lx.size()) == 4:
102 | mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
103 | lx = lx * mask
104 |
105 | # scaling for rank dropout: treat as if the rank is changed
106 | scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
107 | else:
108 | scale = self.scale
109 |
110 | lx = self.lora_up(lx)
111 |
112 | return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
113 |
114 |
115 | def addnet_hash_legacy(b):
116 | """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
117 | m = hashlib.sha256()
118 |
119 | b.seek(0x100000)
120 | m.update(b.read(0x10000))
121 | return m.hexdigest()[0:8]
122 |
123 |
124 | def addnet_hash_safetensors(b):
125 | """New model hash used by sd-webui-additional-networks for .safetensors format files"""
126 | hash_sha256 = hashlib.sha256()
127 | blksize = 1024 * 1024
128 |
129 | b.seek(0)
130 | header = b.read(8)
131 | n = int.from_bytes(header, "little")
132 |
133 | offset = n + 8
134 | b.seek(offset)
135 | for chunk in iter(lambda: b.read(blksize), b""):
136 | hash_sha256.update(chunk)
137 |
138 | return hash_sha256.hexdigest()
139 |
140 |
141 | def precalculate_safetensors_hashes(tensors, metadata):
142 | """Precalculate the model hashes needed by sd-webui-additional-networks to
143 | save time on indexing the model later."""
144 |
145 | # Because writing user metadata to the file can change the result of
146 | # sd_models.model_hash(), only retain the training metadata for purposes of
147 | # calculating the hash, as they are meant to be immutable
148 | metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
149 |
150 | bytes = safetensors.torch.save(tensors, metadata)
151 | b = BytesIO(bytes)
152 |
153 | model_hash = addnet_hash_safetensors(b)
154 | legacy_hash = addnet_hash_legacy(b)
155 | return model_hash, legacy_hash
156 |
157 |
158 | class LoRANetwork(torch.nn.Module):
159 | TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel"]
160 | TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF"]
161 | LORA_PREFIX_TRANSFORMER = "lora_unet"
162 | LORA_PREFIX_TEXT_ENCODER = "lora_te"
163 | def __init__(
164 | self,
165 | text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
166 | unet,
167 | multiplier: float = 1.0,
168 | lora_dim: int = 4,
169 | alpha: float = 1,
170 | dropout: Optional[float] = None,
171 | module_class: Type[object] = LoRAModule,
172 | add_lora_in_attn_temporal: bool = False,
173 | varbose: Optional[bool] = False,
174 | ) -> None:
175 | super().__init__()
176 | self.multiplier = multiplier
177 |
178 | self.lora_dim = lora_dim
179 | self.alpha = alpha
180 | self.dropout = dropout
181 |
182 | print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
183 | print(f"neuron dropout: p={self.dropout}")
184 |
185 | # create module instances
186 | def create_modules(
187 | is_unet: bool,
188 | root_module: torch.nn.Module,
189 | target_replace_modules: List[torch.nn.Module],
190 | ) -> List[LoRAModule]:
191 | prefix = (
192 | self.LORA_PREFIX_TRANSFORMER
193 | if is_unet
194 | else self.LORA_PREFIX_TEXT_ENCODER
195 | )
196 | loras = []
197 | skipped = []
198 | for name, module in root_module.named_modules():
199 | if module.__class__.__name__ in target_replace_modules:
200 | for child_name, child_module in module.named_modules():
201 | is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
202 | is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
203 | is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
204 |
205 | if not add_lora_in_attn_temporal:
206 | if "attn_temporal" in child_name:
207 | continue
208 |
209 | if is_linear or is_conv2d:
210 | lora_name = prefix + "." + name + "." + child_name
211 | lora_name = lora_name.replace(".", "_")
212 |
213 | dim = None
214 | alpha = None
215 |
216 | if is_linear or is_conv2d_1x1:
217 | dim = self.lora_dim
218 | alpha = self.alpha
219 |
220 | if dim is None or dim == 0:
221 | if is_linear or is_conv2d_1x1:
222 | skipped.append(lora_name)
223 | continue
224 |
225 | lora = module_class(
226 | lora_name,
227 | child_module,
228 | self.multiplier,
229 | dim,
230 | alpha,
231 | dropout=dropout,
232 | )
233 | loras.append(lora)
234 | return loras, skipped
235 |
236 | text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
237 |
238 | self.text_encoder_loras = []
239 | skipped_te = []
240 | for i, text_encoder in enumerate(text_encoders):
241 | text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
242 | self.text_encoder_loras.extend(text_encoder_loras)
243 | skipped_te += skipped
244 | print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
245 |
246 | self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
247 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
248 |
249 | # assertion
250 | names = set()
251 | for lora in self.text_encoder_loras + self.unet_loras:
252 | assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
253 | names.add(lora.lora_name)
254 |
255 | def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
256 | if apply_text_encoder:
257 | print("enable LoRA for text encoder")
258 | else:
259 | self.text_encoder_loras = []
260 |
261 | if apply_unet:
262 | print("enable LoRA for U-Net")
263 | else:
264 | self.unet_loras = []
265 |
266 | for lora in self.text_encoder_loras + self.unet_loras:
267 | lora.apply_to()
268 | self.add_module(lora.lora_name, lora)
269 |
270 | def set_multiplier(self, multiplier):
271 | self.multiplier = multiplier
272 | for lora in self.text_encoder_loras + self.unet_loras:
273 | lora.multiplier = self.multiplier
274 |
275 | def load_weights(self, file):
276 | if os.path.splitext(file)[1] == ".safetensors":
277 | from safetensors.torch import load_file
278 |
279 | weights_sd = load_file(file)
280 | else:
281 | weights_sd = torch.load(file, map_location="cpu")
282 | info = self.load_state_dict(weights_sd, False)
283 | return info
284 |
285 | def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
286 | self.requires_grad_(True)
287 | all_params = []
288 |
289 | def enumerate_params(loras):
290 | params = []
291 | for lora in loras:
292 | params.extend(lora.parameters())
293 | return params
294 |
295 | if self.text_encoder_loras:
296 | param_data = {"params": enumerate_params(self.text_encoder_loras)}
297 | if text_encoder_lr is not None:
298 | param_data["lr"] = text_encoder_lr
299 | all_params.append(param_data)
300 |
301 | if self.unet_loras:
302 | param_data = {"params": enumerate_params(self.unet_loras)}
303 | if unet_lr is not None:
304 | param_data["lr"] = unet_lr
305 | all_params.append(param_data)
306 |
307 | return all_params
308 |
309 | def enable_gradient_checkpointing(self):
310 | pass
311 |
312 | def get_trainable_params(self):
313 | return self.parameters()
314 |
315 | def save_weights(self, file, dtype, metadata):
316 | if metadata is not None and len(metadata) == 0:
317 | metadata = None
318 |
319 | state_dict = self.state_dict()
320 |
321 | if dtype is not None:
322 | for key in list(state_dict.keys()):
323 | v = state_dict[key]
324 | v = v.detach().clone().to("cpu").to(dtype)
325 | state_dict[key] = v
326 |
327 | if os.path.splitext(file)[1] == ".safetensors":
328 | from safetensors.torch import save_file
329 |
330 | # Precalculate model hashes to save time on indexing
331 | if metadata is None:
332 | metadata = {}
333 | model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
334 | metadata["sshs_model_hash"] = model_hash
335 | metadata["sshs_legacy_hash"] = legacy_hash
336 |
337 | save_file(state_dict, file, metadata)
338 | else:
339 | torch.save(state_dict, file)
340 |
341 | def create_network(
342 | multiplier: float,
343 | network_dim: Optional[int],
344 | network_alpha: Optional[float],
345 | text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
346 | transformer,
347 | neuron_dropout: Optional[float] = None,
348 | add_lora_in_attn_temporal: bool = False,
349 | **kwargs,
350 | ):
351 | if network_dim is None:
352 | network_dim = 4 # default
353 | if network_alpha is None:
354 | network_alpha = 1.0
355 |
356 | network = LoRANetwork(
357 | text_encoder,
358 | transformer,
359 | multiplier=multiplier,
360 | lora_dim=network_dim,
361 | alpha=network_alpha,
362 | dropout=neuron_dropout,
363 | add_lora_in_attn_temporal=add_lora_in_attn_temporal,
364 | varbose=True,
365 | )
366 | return network
367 |
368 | def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
369 | LORA_PREFIX_TRANSFORMER = "lora_unet"
370 | LORA_PREFIX_TEXT_ENCODER = "lora_te"
371 | if state_dict is None:
372 | state_dict = load_file(lora_path, device=device)
373 | else:
374 | state_dict = state_dict
375 | updates = defaultdict(dict)
376 | for key, value in state_dict.items():
377 | layer, elem = key.split('.', 1)
378 | updates[layer][elem] = value
379 |
380 | for layer, elems in updates.items():
381 |
382 | if "lora_te" in layer:
383 | if transformer_only:
384 | continue
385 | else:
386 | layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
387 | curr_layer = pipeline.text_encoder
388 | else:
389 | layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
390 | curr_layer = pipeline.transformer
391 |
392 | temp_name = layer_infos.pop(0)
393 | while len(layer_infos) > -1:
394 | try:
395 | curr_layer = curr_layer.__getattr__(temp_name)
396 | if len(layer_infos) > 0:
397 | temp_name = layer_infos.pop(0)
398 | elif len(layer_infos) == 0:
399 | break
400 | except Exception:
401 | if len(layer_infos) == 0:
402 | print('Error loading layer')
403 | if len(temp_name) > 0:
404 | temp_name += "_" + layer_infos.pop(0)
405 | else:
406 | temp_name = layer_infos.pop(0)
407 |
408 | weight_up = elems['lora_up.weight'].to(dtype)
409 | weight_down = elems['lora_down.weight'].to(dtype)
410 | if 'alpha' in elems.keys():
411 | alpha = elems['alpha'].item() / weight_up.shape[1]
412 | else:
413 | alpha = 1.0
414 |
415 | curr_layer.weight.data = curr_layer.weight.data.to(device)
416 | if len(weight_up.shape) == 4:
417 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
418 | weight_down.squeeze(3).squeeze(2)).unsqueeze(
419 | 2).unsqueeze(3)
420 | else:
421 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
422 |
423 | return pipeline
424 |
425 | # TODO: Refactor with merge_lora.
426 | def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
427 | """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
428 | LORA_PREFIX_UNET = "lora_unet"
429 | LORA_PREFIX_TEXT_ENCODER = "lora_te"
430 | state_dict = load_file(lora_path, device=device)
431 |
432 | updates = defaultdict(dict)
433 | for key, value in state_dict.items():
434 | layer, elem = key.split('.', 1)
435 | updates[layer][elem] = value
436 |
437 | for layer, elems in updates.items():
438 |
439 | if "lora_te" in layer:
440 | layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
441 | curr_layer = pipeline.text_encoder
442 | else:
443 | layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
444 | curr_layer = pipeline.transformer
445 |
446 | temp_name = layer_infos.pop(0)
447 | while len(layer_infos) > -1:
448 | try:
449 | curr_layer = curr_layer.__getattr__(temp_name)
450 | if len(layer_infos) > 0:
451 | temp_name = layer_infos.pop(0)
452 | elif len(layer_infos) == 0:
453 | break
454 | except Exception:
455 | if len(layer_infos) == 0:
456 | print('Error loading layer')
457 | if len(temp_name) > 0:
458 | temp_name += "_" + layer_infos.pop(0)
459 | else:
460 | temp_name = layer_infos.pop(0)
461 |
462 | weight_up = elems['lora_up.weight'].to(dtype)
463 | weight_down = elems['lora_down.weight'].to(dtype)
464 | if 'alpha' in elems.keys():
465 | alpha = elems['alpha'].item() / weight_up.shape[1]
466 | else:
467 | alpha = 1.0
468 |
469 | curr_layer.weight.data = curr_layer.weight.data.to(device)
470 | if len(weight_up.shape) == 4:
471 | curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
472 | weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
473 | else:
474 | curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
475 |
476 | return pipeline
--------------------------------------------------------------------------------
/easyanimate/utils/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
52 | cur_idx = 0.0
53 | taken_steps = []
54 | for _ in range(section_count):
55 | taken_steps.append(start_idx + round(cur_idx))
56 | cur_idx += frac_stride
57 | all_steps += taken_steps
58 | start_idx += size
59 | return set(all_steps)
60 |
61 |
62 | class SpacedDiffusion(GaussianDiffusion):
63 | """
64 | A diffusion process which can skip steps in a base diffusion process.
65 | :param use_timesteps: a collection (sequence or set) of timesteps from the
66 | original diffusion process to retain.
67 | :param kwargs: the kwargs to create the base diffusion process.
68 | """
69 |
70 | def __init__(self, use_timesteps, **kwargs):
71 | self.use_timesteps = set(use_timesteps)
72 | self.timestep_map = []
73 | self.original_num_steps = len(kwargs["betas"])
74 |
75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
76 | last_alpha_cumprod = 1.0
77 | new_betas = []
78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
79 | if i in self.use_timesteps:
80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
81 | last_alpha_cumprod = alpha_cumprod
82 | self.timestep_map.append(i)
83 | kwargs["betas"] = np.array(new_betas)
84 | super().__init__(**kwargs)
85 |
86 | def p_mean_variance(
87 | self, model, *args, **kwargs
88 | ): # pylint: disable=signature-differs
89 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
90 |
91 | def training_losses(
92 | self, model, *args, **kwargs
93 | ): # pylint: disable=signature-differs
94 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
95 |
96 | def training_losses_diffusers(
97 | self, model, *args, **kwargs
98 | ): # pylint: disable=signature-differs
99 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
100 |
101 | def condition_mean(self, cond_fn, *args, **kwargs):
102 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def condition_score(self, cond_fn, *args, **kwargs):
105 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
106 |
107 | def _wrap_model(self, model):
108 | if isinstance(model, _WrappedModel):
109 | return model
110 | return _WrappedModel(
111 | model, self.timestep_map, self.original_num_steps
112 | )
113 |
114 | def _scale_timesteps(self, t):
115 | # Scaling is done by the wrapped model.
116 | return t
117 |
118 |
119 | class _WrappedModel:
120 | def __init__(self, model, timestep_map, original_num_steps):
121 | self.model = model
122 | self.timestep_map = timestep_map
123 | # self.rescale_timesteps = rescale_timesteps
124 | self.original_num_steps = original_num_steps
125 |
126 | def __call__(self, x, timestep, **kwargs):
127 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
128 | new_ts = map_tensor[timestep]
129 | # if self.rescale_timesteps:
130 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
131 | return self.model(x, timestep=new_ts, **kwargs)
--------------------------------------------------------------------------------
/easyanimate/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import imageio
4 | import numpy as np
5 | import torch
6 | import torchvision
7 | from einops import rearrange
8 | from PIL import Image
9 |
10 |
11 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True):
12 | videos = rearrange(videos, "b c t h w -> t b c h w")
13 | outputs = []
14 | for x in videos:
15 | x = torchvision.utils.make_grid(x, nrow=n_rows)
16 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
17 | if rescale:
18 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
19 | x = (x * 255).numpy().astype(np.uint8)
20 | outputs.append(Image.fromarray(x))
21 |
22 | os.makedirs(os.path.dirname(path), exist_ok=True)
23 | if imageio_backend:
24 | if path.endswith("mp4"):
25 | imageio.mimsave(path, outputs, fps=fps)
26 | else:
27 | imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
28 | else:
29 | if path.endswith("mp4"):
30 | path = path.replace('.mp4', '.gif')
31 | outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
32 |
--------------------------------------------------------------------------------
/models/put models here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-EasyAnimate/9bef69d1ceda9d300613488517af6cc66cf5c360/models/put models here.txt
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import folder_paths
4 | comfy_path = os.path.dirname(folder_paths.__file__)
5 |
6 | import sys
7 | easyanimate_path=f'{comfy_path}/custom_nodes/ComfyUI-EasyAnimate'
8 | sys.path.insert(0,easyanimate_path)
9 |
10 | import torch
11 | from diffusers import (AutoencoderKL, DDIMScheduler,
12 | DPMSolverMultistepScheduler,
13 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
14 | PNDMScheduler)
15 | from omegaconf import OmegaConf
16 |
17 | from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
18 | from easyanimate.models.transformer3d import Transformer3DModel
19 | from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
20 | from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
21 | from easyanimate.utils.utils import save_videos_grid
22 | from einops import rearrange
23 |
24 | checkpoints=['None']
25 | checkpoints.extend(folder_paths.get_filename_list("checkpoints"))
26 | vaes=['None']
27 | vaes.extend(folder_paths.get_filename_list("vae"))
28 |
29 | class EasyAnimateLoader:
30 | @classmethod
31 | def INPUT_TYPES(cls):
32 | return {
33 | "required": {
34 | "pixart_path": (os.listdir(folder_paths.get_folder_paths("diffusers")[0]), {"default": "PixArt-XL-2-512x512"}),
35 | "motion_ckpt": (folder_paths.get_filename_list("checkpoints"), {"default": "easyanimate_v1_mm.safetensors"}),
36 | "sampler_name": (["Euler","Euler A","DPM++","PNDM","DDIM"],{"default":"DPM++"}),
37 | "device":(["cuda","cpu"],{"default":"cuda"}),
38 | },
39 | "optional": {
40 | "transformer_ckpt": (checkpoints, {"default": 'None'}),
41 | "lora_ckpt": (checkpoints, {"default": 'None'}),
42 | "vae_ckpt": (vaes, {"default": 'None'}),
43 | "lora_weight": ("FLOAT", {"default": 0.55, "min": 0, "max": 1, "step": 0.01}),
44 | }
45 | }
46 |
47 | RETURN_TYPES = ("EasyAnimateModel",)
48 | FUNCTION = "run"
49 | CATEGORY = "EasyAnimate"
50 |
51 | def run(self,pixart_path,motion_ckpt,sampler_name,device,transformer_ckpt='None',lora_ckpt='None',vae_ckpt='None',lora_weight=0.55):
52 | pixart_path=os.path.join(folder_paths.get_folder_paths("diffusers")[0],pixart_path)
53 | # Config and model path
54 | config_path = f"{easyanimate_path}/config/easyanimate_video_motion_module_v1.yaml"
55 | model_name = pixart_path
56 | #model_name = "models/Diffusion_Transformer/PixArt-XL-2-512x512"
57 |
58 | # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
59 | sampler_name = "DPM++"
60 |
61 | # Load pretrained model if need
62 | transformer_path = None
63 | if transformer_ckpt!='None':
64 | transformer_path = folder_paths.get_full_path("checkpoints", transformer_ckpt)
65 | motion_module_path = folder_paths.get_full_path("checkpoints", motion_ckpt)
66 | #motion_module_path = "models/Motion_Module/easyanimate_v1_mm.safetensors"
67 | vae_path = None
68 | if vae_ckpt!='None':
69 | vae_path = folder_paths.get_full_path("vae", vae_ckpt)
70 | lora_path = None
71 | if lora_ckpt!='None':
72 | lora_path = folder_paths.get_full_path("checkpoints", lora_ckpt)
73 |
74 | weight_dtype = torch.float16
75 | guidance_scale = 6.0
76 | seed = 43
77 | num_inference_steps = 30
78 | #lora_weight = 0.55
79 |
80 | config = OmegaConf.load(config_path)
81 |
82 | # Get Transformer
83 | transformer = Transformer3DModel.from_pretrained_2d(
84 | model_name,
85 | subfolder="transformer",
86 | transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs'])
87 | ).to(weight_dtype)
88 |
89 | if transformer_path is not None:
90 | print(f"From checkpoint: {transformer_path}")
91 | if transformer_path.endswith("safetensors"):
92 | from safetensors.torch import load_file, safe_open
93 | state_dict = load_file(transformer_path)
94 | else:
95 | state_dict = torch.load(transformer_path, map_location="cpu")
96 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
97 |
98 | m, u = transformer.load_state_dict(state_dict, strict=False)
99 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
100 |
101 | if motion_module_path is not None:
102 | print(f"From Motion Module: {motion_module_path}")
103 | if motion_module_path.endswith("safetensors"):
104 | from safetensors.torch import load_file, safe_open
105 | state_dict = load_file(motion_module_path)
106 | else:
107 | state_dict = torch.load(motion_module_path, map_location="cpu")
108 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
109 |
110 | m, u = transformer.load_state_dict(state_dict, strict=False)
111 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}, {u}")
112 |
113 | # Get Vae
114 | if OmegaConf.to_container(config['vae_kwargs'])['enable_magvit']:
115 | Choosen_AutoencoderKL = AutoencoderKLMagvit
116 | else:
117 | Choosen_AutoencoderKL = AutoencoderKL
118 | vae = Choosen_AutoencoderKL.from_pretrained(
119 | model_name,
120 | subfolder="vae",
121 | torch_dtype=weight_dtype
122 | )
123 |
124 | if vae_path is not None:
125 | print(f"From checkpoint: {vae_path}")
126 | if vae_path.endswith("safetensors"):
127 | from safetensors.torch import load_file, safe_open
128 | state_dict = load_file(vae_path)
129 | else:
130 | state_dict = torch.load(vae_path, map_location="cpu")
131 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
132 |
133 | m, u = vae.load_state_dict(state_dict, strict=False)
134 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
135 |
136 | # Get Scheduler
137 | Choosen_Scheduler = scheduler_dict = {
138 | "Euler": EulerDiscreteScheduler,
139 | "Euler A": EulerAncestralDiscreteScheduler,
140 | "DPM++": DPMSolverMultistepScheduler,
141 | "PNDM": PNDMScheduler,
142 | "DDIM": DDIMScheduler,
143 | }[sampler_name]
144 | scheduler = Choosen_Scheduler(**OmegaConf.to_container(config['noise_scheduler_kwargs']))
145 |
146 | pipeline = EasyAnimatePipeline.from_pretrained(
147 | model_name,
148 | vae=vae,
149 | transformer=transformer,
150 | scheduler=scheduler,
151 | torch_dtype=weight_dtype
152 | )
153 | #pipeline.to(device)
154 | pipeline.enable_model_cpu_offload()
155 |
156 | pipeline.transformer.to(device)
157 | pipeline.text_encoder.to('cpu')
158 | pipeline.vae.to('cpu')
159 |
160 | if lora_path is not None:
161 | pipeline = merge_lora(pipeline, lora_path, lora_weight)
162 | return (pipeline,)
163 |
164 | class EasyAnimateRun:
165 | @classmethod
166 | def INPUT_TYPES(cls):
167 | return {
168 | "required": {
169 | "model":("EasyAnimateModel",),
170 | "prompt":("STRING",{"multiline": True, "default":"A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road."}),
171 | "negative_prompt":("STRING",{"multiline": True, "default":"Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly"}),
172 | "video_length":("INT",{"default":80}),
173 | "num_inference_steps":("INT",{"default":30}),
174 | "width":("INT",{"default":512}),
175 | "height":("INT",{"default":512}),
176 | "guidance_scale":("FLOAT",{"default":6.0}),
177 | "seed":("INT",{"default":1234}),
178 | },
179 | }
180 |
181 | RETURN_TYPES = ("IMAGE",)
182 | FUNCTION = "run"
183 | CATEGORY = "EasyAnimate"
184 |
185 | def run(self,model,prompt,negative_prompt,video_length,num_inference_steps,width,height,guidance_scale,seed):
186 | generator = torch.Generator(device='cuda').manual_seed(seed)
187 |
188 | with torch.no_grad():
189 | videos = model(
190 | prompt,
191 | video_length = video_length,
192 | negative_prompt = negative_prompt,
193 | height = height,
194 | width = width,
195 | generator = generator,
196 | guidance_scale = guidance_scale,
197 | num_inference_steps = num_inference_steps,
198 | ).videos
199 |
200 | videos = rearrange(videos, "b c t h w -> b t h w c")
201 |
202 | return videos
203 |
204 | NODE_CLASS_MAPPINGS = {
205 | "EasyAnimateLoader":EasyAnimateLoader,
206 | "EasyAnimateRun":EasyAnimateRun,
207 | }
--------------------------------------------------------------------------------
/predict_t2i.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from diffusers import (AutoencoderKL, DDIMScheduler,
5 | DPMSolverMultistepScheduler,
6 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
7 | PNDMScheduler)
8 | from omegaconf import OmegaConf
9 |
10 | from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
11 | from easyanimate.models.transformer2d import Transformer2DModel
12 | from easyanimate.pipeline.pipeline_pixart_magvit import PixArtAlphaMagvitPipeline
13 | from easyanimate.utils.lora_utils import merge_lora
14 |
15 | # Config and model path
16 | config_path = "config/easyanimate_image_normal_v1.yaml"
17 | model_name = "models/Diffusion_Transformer/PixArt-XL-2-512x512"
18 | # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
19 | sampler_name = "DPM++"
20 |
21 | # Load pretrained model if need
22 | transformer_path = None
23 | vae_path = None
24 | lora_path = None
25 |
26 | # Other params
27 | sample_size = [512, 512]
28 | weight_dtype = torch.float16
29 | prompt = "1girl, bangs, blue eyes, blunt bangs, blurry, blurry background, bob cut, depth of field, lips, looking at viewer, motion blur, nose, realistic, red lips, shirt, short hair, solo, white shirt."
30 | negative_prompt = "bad detailed"
31 | guidance_scale = 6.0
32 | seed = 43
33 | lora_weight = 0.55
34 | save_path = "samples/easyanimate-images"
35 |
36 | config = OmegaConf.load(config_path)
37 |
38 | # Get Transformer
39 | transformer = Transformer2DModel.from_pretrained(
40 | model_name,
41 | subfolder="transformer"
42 | ).to(weight_dtype)
43 |
44 | if transformer_path is not None:
45 | print(f"From checkpoint: {transformer_path}")
46 | if transformer_path.endswith("safetensors"):
47 | from safetensors.torch import load_file, safe_open
48 | state_dict = load_file(transformer_path)
49 | else:
50 | state_dict = torch.load(transformer_path, map_location="cpu")
51 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
52 |
53 | m, u = transformer.load_state_dict(state_dict, strict=False)
54 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
55 |
56 | # Get Vae
57 | if OmegaConf.to_container(config['vae_kwargs'])['enable_magvit']:
58 | Choosen_AutoencoderKL = AutoencoderKLMagvit
59 | else:
60 | Choosen_AutoencoderKL = AutoencoderKL
61 | vae = Choosen_AutoencoderKL.from_pretrained(
62 | model_name,
63 | subfolder="vae",
64 | torch_dtype=weight_dtype
65 | )
66 |
67 | if vae_path is not None:
68 | print(f"From checkpoint: {vae_path}")
69 | if vae_path.endswith("safetensors"):
70 | from safetensors.torch import load_file, safe_open
71 | state_dict = load_file(vae_path)
72 | else:
73 | state_dict = torch.load(vae_path, map_location="cpu")
74 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
75 |
76 | m, u = vae.load_state_dict(state_dict, strict=False)
77 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
78 | assert len(u) == 0
79 |
80 | # Get Scheduler
81 | Choosen_Scheduler = scheduler_dict = {
82 | "Euler": EulerDiscreteScheduler,
83 | "Euler A": EulerAncestralDiscreteScheduler,
84 | "DPM++": DPMSolverMultistepScheduler,
85 | "PNDM": PNDMScheduler,
86 | "DDIM": DDIMScheduler,
87 | }[sampler_name]
88 | scheduler = Choosen_Scheduler(**OmegaConf.to_container(config['noise_scheduler_kwargs']))
89 |
90 | # PixArtAlphaMagvitPipeline is compatible with PixArtAlphaPipeline
91 | pipeline = PixArtAlphaMagvitPipeline.from_pretrained(
92 | model_name,
93 | vae=vae,
94 | transformer=transformer,
95 | scheduler=scheduler,
96 | torch_dtype=weight_dtype
97 | )
98 | pipeline.to("cuda")
99 | pipeline.enable_model_cpu_offload()
100 |
101 | if lora_path is not None:
102 | pipeline = merge_lora(pipeline, lora_path, lora_weight)
103 |
104 | generator = torch.Generator(device="cuda").manual_seed(seed)
105 |
106 | with torch.no_grad():
107 | sample = pipeline(
108 | prompt = prompt,
109 | negative_prompt = negative_prompt,
110 | guidance_scale = guidance_scale,
111 | height = sample_size[0],
112 | width = sample_size[1],
113 | generator = generator,
114 | ).images[0]
115 |
116 | if not os.path.exists(save_path):
117 | os.makedirs(save_path, exist_ok=True)
118 |
119 | index = len([path for path in os.listdir(save_path)]) + 1
120 | prefix = str(index).zfill(8)
121 | image_path = os.path.join(save_path, prefix + ".png")
122 | sample.save(image_path)
--------------------------------------------------------------------------------
/predict_t2v.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from diffusers import (AutoencoderKL, DDIMScheduler,
5 | DPMSolverMultistepScheduler,
6 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
7 | PNDMScheduler)
8 | from omegaconf import OmegaConf
9 |
10 | from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
11 | from easyanimate.models.transformer3d import Transformer3DModel
12 | from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
13 | from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
14 | from easyanimate.utils.utils import save_videos_grid
15 |
16 | # Config and model path
17 | config_path = "config/easyanimate_video_motion_module_v1.yaml"
18 | model_name = "models/Diffusion_Transformer/PixArt-XL-2-512x512"
19 |
20 | # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
21 | sampler_name = "DPM++"
22 |
23 | # Load pretrained model if need
24 | transformer_path = None
25 | motion_module_path = "models/Motion_Module/easyanimate_v1_mm.safetensors"
26 | vae_path = None
27 | lora_path = None
28 |
29 | # other params
30 | sample_size = [512, 512]
31 | video_length = 80
32 | fps = 12
33 |
34 | weight_dtype = torch.float16
35 | prompt = "A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road."
36 | negative_prompt = "Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly"
37 | guidance_scale = 6.0
38 | seed = 43
39 | num_inference_steps = 30
40 | lora_weight = 0.55
41 | save_path = "samples/easyanimate-videos"
42 |
43 | config = OmegaConf.load(config_path)
44 |
45 | # Get Transformer
46 | transformer = Transformer3DModel.from_pretrained_2d(
47 | model_name,
48 | subfolder="transformer",
49 | transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs'])
50 | ).to(weight_dtype)
51 |
52 | if transformer_path is not None:
53 | print(f"From checkpoint: {transformer_path}")
54 | if transformer_path.endswith("safetensors"):
55 | from safetensors.torch import load_file, safe_open
56 | state_dict = load_file(transformer_path)
57 | else:
58 | state_dict = torch.load(transformer_path, map_location="cpu")
59 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
60 |
61 | m, u = transformer.load_state_dict(state_dict, strict=False)
62 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
63 |
64 | if motion_module_path is not None:
65 | print(f"From Motion Module: {motion_module_path}")
66 | if motion_module_path.endswith("safetensors"):
67 | from safetensors.torch import load_file, safe_open
68 | state_dict = load_file(motion_module_path)
69 | else:
70 | state_dict = torch.load(motion_module_path, map_location="cpu")
71 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
72 |
73 | m, u = transformer.load_state_dict(state_dict, strict=False)
74 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}, {u}")
75 |
76 | # Get Vae
77 | if OmegaConf.to_container(config['vae_kwargs'])['enable_magvit']:
78 | Choosen_AutoencoderKL = AutoencoderKLMagvit
79 | else:
80 | Choosen_AutoencoderKL = AutoencoderKL
81 | vae = Choosen_AutoencoderKL.from_pretrained(
82 | model_name,
83 | subfolder="vae",
84 | torch_dtype=weight_dtype
85 | )
86 |
87 | if vae_path is not None:
88 | print(f"From checkpoint: {vae_path}")
89 | if vae_path.endswith("safetensors"):
90 | from safetensors.torch import load_file, safe_open
91 | state_dict = load_file(vae_path)
92 | else:
93 | state_dict = torch.load(vae_path, map_location="cpu")
94 | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
95 |
96 | m, u = vae.load_state_dict(state_dict, strict=False)
97 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
98 |
99 | # Get Scheduler
100 | Choosen_Scheduler = scheduler_dict = {
101 | "Euler": EulerDiscreteScheduler,
102 | "Euler A": EulerAncestralDiscreteScheduler,
103 | "DPM++": DPMSolverMultistepScheduler,
104 | "PNDM": PNDMScheduler,
105 | "DDIM": DDIMScheduler,
106 | }[sampler_name]
107 | scheduler = Choosen_Scheduler(**OmegaConf.to_container(config['noise_scheduler_kwargs']))
108 |
109 | pipeline = EasyAnimatePipeline.from_pretrained(
110 | model_name,
111 | vae=vae,
112 | transformer=transformer,
113 | scheduler=scheduler,
114 | torch_dtype=weight_dtype
115 | )
116 | pipeline.to("cuda")
117 | pipeline.enable_model_cpu_offload()
118 |
119 | generator = torch.Generator(device="cuda").manual_seed(seed)
120 |
121 | if lora_path is not None:
122 | pipeline = merge_lora(pipeline, lora_path, lora_weight)
123 |
124 | with torch.no_grad():
125 | sample = pipeline(
126 | prompt,
127 | video_length = video_length,
128 | negative_prompt = negative_prompt,
129 | height = sample_size[0],
130 | width = sample_size[1],
131 | generator = generator,
132 | guidance_scale = guidance_scale,
133 | num_inference_steps = num_inference_steps,
134 | ).videos
135 |
136 | if not os.path.exists(save_path):
137 | os.makedirs(save_path, exist_ok=True)
138 |
139 | index = len([path for path in os.listdir(save_path)]) + 1
140 | prefix = str(index).zfill(8)
141 | video_path = os.path.join(save_path, prefix + ".gif")
142 | save_videos_grid(sample, video_path, fps=fps)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow
2 | einops
3 | safetensors
4 | timm
5 | tomesd
6 | xformers
7 | decord
8 | datasets
9 | numpy
10 | scikit-image
11 | opencv-python
12 | omegaconf
13 | diffusers
14 | transformers
--------------------------------------------------------------------------------
/scripts/extra_motion_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from safetensors.torch import load_file, safe_open, save_file
3 |
4 | original_safetensor_path = 'diffusion_pytorch_model.safetensors'
5 | new_safetensor_path = 'easyanimate_v1_mm.safetensors' #
6 |
7 | original_weights = load_file(original_safetensor_path)
8 | temporal_weights = {}
9 | for name, weight in original_weights.items():
10 | if 'temporal' in name:
11 | temporal_weights[name] = weight
12 | save_file(temporal_weights, new_safetensor_path, None)
13 | print(f'Saved weights containing "temporal" to {new_safetensor_path}')
--------------------------------------------------------------------------------
/scripts/train_t2i.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="models/Diffusion_Transformer/PixArt-XL-2-512x512"
2 | export DATASET_NAME="datasets/internal_datasets/"
3 | export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
4 |
5 | # When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
6 | # vae_mode can be choosen in "normal" and "magvit"
7 | # transformer_mode can be choosen in "normal" and "kvcompress"
8 | accelerate launch --mixed_precision="bf16" scripts/train_t2i.py \
9 | --pretrained_model_name_or_path=$MODEL_NAME \
10 | --train_data_dir=$DATASET_NAME \
11 | --train_data_meta=$DATASET_META_NAME \
12 | --config_path "config/easyanimate_image_normal_v1.yaml" \
13 | --train_data_format="normal" \
14 | --caption_column="text" \
15 | --resolution=512 \
16 | --train_batch_size=2 \
17 | --gradient_accumulation_steps=1 \
18 | --dataloader_num_workers=8 \
19 | --num_train_epochs=50 \
20 | --checkpointing_steps=500 \
21 | --validation_prompts="1girl, bangs, blue eyes, blunt bangs, blurry, blurry background, bob cut, depth of field, lips, looking at viewer, motion blur, nose, realistic, red lips, shirt, short hair, solo, white shirt." \
22 | --validation_epochs=1 \
23 | --validation_steps=100 \
24 | --learning_rate=1e-05 \
25 | --lr_scheduler="constant_with_warmup" \
26 | --lr_warmup_steps=50 \
27 | --seed=42 \
28 | --max_grad_norm=1 \
29 | --output_dir="output_dir_t2i" \
30 | --enable_xformers_memory_efficient_attention \
31 | --gradient_checkpointing \
32 | --mixed_precision='bf16' \
33 | --use_ema \
34 | --trainable_modules "."
--------------------------------------------------------------------------------
/scripts/train_t2i_lora.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="models/Diffusion_Transformer/PixArt-XL-2-512x512"
2 | export DATASET_NAME="datasets/internal_datasets/"
3 | export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
4 |
5 | # When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
6 | # vae_mode can be choosen in "normal" and "magvit"
7 | # transformer_mode can be choosen in "normal" and "kvcompress"
8 | accelerate launch --mixed_precision="bf16" scripts/train_t2i_lora.py \
9 | --pretrained_model_name_or_path=$MODEL_NAME \
10 | --train_data_dir=$DATASET_NAME \
11 | --train_data_meta=$DATASET_META_NAME \
12 | --config_path "config/easyanimate_image_normal_v1.yaml" \
13 | --train_data_format="normal" \
14 | --caption_column="text" \
15 | --resolution=512 \
16 | --train_text_encoder \
17 | --train_batch_size=2 \
18 | --gradient_accumulation_steps=1 \
19 | --dataloader_num_workers=8 \
20 | --max_train_steps=2500 \
21 | --checkpointing_steps=500 \
22 | --validation_prompts="1girl, bangs, blue eyes, blunt bangs, blurry, blurry background, bob cut, depth of field, lips, looking at viewer, motion blur, nose, realistic, red lips, shirt, short hair, solo, white shirt." \
23 | --validation_steps=100 \
24 | --learning_rate=1e-04 \
25 | --seed=42 \
26 | --output_dir="output_dir_lora" \
27 | --enable_xformers_memory_efficient_attention \
28 | --gradient_checkpointing \
29 | --mixed_precision='bf16'
30 |
--------------------------------------------------------------------------------
/scripts/train_t2iv.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="models/Diffusion_Transformer/PixArt-XL-2-512x512"
2 | export DATASET_NAME="datasets/internal_datasets/"
3 | export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
4 | export NCCL_IB_DISABLE=1
5 | export NCCL_P2P_DISABLE=1
6 | NCCL_DEBUG=INFO
7 |
8 | # When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
9 | # vae_mode can be choosen in "normal" and "magvit"
10 | # transformer_mode can be choosen in "normal" and "kvcompress"
11 | accelerate launch --mixed_precision="bf16" scripts/train_t2iv.py \
12 | --pretrained_model_name_or_path=$MODEL_NAME \
13 | --train_data_dir=$DATASET_NAME \
14 | --train_data_meta=$DATASET_META_NAME \
15 | --config_path "config/easyanimate_video_motion_module_v1.yaml" \
16 | --image_sample_size=512 \
17 | --video_sample_size=512 \
18 | --video_sample_stride=2 \
19 | --video_sample_n_frames=16 \
20 | --train_batch_size=2 \
21 | --video_repeat=1 \
22 | --image_repeat_in_forward=4 \
23 | --gradient_accumulation_steps=1 \
24 | --dataloader_num_workers=8 \
25 | --num_train_epochs=100 \
26 | --checkpointing_steps=500 \
27 | --validation_prompts="A girl with delicate face is smiling." \
28 | --validation_epochs=1 \
29 | --validation_steps=100 \
30 | --learning_rate=2e-05 \
31 | --lr_scheduler="constant_with_warmup" \
32 | --lr_warmup_steps=100 \
33 | --seed=42 \
34 | --output_dir="output_dir" \
35 | --enable_xformers_memory_efficient_attention \
36 | --gradient_checkpointing \
37 | --mixed_precision="bf16" \
38 | --adam_weight_decay=3e-2 \
39 | --adam_epsilon=1e-10 \
40 | --max_grad_norm=1 \
41 | --vae_mini_batch=16 \
42 | --use_ema \
43 | --trainable_modules "attn_temporal"
--------------------------------------------------------------------------------
/scripts/train_t2v.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="models/Diffusion_Transformer/PixArt-XL-2-512x512"
2 | export DATASET_NAME="datasets/internal_datasets/"
3 | export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
4 | export NCCL_IB_DISABLE=1
5 | export NCCL_P2P_DISABLE=1
6 | NCCL_DEBUG=INFO
7 |
8 | # When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
9 | accelerate launch --mixed_precision="bf16" scripts/train_t2v.py \
10 | --pretrained_model_name_or_path=$MODEL_NAME \
11 | --train_data_dir=$DATASET_NAME \
12 | --train_data_meta=$DATASET_META_NAME \
13 | --config_path "config/easyanimate_video_motion_module_v1.yaml" \
14 | --train_data_format="normal" \
15 | --train_mode="normal" \
16 | --sample_size=512 \
17 | --sample_stride=2 \
18 | --sample_n_frames=16 \
19 | --train_batch_size=2 \
20 | --gradient_accumulation_steps=1 \
21 | --dataloader_num_workers=8 \
22 | --num_train_epochs=100 \
23 | --checkpointing_steps=500 \
24 | --validation_prompts="A girl with delicate face is smiling." \
25 | --validation_epochs=1 \
26 | --validation_steps=100 \
27 | --learning_rate=2e-05 \
28 | --lr_scheduler="constant_with_warmup" \
29 | --lr_warmup_steps=100 \
30 | --seed=42 \
31 | --output_dir="output_dir" \
32 | --enable_xformers_memory_efficient_attention \
33 | --gradient_checkpointing \
34 | --mixed_precision="bf16" \
35 | --adam_weight_decay=3e-2 \
36 | --adam_epsilon=1e-10 \
37 | --max_grad_norm=1 \
38 | --vae_mini_batch=16 \
39 | --use_ema \
40 | --trainable_modules "attn_temporal"
--------------------------------------------------------------------------------
/scripts/train_t2v_lora.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="models/Diffusion_Transformer/PixArt-XL-2-512x512"
2 | export DATASET_NAME="datasets/internal_datasets/"
3 | export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
4 | export NCCL_IB_DISABLE=1
5 | export NCCL_P2P_DISABLE=1
6 | NCCL_DEBUG=INFO
7 |
8 | # When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
9 | accelerate launch --mixed_precision="bf16" scripts/train_t2v_lora.py \
10 | --pretrained_model_name_or_path=$MODEL_NAME \
11 | --transformer_path="models/Motion_Module/easyanimate_v1_mm.safetensors" \
12 | --train_data_dir=$DATASET_NAME \
13 | --train_data_meta=$DATASET_META_NAME \
14 | --config_path "config/easyanimate_video_motion_module_v1.yaml" \
15 | --train_data_format="normal" \
16 | --train_mode="normal" \
17 | --sample_size=512 \
18 | --sample_stride=2 \
19 | --sample_n_frames=16 \
20 | --train_batch_size=2 \
21 | --gradient_accumulation_steps=1 \
22 | --dataloader_num_workers=8 \
23 | --num_train_epochs=100 \
24 | --checkpointing_steps=500 \
25 | --validation_prompts="A girl with delicate face is smiling." \
26 | --validation_epochs=1 \
27 | --validation_steps=100 \
28 | --learning_rate=2e-05 \
29 | --seed=42 \
30 | --output_dir="output_dir" \
31 | --enable_xformers_memory_efficient_attention \
32 | --gradient_checkpointing \
33 | --mixed_precision="bf16" \
34 | --adam_weight_decay=3e-2 \
35 | --adam_epsilon=1e-10 \
36 | --max_grad_norm=1 \
37 | --vae_mini_batch=16
38 |
--------------------------------------------------------------------------------
/wf.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 8,
3 | "last_link_id": 7,
4 | "nodes": [
5 | {
6 | "id": 2,
7 | "type": "EasyAnimateRun",
8 | "pos": [
9 | 459,
10 | -4
11 | ],
12 | "size": {
13 | "0": 387.611083984375,
14 | "1": 427.48992919921875
15 | },
16 | "flags": {},
17 | "order": 1,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "model",
22 | "type": "EasyAnimateModel",
23 | "link": 7
24 | }
25 | ],
26 | "outputs": [
27 | {
28 | "name": "IMAGE",
29 | "type": "IMAGE",
30 | "links": [
31 | 4
32 | ],
33 | "shape": 3,
34 | "slot_index": 0
35 | }
36 | ],
37 | "properties": {
38 | "Node name for S&R": "EasyAnimateRun"
39 | },
40 | "widgets_values": [
41 | "A serene night scene in a forested area. The first frame shows a tranquil lake reflecting the star-filled sky above. The second frame reveals a beautiful sunset, casting a warm glow over the landscape. The third frame showcases the night sky, filled with stars and a vibrant Milky Way galaxy. The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. The style of the video is naturalistic, emphasizing the beauty of the night sky and the peacefulness of the forest.",
42 | "Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly",
43 | 80,
44 | 30,
45 | 512,
46 | 512,
47 | 6,
48 | 1829,
49 | "fixed"
50 | ]
51 | },
52 | {
53 | "id": 8,
54 | "type": "EasyAnimateLoader",
55 | "pos": [
56 | -31,
57 | -3
58 | ],
59 | "size": {
60 | "0": 315,
61 | "1": 226
62 | },
63 | "flags": {},
64 | "order": 0,
65 | "mode": 0,
66 | "outputs": [
67 | {
68 | "name": "EasyAnimateModel",
69 | "type": "EasyAnimateModel",
70 | "links": [
71 | 7
72 | ],
73 | "shape": 3,
74 | "slot_index": 0
75 | }
76 | ],
77 | "properties": {
78 | "Node name for S&R": "EasyAnimateLoader"
79 | },
80 | "widgets_values": [
81 | "PixArt-XL-2-512x512",
82 | "easyanimate_v1_mm.safetensors",
83 | "DPM++",
84 | "cuda",
85 | "PixArt-Sigma-XL-2-1024-MS.pth",
86 | "None",
87 | "None",
88 | 0.55
89 | ]
90 | },
91 | {
92 | "id": 5,
93 | "type": "VHS_VideoCombine",
94 | "pos": [
95 | 1018,
96 | -13
97 | ],
98 | "size": [
99 | 315,
100 | 599
101 | ],
102 | "flags": {},
103 | "order": 2,
104 | "mode": 0,
105 | "inputs": [
106 | {
107 | "name": "images",
108 | "type": "IMAGE",
109 | "link": 4
110 | },
111 | {
112 | "name": "audio",
113 | "type": "VHS_AUDIO",
114 | "link": null
115 | },
116 | {
117 | "name": "batch_manager",
118 | "type": "VHS_BatchManager",
119 | "link": null
120 | }
121 | ],
122 | "outputs": [
123 | {
124 | "name": "Filenames",
125 | "type": "VHS_FILENAMES",
126 | "links": null,
127 | "shape": 3
128 | }
129 | ],
130 | "properties": {
131 | "Node name for S&R": "VHS_VideoCombine"
132 | },
133 | "widgets_values": {
134 | "frame_rate": 8,
135 | "loop_count": 0,
136 | "filename_prefix": "AnimateDiff",
137 | "format": "video/h264-mp4",
138 | "pix_fmt": "yuv420p",
139 | "crf": 19,
140 | "save_metadata": true,
141 | "pingpong": false,
142 | "save_output": true,
143 | "videopreview": {
144 | "hidden": false,
145 | "paused": false,
146 | "params": {
147 | "filename": "AnimateDiff_00461.mp4",
148 | "subfolder": "",
149 | "type": "output",
150 | "format": "video/h264-mp4"
151 | }
152 | }
153 | }
154 | }
155 | ],
156 | "links": [
157 | [
158 | 4,
159 | 2,
160 | 0,
161 | 5,
162 | 0,
163 | "IMAGE"
164 | ],
165 | [
166 | 7,
167 | 8,
168 | 0,
169 | 2,
170 | 0,
171 | "EasyAnimateModel"
172 | ]
173 | ],
174 | "groups": [],
175 | "config": {},
176 | "extra": {},
177 | "version": 0.4
178 | }
--------------------------------------------------------------------------------
/wf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-EasyAnimate/9bef69d1ceda9d300613488517af6cc66cf5c360/wf.png
--------------------------------------------------------------------------------