├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── __assets__
└── image_animation
│ ├── loop
│ ├── labrador
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ └── 3.gif
│ └── lighthouse
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ └── 3.gif
│ ├── magnitude
│ ├── bear
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ └── 3.gif
│ ├── genshin
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ └── 3.gif
│ └── labrador
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ └── 3.gif
│ ├── majic
│ ├── 1.gif
│ ├── 2.gif
│ └── 3.gif
│ ├── rcnz
│ ├── 1.gif
│ ├── 2.gif
│ └── 3.gif
│ ├── real
│ ├── 1.gif
│ ├── 2.gif
│ └── 3.gif
│ ├── style_transfer
│ ├── anya
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ └── 3.gif
│ ├── bear
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ └── 3.gif
│ └── concert
│ │ ├── 1.gif
│ │ ├── 2.gif
│ │ ├── 3.gif
│ │ ├── 4.gif
│ │ ├── 5.gif
│ │ └── 6.gif
│ └── teaser
│ └── teaser.gif
├── animatediff
├── data
│ ├── dataset.py
│ └── video_transformer.py
├── models
│ ├── __init__.py
│ ├── attention.py
│ ├── motion_module.py
│ ├── resnet.py
│ ├── unet.py
│ └── unet_blocks.py
├── pipelines
│ ├── __init__.py
│ ├── i2v_pipeline.py
│ ├── pipeline_animation.py
│ └── validation_pipeline.py
└── utils
│ ├── convert_from_ckpt.py
│ ├── convert_lora_safetensor_to_diffusers.py
│ └── util.py
├── app.py
├── cog.yaml
├── download_bashscripts
├── 1-RealisticVision.sh
├── 2-RcnzCartoon.sh
└── 3-MajicMix.sh
├── environment-pt2.yaml
├── environment.yaml
├── example
├── config
│ ├── anya.yaml
│ ├── base.yaml
│ ├── bear.yaml
│ ├── concert.yaml
│ ├── genshin.yaml
│ ├── harry.yaml
│ ├── labrador.yaml
│ ├── lighthouse.yaml
│ ├── majic_girl.yaml
│ └── train.yaml
├── img
│ ├── anya.jpg
│ ├── bear.jpg
│ ├── concert.png
│ ├── genshin.jpg
│ ├── harry.png
│ ├── labrador.png
│ ├── lighthouse.jpg
│ └── majic_girl.jpg
├── openxlab
│ ├── 1-realistic.yaml
│ └── 3-3d.yaml
└── replicate
│ ├── 1-realistic.yaml
│ └── 3-3d.yaml
├── inference.py
├── models
├── DreamBooth_LoRA
│ └── Put personalized T2I checkpoints here.txt
├── IP_Adapter
│ └── Put IP-Adapter checkpoints here.txt
├── Motion_Module
│ └── Put motion module checkpoints here.txt
└── VAE
│ └── Put VAE checkpoints here.txt
├── pia.png
├── pia.yml
├── predict.py
├── pyproject.toml
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pkl
2 | *.pt
3 | *.mov
4 | *.pth
5 | *.json
6 | *.mov
7 | *.npz
8 | *.npy
9 | *.boj
10 | *.onnx
11 | *.tar
12 | *.bin
13 | cache*
14 | batch*
15 | *.jpg
16 | *.png
17 | *.mp4
18 | *.gif
19 | *.ckpt
20 | *.safetensors
21 | *.zip
22 | *.csv
23 | *.log
24 |
25 | **/__pycache__/
26 | samples/
27 | wandb/
28 | outputs/
29 | example/result
30 | models/StableDiffusion
31 | models/PIA
32 |
33 | !pia.png
34 | .DS_Store
35 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # Ruff version.
4 | rev: v0.3.5
5 | hooks:
6 | # Run the linter.
7 | - id: ruff
8 | args: [ --fix ]
9 | # Run the formatter.
10 | - id: ruff-format
11 | - repo: https://github.com/codespell-project/codespell
12 | rev: v2.2.1
13 | hooks:
14 | - id: codespell
15 | - repo: https://github.com/pre-commit/pre-commit-hooks
16 | rev: v4.3.0
17 | hooks:
18 | - id: trailing-whitespace
19 | - id: check-yaml
20 | - id: end-of-file-fixer
21 | - id: requirements-txt-fixer
22 | - id: fix-encoding-pragma
23 | args: ["--remove"]
24 | - id: mixed-line-ending
25 | args: ["--fix=lf"]
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CVPR 2024 | PIA:Personalized Image Animator
2 |
3 | [**PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models**](https://arxiv.org/abs/2312.13964)
4 |
5 | [Yiming Zhang*](https://github.com/ymzhang0319), [Zhening Xing*](https://github.com/LeoXing1996/), [Yanhong Zeng†](https://zengyh1900.github.io/), [Youqing Fang](https://github.com/FangYouqing), [Kai Chen†](https://chenkai.site/)
6 |
7 | (*equal contribution, †corresponding Author)
8 |
9 |
10 | [](https://arxiv.org/abs/2312.13964)
11 | [](https://pi-animator.github.io)
12 | [](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia)
13 | [](https://colab.research.google.com/github/camenduru/PIA-colab/blob/main/PIA_colab.ipynb)
14 | [](https://huggingface.co/Leoxing/PIA)
15 |
16 |
17 |
18 | [](https://replicate.com/cjwbw/pia)
19 |
20 |
21 | PIA is a personalized image animation method which can generate videos with **high motion controllability** and **strong text and image alignment**.
22 |
23 | If you find our project helpful, please give it a star :star: or [cite](#bibtex) it, we would be very grateful :sparkling_heart: .
24 |
25 |
26 |
27 |
28 | ## What's New
29 | - [x] `2024/01/03` [Replicate Demo & API](https://replicate.com/cjwbw/pia) support!
30 | - [x] `2024/01/03` [Colab](https://github.com/camenduru/PIA-colab) support from [camenduru](https://github.com/camenduru)!
31 | - [x] `2023/12/28` Support `scaled_dot_product_attention` for 1024x1024 images with just 16GB of GPU memory.
32 | - [x] `2023/12/25` HuggingFace demo is available now! [🤗 Hub](https://huggingface.co/spaces/Leoxing/PIA/)
33 | - [x] `2023/12/22` Release the demo of PIA on [OpenXLab](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia) and checkpoints on [Google Drive](https://drive.google.com/file/d/1RL3Fp0Q6pMD8PbGPULYUnvjqyRQXGHwN/view?usp=drive_link) or [](https://openxlab.org.cn/models/detail/zhangyiming/PIA)
34 |
35 | ## Setup
36 | ### Prepare Environment
37 |
38 | Use the following command to install a conda environment for PIA from scratch:
39 |
40 | ```
41 | conda env create -f pia.yml
42 | conda activate pia
43 | ```
44 | You may also want to install it based on an existing environment, then you can use `environment-pt2.yaml` for Pytorch==2.0.0. If you want to use lower version of Pytorch (e.g. 1.13.1), you can use the following command:
45 |
46 | ```
47 | conda env create -f environment.yaml
48 | conda activate pia
49 | ```
50 |
51 | We strongly recommend you to use Pytorch==2.0.0 which supports `scaled_dot_product_attention` for memory-efficient image animation.
52 |
53 | ### Download checkpoints
54 |
Download the Stable Diffusion v1-5
55 |
56 | ```
57 | conda install git-lfs
58 | git lfs install
59 | git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/
60 | ```
61 |
62 | Download PIA
63 |
64 | ```
65 | git clone https://huggingface.co/Leoxing/PIA models/PIA/
66 | ```
67 |
68 | Download Personalized Models
69 |
70 | ```
71 | bash download_bashscripts/1-RealisticVision.sh
72 | bash download_bashscripts/2-RcnzCartoon.sh
73 | bash download_bashscripts/3-MajicMix.sh
74 | ```
75 |
76 |
77 | You can also download *pia.ckpt* manually through link on [Google Drive](https://drive.google.com/file/d/1RL3Fp0Q6pMD8PbGPULYUnvjqyRQXGHwN/view?usp=drive_link)
78 | or [HuggingFace](https://huggingface.co/Leoxing/PIA).
79 |
80 | Put checkpoints as follows:
81 | ```
82 | └── models
83 | ├── DreamBooth_LoRA
84 | │ ├── ...
85 | ├── PIA
86 | │ ├── pia.ckpt
87 | └── StableDiffusion
88 | ├── vae
89 | ├── unet
90 | └── ...
91 | ```
92 |
93 | ## Inference
94 | ### Image Animation
95 | Image to Video result can be obtained by:
96 | ```
97 | python inference.py --config=example/config/lighthouse.yaml
98 | python inference.py --config=example/config/harry.yaml
99 | python inference.py --config=example/config/majic_girl.yaml
100 | ```
101 | Run the command above, then you can find the results in example/result:
102 |
103 |
104 | Input Image |
105 | lightning, lighthouse |
106 | sun rising, lighthouse |
107 | fireworks, lighthouse |
108 |
109 |
110 |  |
111 |  |
112 |  |
113 |  |
114 |
115 |
116 | Input Image |
117 | 1boy smiling |
118 | 1boy playing the magic fire |
119 | 1boy is waving hands |
120 |
121 |
122 |  |
123 |  |
124 |  |
125 |  |
126 |
127 |
128 | Input Image |
129 | 1girl is smiling |
130 | 1girl is crying |
131 | 1girl, snowing |
132 |
133 |
134 |  |
135 |  |
136 |  |
137 |  |
138 |
139 |
140 |
141 |
142 |
151 |
152 | ### Motion Magnitude
153 | You can control the motion magnitude through the parameter **magnitude**:
154 | ```sh
155 | python inference.py --config=example/config/xxx.yaml --magnitude=0 # Small Motion
156 | python inference.py --config=example/config/xxx.yaml --magnitude=1 # Moderate Motion
157 | python inference.py --config=example/config/xxx.yaml --magnitude=2 # Large Motion
158 | ```
159 | Examples:
160 |
161 | ```sh
162 | python inference.py --config=example/config/labrador.yaml
163 | python inference.py --config=example/config/bear.yaml
164 | python inference.py --config=example/config/genshin.yaml
165 | ```
166 |
167 |
193 |
194 | ### Style Transfer
195 | To achieve style transfer, you can run the command(*Please don't forget set the base model in xxx.yaml*):
196 |
197 | Examples:
198 |
199 | ```sh
200 | python inference.py --config example/config/concert.yaml --style_transfer
201 | python inference.py --config example/config/anya.yaml --style_transfer
202 | ```
203 |
235 |
236 | ### Loop Video
237 |
238 | You can generate loop by using the parameter --loop
239 |
240 | ```sh
241 | python inference.py --config=example/config/xxx.yaml --loop
242 | ```
243 |
244 | Examples:
245 | ```sh
246 | python inference.py --config=example/config/lighthouse.yaml --loop
247 | python inference.py --config=example/config/labrador.yaml --loop
248 | ```
249 |
250 |
251 |
252 | Input Image |
253 | lightning, lighthouse |
254 | sun rising, lighthouse |
255 | fireworks, lighthouse |
256 |
257 |
258 |  |
259 |  |
260 |  |
261 |  |
262 |
263 |
264 | Input Image |
265 | labrador jumping |
266 | labrador walking |
267 | labrador running |
268 |
269 |
270 |  |
271 |  |
272 |  |
273 |  |
274 |
275 |
276 |
277 |
278 | ## Training
279 |
280 | We provide [training script]("train.py") for PIA. It borrows from [AnimateDiff](https://github.com/guoyww/AnimateDiff/tree/main) heavily, so please prepare the dataset and configuration files according to the [guideline](https://github.com/guoyww/AnimateDiff/blob/main/__assets__/docs/animatediff.md#steps-for-training).
281 |
282 | After preparation, you can train the model by running the following command using torchrun:
283 |
284 | ```shell
285 | torchrun --nnodes=1 --nproc_per_node=1 train.py --config example/config/train.yaml
286 | ```
287 |
288 | or by slurm,
289 | ```shell
290 | srun --quotatype=reserved --job-name=pia --gres=gpu:8 --ntasks-per-node=8 --ntasks=8 --cpus-per-task=4 --kill-on-bad-exit=1 python train.py --config example/config/train.yaml
291 | ```
292 |
293 |
294 | ## AnimateBench
295 | We have open-sourced AnimateBench on [HuggingFace](https://huggingface.co/datasets/ymzhang319/AnimateBench) which includes images, prompts and configs to evaluate PIA and other image animation methods.
296 |
297 |
298 | ## BibTex
299 | ```
300 | @inproceedings{zhang2024pia,
301 | title={Pia: Your personalized image animator via plug-and-play modules in text-to-image models},
302 | author={Zhang, Yiming and Xing, Zhening and Zeng, Yanhong and Fang, Youqing and Chen, Kai},
303 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
304 | pages={7747--7756},
305 | year={2024}
306 | }
307 | ```
308 |
309 |
310 |
311 |
312 | ## Contact Us
313 | **Yiming Zhang**: zhangyiming@pjlab.org.cn
314 |
315 | **Zhening Xing**: xingzhening@pjlab.org.cn
316 |
317 | **Yanhong Zeng**: zengyh1900@gmail.com
318 |
319 | ## Acknowledgements
320 | The code is built upon [AnimateDiff](https://github.com/guoyww/AnimateDiff), [Tune-a-Video](https://github.com/showlab/Tune-A-Video) and [PySceneDetect](https://github.com/Breakthrough/PySceneDetect)
321 |
322 | You may also want to try other project from our team:
323 |
324 |
325 |
326 |
--------------------------------------------------------------------------------
/__assets__/image_animation/loop/labrador/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/labrador/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/loop/labrador/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/labrador/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/loop/labrador/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/labrador/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/loop/lighthouse/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/lighthouse/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/loop/lighthouse/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/lighthouse/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/loop/lighthouse/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/lighthouse/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/bear/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/bear/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/bear/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/bear/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/bear/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/bear/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/genshin/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/genshin/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/genshin/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/genshin/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/genshin/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/genshin/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/labrador/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/labrador/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/labrador/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/labrador/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/magnitude/labrador/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/labrador/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/majic/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/majic/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/majic/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/majic/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/majic/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/majic/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/rcnz/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/rcnz/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/rcnz/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/rcnz/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/rcnz/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/rcnz/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/real/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/real/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/real/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/real/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/real/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/real/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/anya/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/anya/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/anya/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/anya/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/anya/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/anya/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/bear/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/bear/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/bear/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/bear/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/bear/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/bear/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/concert/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/1.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/concert/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/2.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/concert/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/3.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/concert/4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/4.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/concert/5.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/5.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/style_transfer/concert/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/6.gif
--------------------------------------------------------------------------------
/__assets__/image_animation/teaser/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/teaser/teaser.gif
--------------------------------------------------------------------------------
/animatediff/data/dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import io
3 | import os
4 | import random
5 |
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torchvision.transforms as transforms
10 | from decord import VideoReader
11 | from torch.utils.data.dataset import Dataset
12 |
13 | import animatediff.data.video_transformer as video_transforms
14 | from animatediff.utils.util import detect_edges, zero_rank_print
15 |
16 |
17 | try:
18 | from petrel_client.client import Client
19 | except ImportError as e:
20 | print(e)
21 |
22 |
23 | def get_score(video_data, cond_frame_idx, weight=[1.0, 1.0, 1.0, 1.0], use_edge=True):
24 | """
25 | Similar to get_score under utils/util.py/detect_edges
26 | """
27 | """
28 | the shape of video_data is f c h w, np.ndarray
29 | """
30 | h, w = video_data.shape[1], video_data.shape[2]
31 |
32 | cond_frame = video_data[cond_frame_idx]
33 | cond_hsv_list = list(cv2.split(cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV)))
34 |
35 | if use_edge:
36 | cond_frame_lum = cond_hsv_list[-1]
37 | cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8))
38 | cond_hsv_list.append(cond_frame_edge)
39 |
40 | score_sum = []
41 |
42 | for frame_idx in range(video_data.shape[0]):
43 | frame = video_data[frame_idx]
44 | hsv_list = list(cv2.split(cv2.cvtColor(frame.astype(np.float32), cv2.COLOR_RGB2HSV)))
45 |
46 | if use_edge:
47 | frame_img_lum = hsv_list[-1]
48 | frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8))
49 | hsv_list.append(frame_img_edge)
50 |
51 | hsv_diff = [np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight))]
52 | hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))]
53 | score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight)))
54 |
55 | return score_sum
56 |
57 |
58 | class WebVid10M(Dataset):
59 | def __init__(
60 | self,
61 | csv_path,
62 | video_folder,
63 | sample_size=256,
64 | sample_stride=4,
65 | sample_n_frames=16,
66 | is_image=False,
67 | use_petreloss=False,
68 | conf_path=None,
69 | ):
70 | if use_petreloss:
71 | self._client = Client(conf_path=conf_path, enable_mc=True)
72 | else:
73 | self._client = None
74 | self.video_folder = video_folder
75 | self.sample_stride = sample_stride
76 | self.sample_n_frames = sample_n_frames
77 | self.is_image = is_image
78 | self.temporal_sampler = video_transforms.TemporalRandomCrop(sample_n_frames * sample_stride)
79 |
80 | zero_rank_print(f"loading annotations from {csv_path} ...")
81 | with open(csv_path, "r") as csvfile:
82 | self.dataset = list(csv.DictReader(csvfile))
83 | self.length = len(self.dataset)
84 | zero_rank_print(f"data scale: {self.length}")
85 |
86 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
87 | self.pixel_transforms = transforms.Compose(
88 | [
89 | transforms.RandomHorizontalFlip(),
90 | transforms.Resize(sample_size[0]),
91 | transforms.CenterCrop(sample_size),
92 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
93 | ]
94 | )
95 |
96 | def get_batch(self, idx):
97 | video_dict = self.dataset[idx]
98 | videoid, name, page_dir = video_dict["videoid"], video_dict["name"], video_dict["page_dir"]
99 |
100 | if self._client is not None:
101 | video_dir = os.path.join(self.video_folder, page_dir, f"{videoid}.mp4")
102 | video_bytes = self._client.Get(video_dir)
103 | video_bytes = io.BytesIO(video_bytes)
104 | # ensure not reading zero byte
105 | assert video_bytes.getbuffer().nbytes != 0
106 | video_reader = VideoReader(video_bytes)
107 | else:
108 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
109 | video_reader = VideoReader(video_dir)
110 |
111 | total_frames = len(video_reader)
112 | if not self.is_image:
113 | start_frame_ind, end_frame_ind = self.temporal_sampler(total_frames)
114 | assert end_frame_ind - start_frame_ind >= self.sample_n_frames
115 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
116 | else:
117 | frame_indice = [random.randint(0, total_frames - 1)]
118 |
119 | pixel_values_np = video_reader.get_batch(frame_indice).asnumpy()
120 | cond_frames = random.randint(0, self.sample_n_frames - 1)
121 |
122 | # f h w c -> f c h w
123 | pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous()
124 | pixel_values = pixel_values / 255.0
125 | del video_reader
126 |
127 | if self.is_image:
128 | pixel_values = pixel_values[0]
129 |
130 | return pixel_values, name, cond_frames, videoid
131 |
132 | def __len__(self):
133 | return self.length
134 |
135 | def __getitem__(self, idx):
136 | while True:
137 | try:
138 | video, name, cond_frames, videoid = self.get_batch(idx)
139 | break
140 |
141 | except Exception:
142 | zero_rank_print("Error loading video, retrying...")
143 | idx = random.randint(0, self.length - 1)
144 |
145 | video = self.pixel_transforms(video)
146 | video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5
147 | video_ = video_ * 255
148 | score = get_score(video_, cond_frame_idx=cond_frames)
149 | del video_
150 | sample = {"pixel_values": video, "text": name, "score": score, "cond_frames": cond_frames, "vid": videoid}
151 | return sample
152 |
--------------------------------------------------------------------------------
/animatediff/data/video_transformer.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 |
4 | import torch
5 |
6 |
7 | def _is_tensor_video_clip(clip):
8 | if not torch.is_tensor(clip):
9 | raise TypeError("clip should be Tensor. Got %s" % type(clip))
10 |
11 | if not clip.ndimension() == 4:
12 | raise ValueError("clip should be 4D. Got %dD" % clip.dim())
13 |
14 | return True
15 |
16 |
17 | def crop(clip, i, j, h, w):
18 | """
19 | Args:
20 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
21 | """
22 | if len(clip.size()) != 4:
23 | raise ValueError("clip should be a 4D tensor")
24 | return clip[..., i : i + h, j : j + w]
25 |
26 |
27 | def resize(clip, target_size, interpolation_mode):
28 | if len(target_size) != 2:
29 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
30 | return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
31 |
32 |
33 | def resize_scale(clip, target_size, interpolation_mode):
34 | if len(target_size) != 2:
35 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
36 | _, _, H, W = clip.shape
37 | scale_ = target_size[0] / min(H, W)
38 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
39 |
40 |
41 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
42 | """
43 | Do spatial cropping and resizing to the video clip
44 | Args:
45 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
46 | i (int): i in (i,j) i.e coordinates of the upper left corner.
47 | j (int): j in (i,j) i.e coordinates of the upper left corner.
48 | h (int): Height of the cropped region.
49 | w (int): Width of the cropped region.
50 | size (tuple(int, int)): height and width of resized clip
51 | Returns:
52 | clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
53 | """
54 | if not _is_tensor_video_clip(clip):
55 | raise ValueError("clip should be a 4D torch.tensor")
56 | clip = crop(clip, i, j, h, w)
57 | clip = resize(clip, size, interpolation_mode)
58 | return clip
59 |
60 |
61 | def center_crop(clip, crop_size):
62 | if not _is_tensor_video_clip(clip):
63 | raise ValueError("clip should be a 4D torch.tensor")
64 | h, w = clip.size(-2), clip.size(-1)
65 | th, tw = crop_size
66 | if h < th or w < tw:
67 | raise ValueError("height and width must be no smaller than crop_size")
68 |
69 | i = int(round((h - th) / 2.0))
70 | j = int(round((w - tw) / 2.0))
71 | return crop(clip, i, j, th, tw)
72 |
73 |
74 | def random_shift_crop(clip):
75 | """
76 | Slide along the long edge, with the short edge as crop size
77 | """
78 | if not _is_tensor_video_clip(clip):
79 | raise ValueError("clip should be a 4D torch.tensor")
80 | h, w = clip.size(-2), clip.size(-1)
81 |
82 | if h <= w:
83 | # long_edge = w
84 | short_edge = h
85 | else:
86 | # long_edge = h
87 | short_edge = w
88 |
89 | th, tw = short_edge, short_edge
90 |
91 | i = torch.randint(0, h - th + 1, size=(1,)).item()
92 | j = torch.randint(0, w - tw + 1, size=(1,)).item()
93 | return crop(clip, i, j, th, tw)
94 |
95 |
96 | def to_tensor(clip):
97 | """
98 | Convert tensor data type from uint8 to float, divide value by 255.0 and
99 | permute the dimensions of clip tensor
100 | Args:
101 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
102 | Return:
103 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
104 | """
105 | _is_tensor_video_clip(clip)
106 | if not clip.dtype == torch.uint8:
107 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
108 | # return clip.float().permute(3, 0, 1, 2) / 255.0
109 | return clip.float() / 255.0
110 |
111 |
112 | def normalize(clip, mean, std, inplace=False):
113 | """
114 | Args:
115 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
116 | mean (tuple): pixel RGB mean. Size is (3)
117 | std (tuple): pixel standard deviation. Size is (3)
118 | Returns:
119 | normalized clip (torch.tensor): Size is (T, C, H, W)
120 | """
121 | if not _is_tensor_video_clip(clip):
122 | raise ValueError("clip should be a 4D torch.tensor")
123 | if not inplace:
124 | clip = clip.clone()
125 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
126 | print(mean)
127 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
128 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
129 | return clip
130 |
131 |
132 | def hflip(clip):
133 | """
134 | Args:
135 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
136 | Returns:
137 | flipped clip (torch.tensor): Size is (T, C, H, W)
138 | """
139 | if not _is_tensor_video_clip(clip):
140 | raise ValueError("clip should be a 4D torch.tensor")
141 | return clip.flip(-1)
142 |
143 |
144 | class RandomCropVideo:
145 | def __init__(self, size):
146 | if isinstance(size, numbers.Number):
147 | self.size = (int(size), int(size))
148 | else:
149 | self.size = size
150 |
151 | def __call__(self, clip):
152 | """
153 | Args:
154 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
155 | Returns:
156 | torch.tensor: randomly cropped video clip.
157 | size is (T, C, OH, OW)
158 | """
159 | i, j, h, w = self.get_params(clip)
160 | return crop(clip, i, j, h, w)
161 |
162 | def get_params(self, clip):
163 | h, w = clip.shape[-2:]
164 | th, tw = self.size
165 |
166 | if h < th or w < tw:
167 | raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
168 |
169 | if w == tw and h == th:
170 | return 0, 0, h, w
171 |
172 | i = torch.randint(0, h - th + 1, size=(1,)).item()
173 | j = torch.randint(0, w - tw + 1, size=(1,)).item()
174 |
175 | return i, j, th, tw
176 |
177 | def __repr__(self) -> str:
178 | return f"{self.__class__.__name__}(size={self.size})"
179 |
180 |
181 | class UCFCenterCropVideo:
182 | def __init__(
183 | self,
184 | size,
185 | interpolation_mode="bilinear",
186 | ):
187 | if isinstance(size, tuple):
188 | if len(size) != 2:
189 | raise ValueError(f"size should be tuple (height, width), instead got {size}")
190 | self.size = size
191 | else:
192 | self.size = (size, size)
193 |
194 | self.interpolation_mode = interpolation_mode
195 |
196 | def __call__(self, clip):
197 | """
198 | Args:
199 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
200 | Returns:
201 | torch.tensor: scale resized / center cropped video clip.
202 | size is (T, C, crop_size, crop_size)
203 | """
204 | clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
205 | clip_center_crop = center_crop(clip_resize, self.size)
206 | return clip_center_crop
207 |
208 | def __repr__(self) -> str:
209 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
210 |
211 |
212 | class KineticsRandomCropResizeVideo:
213 | """
214 | Slide along the long edge, with the short edge as crop size. And resize to the desired size.
215 | """
216 |
217 | def __init__(
218 | self,
219 | size,
220 | interpolation_mode="bilinear",
221 | ):
222 | if isinstance(size, tuple):
223 | if len(size) != 2:
224 | raise ValueError(f"size should be tuple (height, width), instead got {size}")
225 | self.size = size
226 | else:
227 | self.size = (size, size)
228 |
229 | self.interpolation_mode = interpolation_mode
230 |
231 | def __call__(self, clip):
232 | clip_random_crop = random_shift_crop(clip)
233 | clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
234 | return clip_resize
235 |
236 |
237 | class CenterCropVideo:
238 | def __init__(
239 | self,
240 | size,
241 | interpolation_mode="bilinear",
242 | ):
243 | if isinstance(size, tuple):
244 | if len(size) != 2:
245 | raise ValueError(f"size should be tuple (height, width), instead got {size}")
246 | self.size = size
247 | else:
248 | self.size = (size, size)
249 |
250 | self.interpolation_mode = interpolation_mode
251 |
252 | def __call__(self, clip):
253 | """
254 | Args:
255 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
256 | Returns:
257 | torch.tensor: center cropped video clip.
258 | size is (T, C, crop_size, crop_size)
259 | """
260 | clip_center_crop = center_crop(clip, self.size)
261 | return clip_center_crop
262 |
263 | def __repr__(self) -> str:
264 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
265 |
266 |
267 | class NormalizeVideo:
268 | """
269 | Normalize the video clip by mean subtraction and division by standard deviation
270 | Args:
271 | mean (3-tuple): pixel RGB mean
272 | std (3-tuple): pixel RGB standard deviation
273 | inplace (boolean): whether do in-place normalization
274 | """
275 |
276 | def __init__(self, mean, std, inplace=False):
277 | self.mean = mean
278 | self.std = std
279 | self.inplace = inplace
280 |
281 | def __call__(self, clip):
282 | """
283 | Args:
284 | clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
285 | """
286 | return normalize(clip, self.mean, self.std, self.inplace)
287 |
288 | def __repr__(self) -> str:
289 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
290 |
291 |
292 | class ToTensorVideo:
293 | """
294 | Convert tensor data type from uint8 to float, divide value by 255.0 and
295 | permute the dimensions of clip tensor
296 | """
297 |
298 | def __init__(self):
299 | pass
300 |
301 | def __call__(self, clip):
302 | """
303 | Args:
304 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
305 | Return:
306 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
307 | """
308 | return to_tensor(clip)
309 |
310 | def __repr__(self) -> str:
311 | return self.__class__.__name__
312 |
313 |
314 | class RandomHorizontalFlipVideo:
315 | """
316 | Flip the video clip along the horizontal direction with a given probability
317 | Args:
318 | p (float): probability of the clip being flipped. Default value is 0.5
319 | """
320 |
321 | def __init__(self, p=0.5):
322 | self.p = p
323 |
324 | def __call__(self, clip):
325 | """
326 | Args:
327 | clip (torch.tensor): Size is (T, C, H, W)
328 | Return:
329 | clip (torch.tensor): Size is (T, C, H, W)
330 | """
331 | if random.random() < self.p:
332 | clip = hflip(clip)
333 | return clip
334 |
335 | def __repr__(self) -> str:
336 | return f"{self.__class__.__name__}(p={self.p})"
337 |
338 |
339 | # ------------------------------------------------------------
340 | # --------------------- Sampling ---------------------------
341 | # ------------------------------------------------------------
342 | class TemporalRandomCrop(object):
343 | """Temporally crop the given frame indices at a random location.
344 |
345 | Args:
346 | size (int): Desired length of frames will be seen in the model.
347 | """
348 |
349 | def __init__(self, size):
350 | self.size = size
351 |
352 | def __call__(self, total_frames):
353 | rand_end = max(0, total_frames - self.size - 1)
354 | begin_index = random.randint(0, rand_end)
355 | end_index = min(begin_index + self.size, total_frames)
356 | return begin_index, end_index
357 |
358 |
359 | if __name__ == "__main__":
360 | import os
361 |
362 | import numpy as np
363 | import torchvision.io as io
364 | from torchvision import transforms
365 | from torchvision.utils import save_image
366 |
367 | vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
368 |
369 | trans = transforms.Compose(
370 | [
371 | ToTensorVideo(),
372 | RandomHorizontalFlipVideo(),
373 | UCFCenterCropVideo(512),
374 | # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
375 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
376 | ]
377 | )
378 |
379 | target_video_len = 32
380 | frame_interval = 1
381 | total_frames = len(vframes)
382 | print(total_frames)
383 |
384 | temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
385 |
386 | # Sampling video frames
387 | start_frame_ind, end_frame_ind = temporal_sample(total_frames)
388 | # print(start_frame_ind)
389 | # print(end_frame_ind)
390 | assert end_frame_ind - start_frame_ind >= target_video_len
391 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
392 |
393 | select_vframes = vframes[frame_indice]
394 |
395 | select_vframes_trans = trans(select_vframes)
396 |
397 | select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
398 |
399 | io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
400 |
401 | for i in range(target_video_len):
402 | save_image(
403 | select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
404 | )
405 |
--------------------------------------------------------------------------------
/animatediff/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/animatediff/models/__init__.py
--------------------------------------------------------------------------------
/animatediff/models/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/guoyww/AnimateDiff
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from einops import rearrange, repeat
9 | from torch import nn
10 |
11 | from diffusers.configuration_utils import ConfigMixin, register_to_config
12 | from diffusers.models import ModelMixin
13 | from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
14 | from diffusers.utils import BaseOutput
15 | from diffusers.utils.import_utils import is_xformers_available
16 |
17 |
18 | @dataclass
19 | class Transformer3DModelOutput(BaseOutput):
20 | sample: torch.FloatTensor
21 |
22 |
23 | if is_xformers_available():
24 | import xformers
25 | import xformers.ops
26 | else:
27 | xformers = None
28 |
29 |
30 | class Transformer3DModel(ModelMixin, ConfigMixin):
31 | @register_to_config
32 | def __init__(
33 | self,
34 | num_attention_heads: int = 16,
35 | attention_head_dim: int = 88,
36 | in_channels: Optional[int] = None,
37 | num_layers: int = 1,
38 | dropout: float = 0.0,
39 | norm_num_groups: int = 32,
40 | cross_attention_dim: Optional[int] = None,
41 | attention_bias: bool = False,
42 | activation_fn: str = "geglu",
43 | num_embeds_ada_norm: Optional[int] = None,
44 | use_linear_projection: bool = False,
45 | only_cross_attention: bool = False,
46 | upcast_attention: bool = False,
47 | unet_use_cross_frame_attention=None,
48 | unet_use_temporal_attention=None,
49 | ):
50 | super().__init__()
51 | self.use_linear_projection = use_linear_projection
52 | self.num_attention_heads = num_attention_heads
53 | self.attention_head_dim = attention_head_dim
54 | inner_dim = num_attention_heads * attention_head_dim
55 |
56 | # Define input layers
57 | self.in_channels = in_channels
58 |
59 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
60 | if use_linear_projection:
61 | self.proj_in = nn.Linear(in_channels, inner_dim)
62 | else:
63 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
64 |
65 | # Define transformers blocks
66 | self.transformer_blocks = nn.ModuleList(
67 | [
68 | BasicTransformerBlock(
69 | inner_dim,
70 | num_attention_heads,
71 | attention_head_dim,
72 | dropout=dropout,
73 | cross_attention_dim=cross_attention_dim,
74 | activation_fn=activation_fn,
75 | num_embeds_ada_norm=num_embeds_ada_norm,
76 | attention_bias=attention_bias,
77 | only_cross_attention=only_cross_attention,
78 | upcast_attention=upcast_attention,
79 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
80 | unet_use_temporal_attention=unet_use_temporal_attention,
81 | )
82 | for d in range(num_layers)
83 | ]
84 | )
85 |
86 | # 4. Define output layers
87 | if use_linear_projection:
88 | self.proj_out = nn.Linear(in_channels, inner_dim)
89 | else:
90 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
91 |
92 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
93 | # Input
94 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
95 | video_length = hidden_states.shape[2]
96 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
97 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length)
98 |
99 | batch, channel, height, weight = hidden_states.shape
100 | residual = hidden_states
101 |
102 | hidden_states = self.norm(hidden_states)
103 | if not self.use_linear_projection:
104 | hidden_states = self.proj_in(hidden_states)
105 | inner_dim = hidden_states.shape[1]
106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
107 | else:
108 | inner_dim = hidden_states.shape[1]
109 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
110 | hidden_states = self.proj_in(hidden_states)
111 |
112 | # Blocks
113 | for block in self.transformer_blocks:
114 | hidden_states = block(
115 | hidden_states,
116 | encoder_hidden_states=encoder_hidden_states,
117 | timestep=timestep,
118 | video_length=video_length,
119 | )
120 |
121 | # Output
122 | if not self.use_linear_projection:
123 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
124 | hidden_states = self.proj_out(hidden_states)
125 | else:
126 | hidden_states = self.proj_out(hidden_states)
127 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128 |
129 | output = hidden_states + residual
130 |
131 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
132 | if not return_dict:
133 | return (output,)
134 |
135 | return Transformer3DModelOutput(sample=output)
136 |
137 |
138 | class BasicTransformerBlock(nn.Module):
139 | def __init__(
140 | self,
141 | dim: int,
142 | num_attention_heads: int,
143 | attention_head_dim: int,
144 | dropout=0.0,
145 | cross_attention_dim: Optional[int] = None,
146 | activation_fn: str = "geglu",
147 | num_embeds_ada_norm: Optional[int] = None,
148 | attention_bias: bool = False,
149 | only_cross_attention: bool = False,
150 | upcast_attention: bool = False,
151 | unet_use_cross_frame_attention=None,
152 | unet_use_temporal_attention=None,
153 | ):
154 | super().__init__()
155 | self.only_cross_attention = only_cross_attention
156 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
157 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
158 | self.unet_use_temporal_attention = unet_use_temporal_attention
159 |
160 | # SC-Attn
161 | assert unet_use_cross_frame_attention is not None
162 | if unet_use_cross_frame_attention:
163 | self.attn1 = SparseCausalAttention(
164 | query_dim=dim,
165 | heads=num_attention_heads,
166 | dim_head=attention_head_dim,
167 | dropout=dropout,
168 | bias=attention_bias,
169 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
170 | upcast_attention=upcast_attention,
171 | )
172 | else:
173 | self.attn1 = Attention(
174 | query_dim=dim,
175 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
176 | heads=num_attention_heads,
177 | dim_head=attention_head_dim,
178 | dropout=dropout,
179 | bias=attention_bias,
180 | upcast_attention=upcast_attention,
181 | )
182 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
183 |
184 | # Cross-Attn
185 | if cross_attention_dim is not None:
186 | self.attn2 = Attention(
187 | query_dim=dim,
188 | cross_attention_dim=cross_attention_dim,
189 | heads=num_attention_heads,
190 | dim_head=attention_head_dim,
191 | dropout=dropout,
192 | bias=attention_bias,
193 | upcast_attention=upcast_attention,
194 | )
195 | else:
196 | self.attn2 = None
197 |
198 | if cross_attention_dim is not None:
199 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
200 | else:
201 | self.norm2 = None
202 |
203 | # Feed-forward
204 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
205 | self.norm3 = nn.LayerNorm(dim)
206 |
207 | # Temp-Attn
208 | assert unet_use_temporal_attention is not None
209 | if unet_use_temporal_attention:
210 | self.attn_temp = Attention(
211 | query_dim=dim,
212 | heads=num_attention_heads,
213 | dim_head=attention_head_dim,
214 | dropout=dropout,
215 | bias=attention_bias,
216 | upcast_attention=upcast_attention,
217 | )
218 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
219 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
220 |
221 | def forward(
222 | self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
223 | ):
224 | # SparseCausal-Attention
225 | norm_hidden_states = (
226 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
227 | )
228 |
229 | # if self.only_cross_attention:
230 | # hidden_states = (
231 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
232 | # )
233 | # else:
234 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
235 |
236 | if self.unet_use_cross_frame_attention:
237 | hidden_states = (
238 | self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
239 | + hidden_states
240 | )
241 | else:
242 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
243 |
244 | if self.attn2 is not None:
245 | # Cross-Attention
246 | norm_hidden_states = (
247 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
248 | )
249 | hidden_states = (
250 | self.attn2(
251 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
252 | )
253 | + hidden_states
254 | )
255 |
256 | # Feed-forward
257 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
258 |
259 | # Temporal-Attention
260 | if self.unet_use_temporal_attention:
261 | d = hidden_states.shape[1]
262 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
263 | norm_hidden_states = (
264 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
265 | )
266 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
267 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
268 |
269 | return hidden_states
270 |
271 |
272 | class CrossAttention(nn.Module):
273 | r"""
274 | A cross attention layer.
275 |
276 | Parameters:
277 | query_dim (`int`): The number of channels in the query.
278 | cross_attention_dim (`int`, *optional*):
279 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
280 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
281 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
282 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
283 | bias (`bool`, *optional*, defaults to False):
284 | Set to `True` for the query, key, and value linear layers to contain a bias parameter.
285 | """
286 |
287 | def __init__(
288 | self,
289 | query_dim: int,
290 | cross_attention_dim: Optional[int] = None,
291 | heads: int = 8,
292 | dim_head: int = 64,
293 | dropout: float = 0.0,
294 | bias=False,
295 | upcast_attention: bool = False,
296 | upcast_softmax: bool = False,
297 | added_kv_proj_dim: Optional[int] = None,
298 | norm_num_groups: Optional[int] = None,
299 | ):
300 | super().__init__()
301 | inner_dim = dim_head * heads
302 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
303 | self.upcast_attention = upcast_attention
304 | self.upcast_softmax = upcast_softmax
305 |
306 | self.scale = dim_head**-0.5
307 |
308 | self.heads = heads
309 | # for slice_size > 0 the attention score computation
310 | # is split across the batch axis to save memory
311 | # You can set slice_size with `set_attention_slice`
312 | self.sliceable_head_dim = heads
313 | self._slice_size = None
314 | self._use_memory_efficient_attention_xformers = False
315 | self.added_kv_proj_dim = added_kv_proj_dim
316 |
317 | if norm_num_groups is not None:
318 | self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
319 | else:
320 | self.group_norm = None
321 |
322 | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
323 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
324 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
325 |
326 | if self.added_kv_proj_dim is not None:
327 | self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
328 | self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
329 |
330 | self.to_out = nn.ModuleList([])
331 | self.to_out.append(nn.Linear(inner_dim, query_dim))
332 | self.to_out.append(nn.Dropout(dropout))
333 |
334 | def reshape_heads_to_batch_dim(self, tensor):
335 | batch_size, seq_len, dim = tensor.shape
336 | head_size = self.heads
337 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
338 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
339 | return tensor
340 |
341 | def reshape_batch_dim_to_heads(self, tensor):
342 | batch_size, seq_len, dim = tensor.shape
343 | head_size = self.heads
344 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
345 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
346 | return tensor
347 |
348 | def set_attention_slice(self, slice_size):
349 | if slice_size is not None and slice_size > self.sliceable_head_dim:
350 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
351 |
352 | self._slice_size = slice_size
353 |
354 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
355 | batch_size, sequence_length, _ = hidden_states.shape
356 |
357 | encoder_hidden_states = encoder_hidden_states
358 |
359 | if self.group_norm is not None:
360 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
361 |
362 | query = self.to_q(hidden_states)
363 | dim = query.shape[-1]
364 | query = self.reshape_heads_to_batch_dim(query)
365 |
366 | if self.added_kv_proj_dim is not None:
367 | key = self.to_k(hidden_states)
368 | value = self.to_v(hidden_states)
369 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
370 | encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
371 |
372 | key = self.reshape_heads_to_batch_dim(key)
373 | value = self.reshape_heads_to_batch_dim(value)
374 | encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
375 | encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
376 |
377 | key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
378 | value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
379 | else:
380 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
381 | key = self.to_k(encoder_hidden_states)
382 | value = self.to_v(encoder_hidden_states)
383 |
384 | key = self.reshape_heads_to_batch_dim(key)
385 | value = self.reshape_heads_to_batch_dim(value)
386 |
387 | if attention_mask is not None:
388 | if attention_mask.shape[-1] != query.shape[1]:
389 | target_length = query.shape[1]
390 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
391 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
392 |
393 | # attention, what we cannot get enough of
394 | if self._use_memory_efficient_attention_xformers:
395 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
396 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
397 | hidden_states = hidden_states.to(query.dtype)
398 | else:
399 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
400 | hidden_states = self._attention(query, key, value, attention_mask)
401 | else:
402 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
403 |
404 | # linear proj
405 | hidden_states = self.to_out[0](hidden_states)
406 |
407 | # dropout
408 | hidden_states = self.to_out[1](hidden_states)
409 | return hidden_states
410 |
411 | def _attention(self, query, key, value, attention_mask=None):
412 | if self.upcast_attention:
413 | query = query.float()
414 | key = key.float()
415 |
416 | attention_scores = torch.baddbmm(
417 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
418 | query,
419 | key.transpose(-1, -2),
420 | beta=0,
421 | alpha=self.scale,
422 | )
423 |
424 | if attention_mask is not None:
425 | attention_scores = attention_scores + attention_mask
426 |
427 | if self.upcast_softmax:
428 | attention_scores = attention_scores.float()
429 |
430 | attention_probs = attention_scores.softmax(dim=-1)
431 |
432 | # cast back to the original dtype
433 | attention_probs = attention_probs.to(value.dtype)
434 |
435 | # compute attention output
436 | hidden_states = torch.bmm(attention_probs, value)
437 |
438 | # reshape hidden_states
439 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
440 | return hidden_states
441 |
442 | def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
443 | batch_size_attention = query.shape[0]
444 | hidden_states = torch.zeros(
445 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
446 | )
447 | slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
448 | for i in range(hidden_states.shape[0] // slice_size):
449 | start_idx = i * slice_size
450 | end_idx = (i + 1) * slice_size
451 |
452 | query_slice = query[start_idx:end_idx]
453 | key_slice = key[start_idx:end_idx]
454 |
455 | if self.upcast_attention:
456 | query_slice = query_slice.float()
457 | key_slice = key_slice.float()
458 |
459 | attn_slice = torch.baddbmm(
460 | torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
461 | query_slice,
462 | key_slice.transpose(-1, -2),
463 | beta=0,
464 | alpha=self.scale,
465 | )
466 |
467 | if attention_mask is not None:
468 | attn_slice = attn_slice + attention_mask[start_idx:end_idx]
469 |
470 | if self.upcast_softmax:
471 | attn_slice = attn_slice.float()
472 |
473 | attn_slice = attn_slice.softmax(dim=-1)
474 |
475 | # cast back to the original dtype
476 | attn_slice = attn_slice.to(value.dtype)
477 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
478 |
479 | hidden_states[start_idx:end_idx] = attn_slice
480 |
481 | # reshape hidden_states
482 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
483 | return hidden_states
484 |
485 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
486 | # TODO attention_mask
487 | query = query.contiguous()
488 | key = key.contiguous()
489 | value = value.contiguous()
490 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
491 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
492 | return hidden_states
493 |
494 |
495 | class SparseCausalAttention(CrossAttention):
496 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
497 | batch_size, sequence_length, _ = hidden_states.shape
498 |
499 | encoder_hidden_states = encoder_hidden_states
500 |
501 | if self.group_norm is not None:
502 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
503 |
504 | query = self.to_q(hidden_states)
505 | dim = query.shape[-1]
506 | query = self.reshape_heads_to_batch_dim(query)
507 |
508 | if self.added_kv_proj_dim is not None:
509 | raise NotImplementedError
510 |
511 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
512 | key = self.to_k(encoder_hidden_states)
513 | value = self.to_v(encoder_hidden_states)
514 |
515 | former_frame_index = torch.arange(video_length) - 1
516 | former_frame_index[0] = 0
517 |
518 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
519 | # key = torch.cat([key[:, [0] * video_length], key[:, [0] * video_length]], dim=2)
520 | key = key[:, [0] * video_length]
521 | key = rearrange(key, "b f d c -> (b f) d c")
522 |
523 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
524 | # value = torch.cat([value[:, [0] * video_length], value[:, [0] * video_length]], dim=2)
525 | # value = value[:, former_frame_index]
526 | value = rearrange(value, "b f d c -> (b f) d c")
527 |
528 | key = self.reshape_heads_to_batch_dim(key)
529 | value = self.reshape_heads_to_batch_dim(value)
530 |
531 | if attention_mask is not None:
532 | if attention_mask.shape[-1] != query.shape[1]:
533 | target_length = query.shape[1]
534 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
535 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
536 |
537 | # attention, what we cannot get enough of
538 | if self._use_memory_efficient_attention_xformers:
539 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
540 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
541 | hidden_states = hidden_states.to(query.dtype)
542 | else:
543 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
544 | hidden_states = self._attention(query, key, value, attention_mask)
545 | else:
546 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
547 |
548 | # linear proj
549 | hidden_states = self.to_out[0](hidden_states)
550 |
551 | # dropout
552 | hidden_states = self.to_out[1](hidden_states)
553 | return hidden_states
554 |
--------------------------------------------------------------------------------
/animatediff/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/guoyww/AnimateDiff
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | class InflatedConv3d(nn.Conv2d):
10 | def forward(self, x):
11 | video_length = x.shape[2]
12 |
13 | x = rearrange(x, "b c f h w -> (b f) c h w")
14 | x = super().forward(x)
15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16 |
17 | return x
18 |
19 |
20 | class Upsample3D(nn.Module):
21 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
22 | super().__init__()
23 | self.channels = channels
24 | self.out_channels = out_channels or channels
25 | self.use_conv = use_conv
26 | self.use_conv_transpose = use_conv_transpose
27 | self.name = name
28 |
29 | # conv = None
30 | if use_conv_transpose:
31 | raise NotImplementedError
32 | elif use_conv:
33 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
34 |
35 | def forward(self, hidden_states, output_size=None):
36 | assert hidden_states.shape[1] == self.channels
37 |
38 | if self.use_conv_transpose:
39 | raise NotImplementedError
40 |
41 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
42 | dtype = hidden_states.dtype
43 | if dtype == torch.bfloat16:
44 | hidden_states = hidden_states.to(torch.float32)
45 |
46 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
47 | if hidden_states.shape[0] >= 64:
48 | hidden_states = hidden_states.contiguous()
49 |
50 | # if `output_size` is passed we force the interpolation output
51 | # size and do not make use of `scale_factor=2`
52 | if output_size is None:
53 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
54 | else:
55 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
56 |
57 | # If the input is bfloat16, we cast back to bfloat16
58 | if dtype == torch.bfloat16:
59 | hidden_states = hidden_states.to(dtype)
60 |
61 | # if self.use_conv:
62 | # if self.name == "conv":
63 | # hidden_states = self.conv(hidden_states)
64 | # else:
65 | # hidden_states = self.Conv2d_0(hidden_states)
66 | hidden_states = self.conv(hidden_states)
67 |
68 | return hidden_states
69 |
70 |
71 | class Downsample3D(nn.Module):
72 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
73 | super().__init__()
74 | self.channels = channels
75 | self.out_channels = out_channels or channels
76 | self.use_conv = use_conv
77 | self.padding = padding
78 | stride = 2
79 | self.name = name
80 |
81 | if use_conv:
82 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
83 | else:
84 | raise NotImplementedError
85 |
86 | def forward(self, hidden_states):
87 | assert hidden_states.shape[1] == self.channels
88 | if self.use_conv and self.padding == 0:
89 | raise NotImplementedError
90 |
91 | assert hidden_states.shape[1] == self.channels
92 | hidden_states = self.conv(hidden_states)
93 |
94 | return hidden_states
95 |
96 |
97 | class ResnetBlock3D(nn.Module):
98 | def __init__(
99 | self,
100 | *,
101 | in_channels,
102 | out_channels=None,
103 | conv_shortcut=False,
104 | dropout=0.0,
105 | temb_channels=512,
106 | groups=32,
107 | groups_out=None,
108 | pre_norm=True,
109 | eps=1e-6,
110 | non_linearity="swish",
111 | time_embedding_norm="default",
112 | output_scale_factor=1.0,
113 | use_in_shortcut=None,
114 | ):
115 | super().__init__()
116 | self.pre_norm = pre_norm
117 | self.pre_norm = True
118 | self.in_channels = in_channels
119 | out_channels = in_channels if out_channels is None else out_channels
120 | self.out_channels = out_channels
121 | self.use_conv_shortcut = conv_shortcut
122 | self.time_embedding_norm = time_embedding_norm
123 | self.output_scale_factor = output_scale_factor
124 |
125 | if groups_out is None:
126 | groups_out = groups
127 |
128 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
129 |
130 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
131 |
132 | if temb_channels is not None:
133 | if self.time_embedding_norm == "default":
134 | time_emb_proj_out_channels = out_channels
135 | elif self.time_embedding_norm == "scale_shift":
136 | time_emb_proj_out_channels = out_channels * 2
137 | else:
138 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
139 |
140 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
141 | else:
142 | self.time_emb_proj = None
143 |
144 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
145 | self.dropout = torch.nn.Dropout(dropout)
146 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
147 |
148 | if non_linearity == "swish":
149 | self.nonlinearity = lambda x: F.silu(x)
150 | elif non_linearity == "mish":
151 | self.nonlinearity = Mish()
152 | elif non_linearity == "silu":
153 | self.nonlinearity = nn.SiLU()
154 |
155 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
156 |
157 | self.conv_shortcut = None
158 | if self.use_in_shortcut:
159 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
160 |
161 | def forward(self, input_tensor, temb):
162 | hidden_states = input_tensor
163 |
164 | hidden_states = self.norm1(hidden_states)
165 | hidden_states = self.nonlinearity(hidden_states)
166 |
167 | hidden_states = self.conv1(hidden_states)
168 |
169 | if temb is not None:
170 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
171 |
172 | if temb is not None and self.time_embedding_norm == "default":
173 | hidden_states = hidden_states + temb
174 |
175 | hidden_states = self.norm2(hidden_states)
176 |
177 | if temb is not None and self.time_embedding_norm == "scale_shift":
178 | scale, shift = torch.chunk(temb, 2, dim=1)
179 | hidden_states = hidden_states * (1 + scale) + shift
180 |
181 | hidden_states = self.nonlinearity(hidden_states)
182 |
183 | hidden_states = self.dropout(hidden_states)
184 | hidden_states = self.conv2(hidden_states)
185 |
186 | if self.conv_shortcut is not None:
187 | input_tensor = self.conv_shortcut(input_tensor)
188 |
189 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
190 |
191 | return output_tensor
192 |
193 |
194 | class Mish(torch.nn.Module):
195 | def forward(self, hidden_states):
196 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
197 |
--------------------------------------------------------------------------------
/animatediff/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .i2v_pipeline import I2VPipeline
2 | from .pipeline_animation import AnimationPipeline
3 | from .validation_pipeline import ValidationPipeline
4 |
5 |
6 | __all__ = ["I2VPipeline", "AnimationPipeline", "ValidationPipeline"]
7 |
--------------------------------------------------------------------------------
/animatediff/pipelines/pipeline_animation.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2 |
3 | import inspect
4 | from dataclasses import dataclass
5 | from typing import Callable, List, Optional, Union
6 |
7 | import numpy as np
8 | import torch
9 | from einops import rearrange
10 | from packaging import version
11 | from tqdm import tqdm
12 | from transformers import CLIPTextModel, CLIPTokenizer
13 |
14 | from diffusers.configuration_utils import FrozenDict
15 | from diffusers.models import AutoencoderKL
16 | from diffusers.pipelines import DiffusionPipeline
17 | from diffusers.schedulers import (
18 | DDIMScheduler,
19 | DPMSolverMultistepScheduler,
20 | EulerAncestralDiscreteScheduler,
21 | EulerDiscreteScheduler,
22 | LMSDiscreteScheduler,
23 | PNDMScheduler,
24 | )
25 | from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
26 |
27 | from ..models.unet import UNet3DConditionModel
28 |
29 |
30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31 |
32 |
33 | @dataclass
34 | class AnimationPipelineOutput(BaseOutput):
35 | videos: Union[torch.Tensor, np.ndarray]
36 |
37 |
38 | class AnimationPipeline(DiffusionPipeline):
39 | _optional_components = []
40 |
41 | def __init__(
42 | self,
43 | vae: AutoencoderKL,
44 | text_encoder: CLIPTextModel,
45 | tokenizer: CLIPTokenizer,
46 | unet: UNet3DConditionModel,
47 | scheduler: Union[
48 | DDIMScheduler,
49 | PNDMScheduler,
50 | LMSDiscreteScheduler,
51 | EulerDiscreteScheduler,
52 | EulerAncestralDiscreteScheduler,
53 | DPMSolverMultistepScheduler,
54 | ],
55 | ):
56 | super().__init__()
57 |
58 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
59 | deprecation_message = (
60 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
61 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
62 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
63 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
64 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
65 | " file"
66 | )
67 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
68 | new_config = dict(scheduler.config)
69 | new_config["steps_offset"] = 1
70 | scheduler._internal_dict = FrozenDict(new_config)
71 |
72 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
73 | deprecation_message = (
74 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
75 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
76 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
77 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
78 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
79 | )
80 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
81 | new_config = dict(scheduler.config)
82 | new_config["clip_sample"] = False
83 | scheduler._internal_dict = FrozenDict(new_config)
84 |
85 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
86 | version.parse(unet.config._diffusers_version).base_version
87 | ) < version.parse("0.9.0.dev0")
88 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
89 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
90 | deprecation_message = (
91 | "The configuration file of the unet has set the default `sample_size` to smaller than"
92 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
93 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
94 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
95 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
96 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
97 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
98 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
99 | " the `unet/config.json` file"
100 | )
101 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
102 | new_config = dict(unet.config)
103 | new_config["sample_size"] = 64
104 | unet._internal_dict = FrozenDict(new_config)
105 |
106 | self.register_modules(
107 | vae=vae,
108 | text_encoder=text_encoder,
109 | tokenizer=tokenizer,
110 | unet=unet,
111 | scheduler=scheduler,
112 | )
113 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
114 |
115 | def enable_vae_slicing(self):
116 | self.vae.enable_slicing()
117 |
118 | def disable_vae_slicing(self):
119 | self.vae.disable_slicing()
120 |
121 | def enable_sequential_cpu_offload(self, gpu_id=0):
122 | if is_accelerate_available():
123 | from accelerate import cpu_offload
124 | else:
125 | raise ImportError("Please install accelerate via `pip install accelerate`")
126 |
127 | device = torch.device(f"cuda:{gpu_id}")
128 |
129 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
130 | if cpu_offloaded_model is not None:
131 | cpu_offload(cpu_offloaded_model, device)
132 |
133 | @property
134 | def _execution_device(self):
135 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
136 | return self.device
137 | for module in self.unet.modules():
138 | if (
139 | hasattr(module, "_hf_hook")
140 | and hasattr(module._hf_hook, "execution_device")
141 | and module._hf_hook.execution_device is not None
142 | ):
143 | return torch.device(module._hf_hook.execution_device)
144 | return self.device
145 |
146 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
147 | batch_size = len(prompt) if isinstance(prompt, list) else 1
148 |
149 | text_inputs = self.tokenizer(
150 | prompt,
151 | padding="max_length",
152 | max_length=self.tokenizer.model_max_length,
153 | truncation=True,
154 | return_tensors="pt",
155 | )
156 | text_input_ids = text_inputs.input_ids
157 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
158 |
159 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
160 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
161 | logger.warning(
162 | "The following part of your input was truncated because CLIP can only handle sequences up to"
163 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
164 | )
165 |
166 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
167 | attention_mask = text_inputs.attention_mask.to(device)
168 | else:
169 | attention_mask = None
170 |
171 | text_embeddings = self.text_encoder(
172 | text_input_ids.to(device),
173 | attention_mask=attention_mask,
174 | )
175 | text_embeddings = text_embeddings[0]
176 |
177 | # duplicate text embeddings for each generation per prompt, using mps friendly method
178 | bs_embed, seq_len, _ = text_embeddings.shape
179 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
180 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
181 |
182 | # get unconditional embeddings for classifier free guidance
183 | if do_classifier_free_guidance:
184 | uncond_tokens: List[str]
185 | if negative_prompt is None:
186 | uncond_tokens = [""] * batch_size
187 | elif type(prompt) is not type(negative_prompt):
188 | raise TypeError(
189 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
190 | f" {type(prompt)}."
191 | )
192 | elif isinstance(negative_prompt, str):
193 | uncond_tokens = [negative_prompt]
194 | elif batch_size != len(negative_prompt):
195 | raise ValueError(
196 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
197 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
198 | " the batch size of `prompt`."
199 | )
200 | else:
201 | uncond_tokens = negative_prompt
202 |
203 | max_length = text_input_ids.shape[-1]
204 | uncond_input = self.tokenizer(
205 | uncond_tokens,
206 | padding="max_length",
207 | max_length=max_length,
208 | truncation=True,
209 | return_tensors="pt",
210 | )
211 |
212 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
213 | attention_mask = uncond_input.attention_mask.to(device)
214 | else:
215 | attention_mask = None
216 |
217 | uncond_embeddings = self.text_encoder(
218 | uncond_input.input_ids.to(device),
219 | attention_mask=attention_mask,
220 | )
221 | uncond_embeddings = uncond_embeddings[0]
222 |
223 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
224 | seq_len = uncond_embeddings.shape[1]
225 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
226 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
227 |
228 | # For classifier free guidance, we need to do two forward passes.
229 | # Here we concatenate the unconditional and text embeddings into a single batch
230 | # to avoid doing two forward passes
231 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
232 |
233 | return text_embeddings
234 |
235 | def decode_latents(self, latents):
236 | video_length = latents.shape[2]
237 | latents = 1 / 0.18215 * latents
238 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
239 | # video = self.vae.decode(latents).sample
240 | video = []
241 | for frame_idx in tqdm(range(latents.shape[0])):
242 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
243 | video = torch.cat(video)
244 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
245 | video = (video / 2 + 0.5).clamp(0, 1)
246 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
247 | video = video.cpu().float().numpy()
248 | return video
249 |
250 | def prepare_extra_step_kwargs(self, generator, eta):
251 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
252 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
253 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
254 | # and should be between [0, 1]
255 |
256 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
257 | extra_step_kwargs = {}
258 | if accepts_eta:
259 | extra_step_kwargs["eta"] = eta
260 |
261 | # check if the scheduler accepts generator
262 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
263 | if accepts_generator:
264 | extra_step_kwargs["generator"] = generator
265 | return extra_step_kwargs
266 |
267 | def check_inputs(self, prompt, height, width, callback_steps):
268 | if not isinstance(prompt, str) and not isinstance(prompt, list):
269 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
270 |
271 | if height % 8 != 0 or width % 8 != 0:
272 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
273 |
274 | if (callback_steps is None) or (
275 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
276 | ):
277 | raise ValueError(
278 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
279 | f" {type(callback_steps)}."
280 | )
281 |
282 | def prepare_latents(
283 | self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None
284 | ):
285 | shape = (
286 | batch_size,
287 | num_channels_latents,
288 | video_length,
289 | height // self.vae_scale_factor,
290 | width // self.vae_scale_factor,
291 | )
292 | if isinstance(generator, list) and len(generator) != batch_size:
293 | raise ValueError(
294 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
295 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
296 | )
297 | if latents is None:
298 | rand_device = "cpu" if device.type == "mps" else device
299 |
300 | if isinstance(generator, list):
301 | shape = shape
302 | # shape = (1,) + shape[1:]
303 | latents = [
304 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
305 | for i in range(batch_size)
306 | ]
307 | latents = torch.cat(latents, dim=0).to(device)
308 | else:
309 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
310 | else:
311 | if latents.shape != shape:
312 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
313 | latents = latents.to(device)
314 |
315 | # scale the initial noise by the standard deviation required by the scheduler
316 | latents = latents * self.scheduler.init_noise_sigma
317 | return latents
318 |
319 | @torch.no_grad()
320 | def __call__(
321 | self,
322 | prompt: Union[str, List[str]],
323 | video_length: Optional[int],
324 | height: Optional[int] = None,
325 | width: Optional[int] = None,
326 | num_inference_steps: int = 50,
327 | guidance_scale: float = 7.5,
328 | negative_prompt: Optional[Union[str, List[str]]] = None,
329 | num_videos_per_prompt: Optional[int] = 1,
330 | eta: float = 0.0,
331 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
332 | latents: Optional[torch.FloatTensor] = None,
333 | output_type: Optional[str] = "tensor",
334 | return_dict: bool = True,
335 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
336 | callback_steps: Optional[int] = 1,
337 | **kwargs,
338 | ):
339 | # Default height and width to unet
340 | height = height or self.unet.config.sample_size * self.vae_scale_factor
341 | width = width or self.unet.config.sample_size * self.vae_scale_factor
342 |
343 | # Check inputs. Raise error if not correct
344 | self.check_inputs(prompt, height, width, callback_steps)
345 |
346 | # Define call parameters
347 | # batch_size = 1 if isinstance(prompt, str) else len(prompt)
348 | batch_size = 1
349 | if latents is not None:
350 | batch_size = latents.shape[0]
351 | if isinstance(prompt, list):
352 | batch_size = len(prompt)
353 |
354 | device = self._execution_device
355 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
356 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
357 | # corresponds to doing no classifier free guidance.
358 | do_classifier_free_guidance = guidance_scale > 1.0
359 |
360 | # Encode input prompt
361 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
362 | if negative_prompt is not None:
363 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
364 | text_embeddings = self._encode_prompt(
365 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
366 | )
367 |
368 | # Prepare timesteps
369 | self.scheduler.set_timesteps(num_inference_steps, device=device)
370 | timesteps = self.scheduler.timesteps
371 |
372 | # Prepare latent variables
373 | num_channels_latents = self.unet.in_channels
374 | latents = self.prepare_latents(
375 | batch_size * num_videos_per_prompt,
376 | num_channels_latents,
377 | video_length,
378 | height,
379 | width,
380 | text_embeddings.dtype,
381 | device,
382 | generator,
383 | latents,
384 | )
385 | latents_dtype = latents.dtype
386 |
387 | # Prepare extra step kwargs.
388 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
389 |
390 | # Denoising loop
391 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
392 | with self.progress_bar(total=num_inference_steps) as progress_bar:
393 | for i, t in enumerate(timesteps):
394 | # expand the latents if we are doing classifier free guidance
395 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
396 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
397 |
398 | # predict the noise residual
399 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(
400 | dtype=latents_dtype
401 | )
402 |
403 | # perform guidance
404 | if do_classifier_free_guidance:
405 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
406 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
407 |
408 | # compute the previous noisy sample x_t -> x_t-1
409 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
410 |
411 | # call the callback, if provided
412 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
413 | progress_bar.update()
414 | if callback is not None and i % callback_steps == 0:
415 | callback(i, t, latents)
416 |
417 | # Post-processing
418 | video = self.decode_latents(latents)
419 |
420 | # Convert to tensor
421 | if output_type == "tensor":
422 | video = torch.from_numpy(video)
423 |
424 | if not return_dict:
425 | return video
426 |
427 | return AnimationPipelineOutput(videos=video)
428 |
--------------------------------------------------------------------------------
/animatediff/pipelines/validation_pipeline.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2 | import inspect
3 | from dataclasses import dataclass
4 | from typing import Callable, List, Optional, Union
5 |
6 | import numpy as np
7 | import torch
8 | from einops import rearrange
9 | from packaging import version
10 | from PIL import Image
11 | from tqdm import tqdm
12 | from transformers import CLIPTextModel, CLIPTokenizer
13 |
14 | from animatediff.models.unet import UNet3DConditionModel
15 | from animatediff.utils.util import prepare_mask_coef
16 | from diffusers.configuration_utils import FrozenDict
17 | from diffusers.models import AutoencoderKL
18 | from diffusers.pipelines import DiffusionPipeline
19 | from diffusers.schedulers import (
20 | DDIMScheduler,
21 | DPMSolverMultistepScheduler,
22 | EulerAncestralDiscreteScheduler,
23 | EulerDiscreteScheduler,
24 | LMSDiscreteScheduler,
25 | PNDMScheduler,
26 | )
27 | from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
28 |
29 |
30 | PIL_INTERPOLATION = {
31 | "linear": Image.Resampling.BILINEAR,
32 | "bilinear": Image.Resampling.BILINEAR,
33 | "bicubic": Image.Resampling.BICUBIC,
34 | "lanczos": Image.Resampling.LANCZOS,
35 | "nearest": Image.Resampling.NEAREST,
36 | }
37 |
38 |
39 | def preprocess_image(image):
40 | if isinstance(image, torch.Tensor):
41 | return image
42 | elif isinstance(image, Image.Image):
43 | image = [image]
44 |
45 | if isinstance(image[0], Image.Image):
46 | w, h = image[0].size
47 | w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
48 |
49 | image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
50 | image = np.concatenate(image, axis=0)
51 | if len(image.shape) == 3:
52 | image = image.reshape(image.shape[0], image.shape[1], image.shape[2], 1)
53 | image = np.array(image).astype(np.float32) / 255.0
54 | image = image.transpose(0, 3, 1, 2)
55 | image = 2.0 * image - 1.0
56 | image = torch.from_numpy(image)
57 | elif isinstance(image[0], torch.Tensor):
58 | image = torch.cat(image, dim=0)
59 | return image
60 |
61 |
62 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
63 |
64 |
65 | @dataclass
66 | class AnimationPipelineOutput(BaseOutput):
67 | videos: Union[torch.Tensor, np.ndarray]
68 |
69 |
70 | class ValidationPipeline(DiffusionPipeline):
71 | _optional_components = []
72 |
73 | def __init__(
74 | self,
75 | vae: AutoencoderKL,
76 | text_encoder: CLIPTextModel,
77 | tokenizer: CLIPTokenizer,
78 | unet: UNet3DConditionModel,
79 | scheduler: Union[
80 | DDIMScheduler,
81 | PNDMScheduler,
82 | LMSDiscreteScheduler,
83 | EulerDiscreteScheduler,
84 | EulerAncestralDiscreteScheduler,
85 | DPMSolverMultistepScheduler,
86 | ],
87 | ):
88 | super().__init__()
89 |
90 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
91 | deprecation_message = (
92 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
93 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
94 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
95 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
96 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
97 | " file"
98 | )
99 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
100 | new_config = dict(scheduler.config)
101 | new_config["steps_offset"] = 1
102 | scheduler._internal_dict = FrozenDict(new_config)
103 |
104 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
105 | deprecation_message = (
106 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
107 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
108 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
109 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
110 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
111 | )
112 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
113 | new_config = dict(scheduler.config)
114 | new_config["clip_sample"] = False
115 | scheduler._internal_dict = FrozenDict(new_config)
116 |
117 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
118 | version.parse(unet.config._diffusers_version).base_version
119 | ) < version.parse("0.9.0.dev0")
120 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
121 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
122 | deprecation_message = (
123 | "The configuration file of the unet has set the default `sample_size` to smaller than"
124 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
125 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
126 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
127 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
128 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
129 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
130 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
131 | " the `unet/config.json` file"
132 | )
133 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
134 | new_config = dict(unet.config)
135 | new_config["sample_size"] = 64
136 | unet._internal_dict = FrozenDict(new_config)
137 |
138 | self.register_modules(
139 | vae=vae,
140 | text_encoder=text_encoder,
141 | tokenizer=tokenizer,
142 | unet=unet,
143 | scheduler=scheduler,
144 | )
145 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
146 |
147 | def enable_vae_slicing(self):
148 | self.vae.enable_slicing()
149 |
150 | def disable_vae_slicing(self):
151 | self.vae.disable_slicing()
152 |
153 | def enable_sequential_cpu_offload(self, gpu_id=0):
154 | if is_accelerate_available():
155 | from accelerate import cpu_offload
156 | else:
157 | raise ImportError("Please install accelerate via `pip install accelerate`")
158 |
159 | device = torch.device(f"cuda:{gpu_id}")
160 |
161 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
162 | if cpu_offloaded_model is not None:
163 | cpu_offload(cpu_offloaded_model, device)
164 |
165 | @property
166 | def _execution_device(self):
167 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
168 | return self.device
169 | for module in self.unet.modules():
170 | if (
171 | hasattr(module, "_hf_hook")
172 | and hasattr(module._hf_hook, "execution_device")
173 | and module._hf_hook.execution_device is not None
174 | ):
175 | return torch.device(module._hf_hook.execution_device)
176 | return self.device
177 |
178 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
179 | batch_size = len(prompt) if isinstance(prompt, list) else 1
180 |
181 | text_inputs = self.tokenizer(
182 | prompt,
183 | padding="max_length",
184 | max_length=self.tokenizer.model_max_length,
185 | truncation=True,
186 | return_tensors="pt",
187 | )
188 | text_input_ids = text_inputs.input_ids
189 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
190 |
191 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
192 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
193 | logger.warning(
194 | "The following part of your input was truncated because CLIP can only handle sequences up to"
195 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
196 | )
197 |
198 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
199 | attention_mask = text_inputs.attention_mask.to(device)
200 | else:
201 | attention_mask = None
202 |
203 | text_embeddings = self.text_encoder(
204 | text_input_ids.to(device),
205 | attention_mask=attention_mask,
206 | )
207 | text_embeddings = text_embeddings[0]
208 |
209 | # duplicate text embeddings for each generation per prompt, using mps friendly method
210 | bs_embed, seq_len, _ = text_embeddings.shape
211 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
212 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
213 |
214 | # get unconditional embeddings for classifier free guidance
215 | if do_classifier_free_guidance:
216 | uncond_tokens: List[str]
217 | if negative_prompt is None:
218 | uncond_tokens = [""] * batch_size
219 | elif type(prompt) is not type(negative_prompt):
220 | raise TypeError(
221 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
222 | f" {type(prompt)}."
223 | )
224 | elif isinstance(negative_prompt, str):
225 | uncond_tokens = [negative_prompt]
226 | elif batch_size != len(negative_prompt):
227 | raise ValueError(
228 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
229 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
230 | " the batch size of `prompt`."
231 | )
232 | else:
233 | uncond_tokens = negative_prompt
234 |
235 | max_length = text_input_ids.shape[-1]
236 | uncond_input = self.tokenizer(
237 | uncond_tokens,
238 | padding="max_length",
239 | max_length=max_length,
240 | truncation=True,
241 | return_tensors="pt",
242 | )
243 |
244 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
245 | attention_mask = uncond_input.attention_mask.to(device)
246 | else:
247 | attention_mask = None
248 |
249 | uncond_embeddings = self.text_encoder(
250 | uncond_input.input_ids.to(device),
251 | attention_mask=attention_mask,
252 | )
253 | uncond_embeddings = uncond_embeddings[0]
254 |
255 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
256 | seq_len = uncond_embeddings.shape[1]
257 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
258 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
259 |
260 | # For classifier free guidance, we need to do two forward passes.
261 | # Here we concatenate the unconditional and text embeddings into a single batch
262 | # to avoid doing two forward passes
263 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
264 |
265 | return text_embeddings
266 |
267 | def decode_latents(self, latents):
268 | video_length = latents.shape[2]
269 | latents = 1 / 0.18215 * latents
270 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
271 | # video = self.vae.decode(latents).sample
272 | video = []
273 | for frame_idx in tqdm(range(latents.shape[0])):
274 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
275 | video = torch.cat(video)
276 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
277 | video = (video / 2 + 0.5).clamp(0, 1)
278 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
279 | video = video.cpu().float().numpy()
280 | return video
281 |
282 | def prepare_extra_step_kwargs(self, generator, eta):
283 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
284 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
285 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
286 | # and should be between [0, 1]
287 |
288 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
289 | extra_step_kwargs = {}
290 | if accepts_eta:
291 | extra_step_kwargs["eta"] = eta
292 |
293 | # check if the scheduler accepts generator
294 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
295 | if accepts_generator:
296 | extra_step_kwargs["generator"] = generator
297 | return extra_step_kwargs
298 |
299 | def check_inputs(self, prompt, height, width, callback_steps):
300 | if not isinstance(prompt, str) and not isinstance(prompt, list):
301 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
302 |
303 | if height % 8 != 0 or width % 8 != 0:
304 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
305 |
306 | if (callback_steps is None) or (
307 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
308 | ):
309 | raise ValueError(
310 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
311 | f" {type(callback_steps)}."
312 | )
313 |
314 | def prepare_latents(
315 | self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None
316 | ):
317 | shape = (
318 | batch_size,
319 | num_channels_latents,
320 | video_length,
321 | height // self.vae_scale_factor,
322 | width // self.vae_scale_factor,
323 | )
324 |
325 | if isinstance(generator, list) and len(generator) != batch_size:
326 | raise ValueError(
327 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
328 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
329 | )
330 | if latents is None:
331 | rand_device = "cpu" if device.type == "mps" else device
332 |
333 | if isinstance(generator, list):
334 | shape = shape
335 | # shape = (1,) + shape[1:]
336 | latents = [
337 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
338 | for i in range(batch_size)
339 | ]
340 | latents = torch.cat(latents, dim=0).to(device)
341 | else:
342 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
343 | else:
344 | if latents.shape != shape:
345 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
346 | latents = latents.to(device)
347 |
348 | # scale the initial noise by the standard deviation required by the scheduler
349 | latents = latents * self.scheduler.init_noise_sigma
350 | return latents
351 |
352 | @torch.no_grad()
353 | def __call__(
354 | self,
355 | prompt: Union[str, List[str]],
356 | use_image: bool,
357 | video_length: Optional[int],
358 | height: Optional[int] = None,
359 | width: Optional[int] = None,
360 | num_inference_steps: int = 50,
361 | guidance_scale: float = 7.5,
362 | negative_prompt: Optional[Union[str, List[str]]] = None,
363 | num_videos_per_prompt: Optional[int] = 1,
364 | eta: float = 0.0,
365 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
366 | latents: Optional[torch.FloatTensor] = None,
367 | output_type: Optional[str] = "tensor",
368 | return_dict: bool = True,
369 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
370 | callback_steps: Optional[int] = 1,
371 | **kwargs,
372 | ):
373 | # Default height and width to unet
374 | height = height or self.unet.config.sample_size * self.vae_scale_factor
375 | width = width or self.unet.config.sample_size * self.vae_scale_factor
376 |
377 | # Check inputs. Raise error if not correct
378 | self.check_inputs(prompt, height, width, callback_steps)
379 |
380 | # Define call parameters
381 | # batch_size = 1 if isinstance(prompt, str) else len(prompt)
382 | batch_size = 1
383 | if latents is not None:
384 | batch_size = latents.shape[0]
385 | if isinstance(prompt, list):
386 | batch_size = len(prompt)
387 |
388 | device = self._execution_device
389 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
390 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
391 | # corresponds to doing no classifier free guidance.
392 | do_classifier_free_guidance = guidance_scale > 1.0
393 |
394 | # Encode input prompt
395 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
396 | if negative_prompt is not None:
397 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
398 | text_embeddings = self._encode_prompt(
399 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
400 | )
401 |
402 | # Prepare timesteps
403 | self.scheduler.set_timesteps(num_inference_steps, device=device)
404 | timesteps = self.scheduler.timesteps
405 |
406 | # Prepare latent variables
407 | num_channels_latents = 4
408 | # num_channels_latents = self.unet.in_channels
409 | latents = self.prepare_latents(
410 | batch_size * num_videos_per_prompt,
411 | num_channels_latents,
412 | video_length,
413 | height,
414 | width,
415 | text_embeddings.dtype,
416 | device,
417 | generator,
418 | latents,
419 | )
420 | latents_dtype = latents.dtype
421 |
422 | if use_image:
423 | shape = (
424 | batch_size,
425 | num_channels_latents,
426 | video_length,
427 | height // self.vae_scale_factor,
428 | width // self.vae_scale_factor,
429 | )
430 |
431 | image = Image.open(f"test_image/init_image{use_image}.png").convert("RGB")
432 | image = preprocess_image(image).to(device)
433 | if isinstance(generator, list):
434 | image_latent = [
435 | self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size)
436 | ]
437 | image_latent = torch.cat(image_latent, dim=0).to(device=device)
438 | else:
439 | image_latent = self.vae.encode(image).latent_dist.sample(generator).to(device=device)
440 |
441 | image_latent = torch.nn.functional.interpolate(image_latent, size=[shape[-2], shape[-1]])
442 | image_latent_padding = image_latent.clone() * 0.18215
443 | mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(device)
444 | mask_coef = prepare_mask_coef(video_length, 0, kwargs["mask_sim_range"])
445 |
446 | add_noise = torch.randn(shape).to(device)
447 | masked_image = torch.zeros(shape).to(device)
448 | for f in range(video_length):
449 | mask[:, :, f, :, :] = mask_coef[f]
450 | masked_image[:, :, f, :, :] = image_latent_padding.clone()
451 | mask = mask.to(device)
452 | else:
453 | shape = (
454 | batch_size,
455 | num_channels_latents,
456 | video_length,
457 | height // self.vae_scale_factor,
458 | width // self.vae_scale_factor,
459 | )
460 | add_noise = torch.zeros_like(latents).to(device)
461 | masked_image = add_noise
462 | mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(device)
463 |
464 | # Prepare extra step kwargs.
465 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
466 | mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
467 | masked_image = torch.cat([masked_image] * 2) if do_classifier_free_guidance else masked_image
468 | # Denoising loop
469 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
470 | with self.progress_bar(total=num_inference_steps) as progress_bar:
471 | for i, t in enumerate(timesteps):
472 | # expand the latents if we are doing classifier free guidance
473 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
474 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
475 |
476 | # predict the noise residual
477 | noise_pred = self.unet(
478 | latent_model_input, mask, masked_image, t, encoder_hidden_states=text_embeddings
479 | ).sample.to(dtype=latents_dtype)
480 |
481 | # perform guidance
482 | if do_classifier_free_guidance:
483 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
484 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
485 |
486 | # compute the previous noisy sample x_t -> x_t-1
487 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
488 |
489 | # call the callback, if provided
490 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
491 | progress_bar.update()
492 | if callback is not None and i % callback_steps == 0:
493 | callback(i, t, latents)
494 |
495 | # Post-processing
496 | video = self.decode_latents(latents)
497 |
498 | # Convert to tensor
499 | if output_type == "tensor":
500 | video = torch.from_numpy(video)
501 |
502 | if not return_dict:
503 | return video
504 |
505 | return AnimationPipelineOutput(videos=video)
506 |
--------------------------------------------------------------------------------
/animatediff/utils/convert_lora_safetensor_to_diffusers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Conversion script for the LoRA's safetensors checkpoints."""
16 |
17 | import argparse
18 |
19 | import torch
20 |
21 |
22 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
23 | # load base model
24 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
25 |
26 | # load LoRA weight from .safetensors
27 | # state_dict = load_file(checkpoint_path)
28 |
29 | visited = []
30 |
31 | # directly update weight in diffusers model
32 | for key in state_dict:
33 | # it is suggested to print out the key, it usually will be something like below
34 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
35 |
36 | # as we have set the alpha beforehand, so just skip
37 | if ".alpha" in key or key in visited:
38 | continue
39 |
40 | if "text" in key:
41 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
42 | curr_layer = pipeline.text_encoder
43 | else:
44 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
45 | curr_layer = pipeline.unet
46 |
47 | # find the target layer
48 | temp_name = layer_infos.pop(0)
49 | while len(layer_infos) > -1:
50 | try:
51 | curr_layer = curr_layer.__getattr__(temp_name)
52 | if len(layer_infos) > 0:
53 | temp_name = layer_infos.pop(0)
54 | elif len(layer_infos) == 0:
55 | break
56 | except Exception:
57 | if len(temp_name) > 0:
58 | temp_name += "_" + layer_infos.pop(0)
59 | else:
60 | temp_name = layer_infos.pop(0)
61 |
62 | pair_keys = []
63 | if "lora_down" in key:
64 | pair_keys.append(key.replace("lora_down", "lora_up"))
65 | pair_keys.append(key)
66 | else:
67 | pair_keys.append(key)
68 | pair_keys.append(key.replace("lora_up", "lora_down"))
69 |
70 | # update weight
71 | if len(state_dict[pair_keys[0]].shape) == 4:
72 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
73 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
74 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(
75 | curr_layer.weight.data.device
76 | )
77 | else:
78 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
79 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
80 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
81 |
82 | # update visited list
83 | for item in pair_keys:
84 | visited.append(item)
85 |
86 | return pipeline
87 |
88 |
89 | def convert_lora_model_level(
90 | state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6
91 | ):
92 | """convert lora in model level instead of pipeline leval"""
93 |
94 | visited = []
95 |
96 | # directly update weight in diffusers model
97 | for key in state_dict:
98 | # it is suggested to print out the key, it usually will be something like below
99 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
100 |
101 | # as we have set the alpha beforehand, so just skip
102 | if ".alpha" in key or key in visited:
103 | continue
104 |
105 | if "text" in key:
106 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
107 | assert text_encoder is not None, "text_encoder must be passed since lora contains text encoder layers"
108 | curr_layer = text_encoder
109 | else:
110 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
111 | curr_layer = unet
112 |
113 | # find the target layer
114 | temp_name = layer_infos.pop(0)
115 | while len(layer_infos) > -1:
116 | try:
117 | curr_layer = curr_layer.__getattr__(temp_name)
118 | if len(layer_infos) > 0:
119 | temp_name = layer_infos.pop(0)
120 | elif len(layer_infos) == 0:
121 | break
122 | except Exception:
123 | if len(temp_name) > 0:
124 | temp_name += "_" + layer_infos.pop(0)
125 | else:
126 | temp_name = layer_infos.pop(0)
127 |
128 | pair_keys = []
129 | if "lora_down" in key:
130 | pair_keys.append(key.replace("lora_down", "lora_up"))
131 | pair_keys.append(key)
132 | else:
133 | pair_keys.append(key)
134 | pair_keys.append(key.replace("lora_up", "lora_down"))
135 |
136 | # update weight
137 | # NOTE: load lycon, maybe have bugs :(
138 | if "conv_in" in pair_keys[0]:
139 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
140 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
141 | weight_up = weight_up.view(weight_up.size(0), -1)
142 | weight_down = weight_down.view(weight_down.size(0), -1)
143 | shape = list(curr_layer.weight.data.shape)
144 | shape[1] = 4
145 | curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape)
146 | elif "conv" in pair_keys[0]:
147 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
148 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
149 | weight_up = weight_up.view(weight_up.size(0), -1)
150 | weight_down = weight_down.view(weight_down.size(0), -1)
151 | shape = list(curr_layer.weight.data.shape)
152 | curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape)
153 | elif len(state_dict[pair_keys[0]].shape) == 4:
154 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
155 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
156 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(
157 | curr_layer.weight.data.device
158 | )
159 | else:
160 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
161 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
162 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
163 |
164 | # update visited list
165 | for item in pair_keys:
166 | visited.append(item)
167 |
168 | return unet, text_encoder
169 |
170 |
171 | if __name__ == "__main__":
172 | parser = argparse.ArgumentParser()
173 |
174 | parser.add_argument(
175 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
176 | )
177 | parser.add_argument(
178 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
179 | )
180 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
181 | parser.add_argument(
182 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
183 | )
184 | parser.add_argument(
185 | "--lora_prefix_text_encoder",
186 | default="lora_te",
187 | type=str,
188 | help="The prefix of text encoder weight in safetensors",
189 | )
190 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
191 | parser.add_argument(
192 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
193 | )
194 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
195 |
196 | args = parser.parse_args()
197 |
198 | base_model_path = args.base_model_path
199 | checkpoint_path = args.checkpoint_path
200 | dump_path = args.dump_path
201 | lora_prefix_unet = args.lora_prefix_unet
202 | lora_prefix_text_encoder = args.lora_prefix_text_encoder
203 | alpha = args.alpha
204 |
205 | pipe = convert_lora_model_level(
206 | base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha
207 | )
208 |
209 | pipe = pipe.to(args.device)
210 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
211 |
--------------------------------------------------------------------------------
/animatediff/utils/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from typing import Optional, Union
4 |
5 | import cv2
6 | import imageio
7 | import moviepy.editor as mpy
8 | import numpy as np
9 | import torch
10 | import torch.distributed as dist
11 | import torchvision
12 | from einops import rearrange
13 | from PIL import Image
14 | from tqdm import tqdm
15 |
16 |
17 | # We recommend to use the following affinity score(motion magnitude)
18 | # Also encourage to try to construct different score by yourself
19 | RANGE_LIST = [
20 | [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion
21 | [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion
22 | [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion
23 | [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # Loop
24 | [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # Loop
25 | [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # Loop
26 | [0.5, 0.2], # Style Transfer Large Motion
27 | [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion
28 | [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion
29 | ]
30 |
31 |
32 | def zero_rank_print(s):
33 | if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0):
34 | print("### " + s)
35 |
36 |
37 | def save_videos_mp4(video: torch.Tensor, path: str, fps: int = 8):
38 | video = rearrange(video, "b c t h w -> t b c h w")
39 | num_frames, batch_size, channels, height, width = video.shape
40 | assert batch_size == 1, "Only support batch size == 1"
41 | video = video.squeeze(1)
42 | video = rearrange(video, "t c h w -> t h w c")
43 |
44 | def make_frame(t):
45 | frame_tensor = video[int(t * fps)]
46 | frame_np = (frame_tensor * 255).numpy().astype("uint8")
47 | return frame_np
48 |
49 | clip = mpy.VideoClip(make_frame, duration=num_frames / fps)
50 | clip.write_videofile(path, fps=fps, codec="libx264")
51 |
52 |
53 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
54 | videos = rearrange(videos, "b c t h w -> t b c h w")
55 | outputs = []
56 | for x in videos:
57 | x = torchvision.utils.make_grid(x, nrow=n_rows)
58 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
59 | if rescale:
60 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
61 | x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8)
62 | outputs.append(x)
63 |
64 | os.makedirs(os.path.dirname(path), exist_ok=True)
65 | imageio.mimsave(path, outputs, fps=fps)
66 |
67 |
68 | # DDIM Inversion
69 | @torch.no_grad()
70 | def init_prompt(prompt, pipeline):
71 | uncond_input = pipeline.tokenizer(
72 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt"
73 | )
74 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
75 | text_input = pipeline.tokenizer(
76 | [prompt],
77 | padding="max_length",
78 | max_length=pipeline.tokenizer.model_max_length,
79 | truncation=True,
80 | return_tensors="pt",
81 | )
82 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
83 | context = torch.cat([uncond_embeddings, text_embeddings])
84 |
85 | return context
86 |
87 |
88 | def next_step(
89 | model_output: Union[torch.FloatTensor, np.ndarray],
90 | timestep: int,
91 | sample: Union[torch.FloatTensor, np.ndarray],
92 | ddim_scheduler,
93 | ):
94 | timestep, next_timestep = (
95 | min(timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999),
96 | timestep,
97 | )
98 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
99 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
100 | beta_prod_t = 1 - alpha_prod_t
101 | next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
102 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
103 | next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
104 | return next_sample
105 |
106 |
107 | def get_noise_pred_single(latents, t, context, unet):
108 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
109 | return noise_pred
110 |
111 |
112 | @torch.no_grad()
113 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
114 | context = init_prompt(prompt, pipeline)
115 | uncond_embeddings, cond_embeddings = context.chunk(2)
116 | all_latent = [latent]
117 | latent = latent.clone().detach()
118 | for i in tqdm(range(num_inv_steps)):
119 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
120 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
121 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
122 | all_latent.append(latent)
123 | return all_latent
124 |
125 |
126 | @torch.no_grad()
127 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
128 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
129 | return ddim_latents
130 |
131 |
132 | def prepare_mask_coef(video_length: int, cond_frame: int, sim_range: list = [0.2, 1.0]):
133 | assert len(sim_range) == 2, "sim_range should has the length of 2, including the min and max similarity"
134 |
135 | assert video_length > 1, "video_length should be greater than 1"
136 |
137 | assert video_length > cond_frame, "video_length should be greater than cond_frame"
138 |
139 | diff = abs(sim_range[0] - sim_range[1]) / (video_length - 1)
140 | coef = [1.0] * video_length
141 | for f in range(video_length):
142 | f_diff = diff * abs(cond_frame - f)
143 | f_diff = 1 - f_diff
144 | coef[f] *= f_diff
145 |
146 | return coef
147 |
148 |
149 | def prepare_mask_coef_by_statistics(video_length: int, cond_frame: int, sim_range: int):
150 | assert video_length > 0, "video_length should be greater than 0"
151 |
152 | assert video_length > cond_frame, "video_length should be greater than cond_frame"
153 |
154 | range_list = RANGE_LIST
155 |
156 | assert sim_range < len(range_list), f"sim_range type{sim_range} not implemented"
157 |
158 | coef = range_list[sim_range]
159 | coef = coef + ([coef[-1]] * (video_length - len(coef)))
160 |
161 | order = [abs(i - cond_frame) for i in range(video_length)]
162 | coef = [coef[order[i]] for i in range(video_length)]
163 |
164 | return coef
165 |
166 |
167 | def prepare_mask_coef_multi_cond(video_length: int, cond_frames: list, sim_range: list = [0.2, 1.0]):
168 | assert len(sim_range) == 2, "sim_range should has the length of 2, including the min and max similarity"
169 |
170 | assert video_length > 1, "video_length should be greater than 1"
171 |
172 | assert isinstance(cond_frames, list), "cond_frames should be a list"
173 |
174 | assert video_length > max(cond_frames), "video_length should be greater than cond_frame"
175 |
176 | if max(sim_range) == min(sim_range):
177 | cond_coefs = [sim_range[0]] * video_length
178 | return cond_coefs
179 |
180 | cond_coefs = []
181 |
182 | for cond_frame in cond_frames:
183 | cond_coef = prepare_mask_coef(video_length, cond_frame, sim_range)
184 | cond_coefs.append(cond_coef)
185 |
186 | mixed_coef = [0] * video_length
187 | for conds in range(len(cond_frames)):
188 | for f in range(video_length):
189 | mixed_coef[f] = abs(cond_coefs[conds][f] - mixed_coef[f])
190 |
191 | if conds > 0:
192 | min_num = min(mixed_coef)
193 | max_num = max(mixed_coef)
194 |
195 | for f in range(video_length):
196 | mixed_coef[f] = (mixed_coef[f] - min_num) / (max_num - min_num)
197 |
198 | mixed_max = max(mixed_coef)
199 | mixed_min = min(mixed_coef)
200 | for f in range(video_length):
201 | mixed_coef[f] = (max(sim_range) - min(sim_range)) * (mixed_coef[f] - mixed_min) / (
202 | mixed_max - mixed_min
203 | ) + min(sim_range)
204 |
205 | mixed_coef = [
206 | x if min(sim_range) <= x <= max(sim_range) else min(sim_range) if x < min(sim_range) else max(sim_range)
207 | for x in mixed_coef
208 | ]
209 |
210 | return mixed_coef
211 |
212 |
213 | def prepare_masked_latent_cond(video_length: int, cond_frames: list):
214 | for cond_frame in cond_frames:
215 | assert cond_frame < video_length, "cond_frame should be smaller than video_length"
216 | assert cond_frame > -1, f"cond_frame should be in the range of [0, {video_length}]"
217 |
218 | cond_frames.sort()
219 | nearest = [cond_frames[0]] * video_length
220 | for f in range(video_length):
221 | for cond_frame in cond_frames:
222 | if abs(nearest[f] - f) > abs(cond_frame - f):
223 | nearest[f] = cond_frame
224 |
225 | maked_latent_cond = nearest
226 |
227 | return maked_latent_cond
228 |
229 |
230 | def estimated_kernel_size(frame_width: int, frame_height: int) -> int:
231 | """Estimate kernel size based on video resolution."""
232 | # TODO: This equation is based on manual estimation from a few videos.
233 | # Create a more comprehensive test suite to optimize against.
234 | size: int = 4 + round(math.sqrt(frame_width * frame_height) / 192)
235 | if size % 2 == 0:
236 | size += 1
237 | return size
238 |
239 |
240 | def detect_edges(lum: np.ndarray) -> np.ndarray:
241 | """Detect edges using the luma channel of a frame.
242 |
243 | Arguments:
244 | lum: 2D 8-bit image representing the luma channel of a frame.
245 |
246 | Returns:
247 | 2D 8-bit image of the same size as the input, where pixels with values of 255
248 | represent edges, and all other pixels are 0.
249 | """
250 | # Initialize kernel.
251 | kernel_size = estimated_kernel_size(lum.shape[1], lum.shape[0])
252 | kernel = np.ones((kernel_size, kernel_size), np.uint8)
253 |
254 | # Estimate levels for thresholding.
255 | # TODO(0.6.3): Add config file entries for sigma, aperture/kernel size, etc.
256 | sigma: float = 1.0 / 3.0
257 | median = np.median(lum)
258 | low = int(max(0, (1.0 - sigma) * median))
259 | high = int(min(255, (1.0 + sigma) * median))
260 |
261 | # Calculate edges using Canny algorithm, and reduce noise by dilating the edges.
262 | # This increases edge overlap leading to improved robustness against noise and slow
263 | # camera movement. Note that very large kernel sizes can negatively affect accuracy.
264 | edges = cv2.Canny(lum, low, high)
265 | return cv2.dilate(edges, kernel)
266 |
267 |
268 | def prepare_mask_coef_by_score(
269 | video_shape: list,
270 | cond_frame_idx: list,
271 | sim_range: list = [0.2, 1.0],
272 | statistic: list = [1, 100],
273 | coef_max: int = 0.98,
274 | score: Optional[torch.Tensor] = None,
275 | ):
276 | """
277 | the shape of video_data is (b f c h w)
278 | cond_frame_idx is a list, with length of batch_size
279 | the shape of statistic is (f 2)
280 | the shape of score is (b f)
281 | the shape of coef is (b f)
282 | """
283 | assert (
284 | len(video_shape) == 2
285 | ), f"the shape of video_shape should be (b f c h w), but now get {len(video_shape.shape)} channels"
286 |
287 | batch_size, frame_num = video_shape[0], video_shape[1]
288 |
289 | score = score.permute(0, 2, 1).squeeze(0)
290 |
291 | # list -> b 1
292 | cond_fram_mat = torch.tensor(cond_frame_idx).unsqueeze(-1)
293 |
294 | statistic = torch.tensor(statistic)
295 | # (f 2) -> (b f 2)
296 | statistic = statistic.repeat(batch_size, 1, 1)
297 |
298 | # shape of order (b f), shape of cond_mat (b f)
299 | order = torch.arange(0, frame_num, 1)
300 | order = order.repeat(batch_size, 1)
301 | cond_mat = torch.ones((batch_size, frame_num)) * cond_fram_mat
302 | order = abs(order - cond_mat)
303 |
304 | statistic = statistic[:, order.to(torch.long)][0, :, :, :]
305 |
306 | # score (b f) max_s (b f 1)
307 | max_stats = torch.max(statistic, dim=2).values.to(dtype=score.dtype)
308 | min_stats = torch.min(statistic, dim=2).values.to(dtype=score.dtype)
309 |
310 | score[score > max_stats] = max_stats[score > max_stats] * 0.95
311 | score[score < min_stats] = min_stats[score < min_stats]
312 |
313 | eps = 1e-10
314 | coef = 1 - abs((score / (max_stats + eps)) * (max(sim_range) - min(sim_range)))
315 |
316 | indices = torch.arange(coef.shape[0]).unsqueeze(1)
317 | coef[indices, cond_fram_mat] = 1.0
318 |
319 | return coef
320 |
321 |
322 | def preprocess_img(img_path, max_size: int = 512):
323 | ori_image = Image.open(img_path).convert("RGB")
324 |
325 | width, height = ori_image.size
326 |
327 | long_edge = max(width, height)
328 | if long_edge > max_size:
329 | scale_factor = max_size / long_edge
330 | else:
331 | scale_factor = 1
332 | width = int(width * scale_factor)
333 | height = int(height * scale_factor)
334 | ori_image = ori_image.resize((width, height))
335 |
336 | if (width % 8 != 0) or (height % 8 != 0):
337 | in_width = (width // 8) * 8
338 | in_height = (height // 8) * 8
339 | else:
340 | in_width = width
341 | in_height = height
342 | in_image = ori_image
343 |
344 | in_image = ori_image.resize((in_width, in_height))
345 | # in_image = ori_image.resize((512, 512))
346 | in_image_np = np.array(in_image)
347 | return in_image_np, in_height, in_width
348 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | import random
5 | from argparse import ArgumentParser
6 | from datetime import datetime
7 | from glob import glob
8 |
9 | import gradio as gr
10 | import numpy as np
11 | import torch
12 | from omegaconf import OmegaConf
13 | from PIL import Image
14 |
15 | from animatediff.pipelines import I2VPipeline
16 | from animatediff.utils.util import save_videos_grid
17 | from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
18 |
19 |
20 | sample_idx = 0
21 | scheduler_dict = {
22 | "DDIM": DDIMScheduler,
23 | "Euler": EulerDiscreteScheduler,
24 | "PNDM": PNDMScheduler,
25 | }
26 |
27 | css = """
28 | .toolbutton {
29 | margin-buttom: 0em 0em 0em 0em;
30 | max-width: 2.5em;
31 | min-width: 2.5em !important;
32 | height: 2.5em;
33 | }
34 | """
35 |
36 | parser = ArgumentParser()
37 | parser.add_argument("--config", type=str, default="example/config/base.yaml")
38 | parser.add_argument("--server-name", type=str, default="0.0.0.0")
39 | parser.add_argument("--port", type=int, default=7860)
40 | parser.add_argument("--share", action="store_true")
41 |
42 | parser.add_argument("--save-path", default="samples")
43 |
44 | args = parser.parse_args()
45 |
46 |
47 | N_PROMPT = (
48 | "wrong white balance, dark, sketches,worst quality,low quality, "
49 | "deformed, distorted, disfigured, bad eyes, wrong lips, "
50 | "weird mouth, bad teeth, mutated hands and fingers, bad anatomy,"
51 | "wrong anatomy, amputation, extra limb, missing limb, "
52 | "floating,limbs, disconnected limbs, mutation, ugly, disgusting, "
53 | "bad_pictures, negative_hand-neg"
54 | )
55 |
56 |
57 | def preprocess_img(img_np, max_size: int = 512):
58 | ori_image = Image.fromarray(img_np).convert("RGB")
59 |
60 | width, height = ori_image.size
61 |
62 | long_edge = max(width, height)
63 | if long_edge > max_size:
64 | scale_factor = max_size / long_edge
65 | else:
66 | scale_factor = 1
67 | width = int(width * scale_factor)
68 | height = int(height * scale_factor)
69 | ori_image = ori_image.resize((width, height))
70 |
71 | if (width % 8 != 0) or (height % 8 != 0):
72 | in_width = (width // 8) * 8
73 | in_height = (height // 8) * 8
74 | else:
75 | in_width = width
76 | in_height = height
77 | in_image = ori_image
78 |
79 | in_image = ori_image.resize((in_width, in_height))
80 | in_image_np = np.array(in_image)
81 | return in_image_np, in_height, in_width
82 |
83 |
84 | class AnimateController:
85 | def __init__(self):
86 | # config dirs
87 | self.basedir = os.getcwd()
88 | self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
89 | self.ip_adapter_dir = os.path.join(self.basedir, "models", "IP_Adapter")
90 | self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
91 | self.savedir_sample = os.path.join(self.savedir, "sample")
92 | os.makedirs(self.savedir, exist_ok=True)
93 |
94 | self.stable_diffusion_list = []
95 | self.motion_module_list = []
96 | self.personalized_model_list = []
97 |
98 | self.refresh_personalized_model()
99 |
100 | self.pipeline = None
101 |
102 | self.inference_config = OmegaConf.load(args.config)
103 | self.stable_diffusion_dir = self.inference_config.pretrained_model_path
104 | self.pia_path = self.inference_config.generate.model_path
105 | self.loaded = False
106 |
107 | def refresh_personalized_model(self):
108 | personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
109 | self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
110 |
111 | def get_ip_apdater_folder(self):
112 | file_list = os.listdir(self.ip_adapter_dir)
113 | if not file_list:
114 | return False
115 |
116 | if not "ip-adapter_sd15.bin" not in file_list:
117 | print('Cannot find "ip-adapter_sd15.bin" ' f"under {self.ip_adapter_dir}")
118 | return False
119 | if not "image_encoder" not in file_list:
120 | print(f'Cannot find "image_encoder" under {self.ip_adapter_dir}')
121 | return False
122 |
123 | return True
124 |
125 | def load_model(self, dreambooth_path=None, lora_path=None, lora_alpha=1.0, enable_ip_adapter=True):
126 | gr.Info("Start Load Models...")
127 | print("Start Load Models...")
128 |
129 | if lora_path and lora_path.upper() != "NONE":
130 | lora_path = osp.join(self.personalized_model_dir, lora_path)
131 | else:
132 | lora_path = None
133 |
134 | if dreambooth_path and dreambooth_path.upper() != "NONE":
135 | dreambooth_path = osp.join(self.personalized_model_dir, dreambooth_path)
136 | else:
137 | dreambooth_path = None
138 |
139 | if enable_ip_adapter:
140 | if not self.get_ip_apdater_folder():
141 | print("Load IP-Adapter from remote.")
142 | ip_adapter_path = "h94/IP-Adapter"
143 | else:
144 | ip_adapter_path = self.ip_adapter_dir
145 | else:
146 | ip_adapter_path = None
147 |
148 | self.pipeline = I2VPipeline.build_pipeline(
149 | self.inference_config,
150 | self.stable_diffusion_dir,
151 | unet_path=self.pia_path,
152 | dreambooth_path=dreambooth_path,
153 | lora_path=lora_path,
154 | lora_alpha=lora_alpha,
155 | ip_adapter_path=ip_adapter_path,
156 | )
157 | gr.Info("Load Finish!")
158 | print("Load Finish!")
159 | self.loaded = True
160 |
161 | return "Load"
162 |
163 | def animate(
164 | self,
165 | init_img,
166 | motion_scale,
167 | prompt_textbox,
168 | negative_prompt_textbox,
169 | sampler_dropdown,
170 | sample_step_slider,
171 | length_slider,
172 | cfg_scale_slider,
173 | seed_textbox,
174 | ip_adapter_scale,
175 | max_size,
176 | progress=gr.Progress(),
177 | ):
178 | if not self.loaded:
179 | raise gr.Error("Please load model first!")
180 |
181 | if seed_textbox != -1 and seed_textbox != "":
182 | torch.manual_seed(int(seed_textbox))
183 | else:
184 | torch.seed()
185 | seed = torch.initial_seed()
186 | init_img, h, w = preprocess_img(init_img, max_size)
187 | sample = self.pipeline(
188 | image=init_img,
189 | prompt=prompt_textbox,
190 | negative_prompt=negative_prompt_textbox,
191 | num_inference_steps=sample_step_slider,
192 | guidance_scale=cfg_scale_slider,
193 | width=w,
194 | height=h,
195 | video_length=16,
196 | mask_sim_template_idx=motion_scale,
197 | ip_adapter_scale=ip_adapter_scale,
198 | progress_fn=progress,
199 | ).videos
200 |
201 | save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
202 | save_videos_grid(sample, save_sample_path)
203 |
204 | sample_config = {
205 | "prompt": prompt_textbox,
206 | "n_prompt": negative_prompt_textbox,
207 | "sampler": sampler_dropdown,
208 | "num_inference_steps": sample_step_slider,
209 | "guidance_scale": cfg_scale_slider,
210 | "width": w,
211 | "height": h,
212 | "video_length": length_slider,
213 | "seed": seed,
214 | "motion": motion_scale,
215 | }
216 | json_str = json.dumps(sample_config, indent=4)
217 | with open(os.path.join(self.savedir, "logs.json"), "a") as f:
218 | f.write(json_str)
219 | f.write("\n\n")
220 |
221 | return save_sample_path
222 |
223 |
224 | controller = AnimateController()
225 |
226 |
227 | def ui():
228 | with gr.Blocks(css=css) as demo:
229 | motion_idx = gr.State(0)
230 | gr.HTML(
231 | "
Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models "
232 | )
233 | with gr.Row():
234 | gr.Markdown(
235 | "" # noqa
239 | )
240 |
241 | with gr.Column(variant="panel"):
242 | gr.Markdown(
243 | """
244 | ### 1. Model checkpoints (select pretrained model path first).
245 | """
246 | )
247 | with gr.Row():
248 | base_model_dropdown = gr.Dropdown(
249 | label="Select base Dreambooth model",
250 | choices=["none"] + controller.personalized_model_list,
251 | value="none",
252 | interactive=True,
253 | )
254 |
255 | lora_model_dropdown = gr.Dropdown(
256 | label="Select LoRA model (optional)",
257 | choices=["none"] + controller.personalized_model_list,
258 | value="none",
259 | interactive=True,
260 | )
261 |
262 | lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0, minimum=0, maximum=2, interactive=True)
263 |
264 | personalized_refresh_button = gr.Button(value="\U0001f503", elem_classes="toolbutton")
265 |
266 | def update_personalized_model():
267 | controller.refresh_personalized_model()
268 | return [controller.personalized_model_list, ["none"] + controller.personalized_model_list]
269 |
270 | personalized_refresh_button.click(
271 | fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]
272 | )
273 |
274 | load_model_button = gr.Button(value="Load")
275 | load_model_button.click(
276 | fn=controller.load_model,
277 | inputs=[
278 | base_model_dropdown,
279 | lora_model_dropdown,
280 | lora_alpha_slider,
281 | ],
282 | outputs=[load_model_button],
283 | )
284 |
285 | with gr.Column(variant="panel"):
286 | gr.Markdown(
287 | """
288 | ### 2. Configs for PIA.
289 | """
290 | )
291 |
292 | prompt_textbox = gr.Textbox(label="Prompt", lines=2)
293 | negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1)
294 |
295 | with gr.Row(equal_height=False):
296 | with gr.Column():
297 | with gr.Row():
298 | init_img = gr.Image(label="Input Image")
299 |
300 | with gr.Row():
301 | sampler_dropdown = gr.Dropdown(
302 | label="Sampling method",
303 | choices=list(scheduler_dict.keys()),
304 | value=list(scheduler_dict.keys())[0],
305 | )
306 | sample_step_slider = gr.Slider(
307 | label="Sampling steps", value=25, minimum=10, maximum=100, step=1
308 | )
309 |
310 | max_size_slider = gr.Slider(
311 | label="Max size (The long edge of the input image will be resized to this value, larger value means slower inference speed)",
312 | value=512,
313 | step=64,
314 | minimum=512,
315 | maximum=1024,
316 | )
317 |
318 | length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
319 | cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
320 | motion_scale_silder = gr.Slider(
321 | label="Motion Scale", value=motion_idx.value, step=1, minimum=0, maximum=2
322 | )
323 | ip_adapter_scale = gr.Slider(label="IP-Apdater Scale", value=0.0, minimum=0, maximum=1)
324 |
325 | def GenerationMode(motion_scale_silder, option):
326 | if option == "Animation":
327 | motion_idx = motion_scale_silder
328 | elif option == "Style Transfer":
329 | motion_idx = motion_scale_silder * -1 - 1
330 | elif option == "Loop Video":
331 | motion_idx = motion_scale_silder + 3
332 | return motion_idx
333 |
334 | with gr.Row():
335 | style_selection = gr.Radio(
336 | ["Animation", "Style Transfer", "Loop Video"], label="Generation Mode", value="Animation"
337 | )
338 | style_selection.change(
339 | fn=GenerationMode, inputs=[motion_scale_silder, style_selection], outputs=[motion_idx]
340 | )
341 | motion_scale_silder.change(
342 | fn=GenerationMode, inputs=[motion_scale_silder, style_selection], outputs=[motion_idx]
343 | )
344 |
345 | with gr.Row():
346 | seed_textbox = gr.Textbox(label="Seed", value=-1)
347 | seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton")
348 | seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)
349 |
350 | generate_button = gr.Button(value="Generate", variant="primary")
351 |
352 | result_video = gr.Video(label="Generated Animation", interactive=False)
353 |
354 | generate_button.click(
355 | fn=controller.animate,
356 | inputs=[
357 | init_img,
358 | motion_idx,
359 | prompt_textbox,
360 | negative_prompt_textbox,
361 | sampler_dropdown,
362 | sample_step_slider,
363 | length_slider,
364 | cfg_scale_slider,
365 | seed_textbox,
366 | ip_adapter_scale,
367 | max_size_slider,
368 | ],
369 | outputs=[result_video],
370 | )
371 |
372 | return demo
373 |
374 |
375 | if __name__ == "__main__":
376 | demo = ui()
377 | demo.queue(3)
378 | demo.launch(server_name=args.server_name, server_port=args.port, share=args.share, allowed_paths=["pia.png"])
379 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 | python_version: "3.11"
7 | system_packages:
8 | - "libgl1-mesa-glx"
9 | - "libglib2.0-0"
10 | python_packages:
11 | - torch==2.0.1
12 | - torchvision==0.15.2
13 | - diffusers==0.24.0
14 | - transformers==4.36.0
15 | - accelerate==0.25.0
16 | - imageio==2.27.0
17 | - decord==0.6.0
18 | - einops==0.7.0
19 | - omegaconf==2.3.0
20 | - safetensors==0.4.1
21 | - opencv-python==4.8.1.78
22 | - moviepy==1.0.3
23 | run:
24 | - pip install xformers
25 | predict: "predict.py:Predictor"
26 |
--------------------------------------------------------------------------------
/download_bashscripts/1-RealisticVision.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/DreamBooth_LoRA/realisticVisionV51_v51VAE.safetensors https://huggingface.co/frankjoshua/realisticVisionV51_v51VAE/resolve/main/realisticVisionV51_v51VAE.safetensors?download=true --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/download_bashscripts/2-RcnzCartoon.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors https://civitai.com/api/download/models/71009 --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/download_bashscripts/3-MajicMix.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/DreamBooth_LoRA/majicmixRealistic_v5.safetensors https://civitai.com/api/download/models/82446 --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/environment-pt2.yaml:
--------------------------------------------------------------------------------
1 | name: pia
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python=3.10
7 | - pytorch=2.0.0
8 | - torchvision=0.15.0
9 | - pytorch-cuda=11.8
10 | - pip
11 | - pip:
12 | - diffusers==0.24.0
13 | - transformers==4.25.1
14 | - xformers
15 | - imageio==2.33.1
16 | - decord==0.6.0
17 | - gdown
18 | - einops
19 | - omegaconf
20 | - safetensors
21 | - gradio
22 | - wandb
23 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: pia
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python=3.10
7 | - pytorch=1.13.1
8 | - torchvision=0.14.1
9 | - torchaudio=0.13.1
10 | - pytorch-cuda=11.7
11 | - pip
12 | - pip:
13 | - diffusers==0.24.0
14 | - transformers==4.25.1
15 | - xformers==0.0.16
16 | - imageio==2.27.0
17 | - decord==0.6.0
18 | - gdown
19 | - einops
20 | - omegaconf
21 | - safetensors
22 | - gradio
23 | - wandb
24 |
--------------------------------------------------------------------------------
/example/config/anya.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - 1girl smiling
4 | - 1girl open mouth
5 | - 1girl crying, pout
6 | n_prompt:
7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
8 | validation_data:
9 | input_name: 'anya'
10 | validation_input_path: 'example/img'
11 | save_path: 'example/result'
12 | mask_sim_range: [-1]
13 | generate:
14 | use_lora: false
15 | use_db: true
16 | global_seed: 10201304011203481429
17 | lora_path: "models/DreamBooth_LoRA/cyberpunk.safetensors"
18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
19 | lora_alpha: 0.8
20 |
--------------------------------------------------------------------------------
/example/config/base.yaml:
--------------------------------------------------------------------------------
1 | generate:
2 | model_path: "models/PIA/pia.ckpt"
3 | use_image: true
4 | use_video: false
5 | sample_width: 512
6 | sample_height: 512
7 | video_length: 16
8 |
9 | validation_data:
10 | mask_sim_range: [0, 1]
11 | cond_frame: 0
12 | num_inference_steps: 25
13 |
14 | img_mask: ''
15 |
16 | noise_scheduler_kwargs:
17 | num_train_timesteps: 1000
18 | beta_start: 0.00085
19 | beta_end: 0.012
20 | beta_schedule: "linear"
21 | steps_offset: 1
22 | clip_sample: false
23 |
24 | pretrained_model_path: "models/StableDiffusion/"
25 | unet_additional_kwargs:
26 | use_motion_module : true
27 | motion_module_resolutions : [ 1,2,4,8 ]
28 | unet_use_cross_frame_attention : false
29 | unet_use_temporal_attention : false
30 |
31 | motion_module_type: Vanilla
32 | motion_module_kwargs:
33 | num_attention_heads : 8
34 | num_transformer_block : 1
35 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
36 | temporal_position_encoding : true
37 | temporal_position_encoding_max_len : 32
38 | temporal_attention_dim_div : 1
39 | zero_initialize : true
40 |
--------------------------------------------------------------------------------
/example/config/bear.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - 1bear walking in a shop, best quality, 4k
4 | n_prompt:
5 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
6 | validation_data:
7 | input_name: 'bear'
8 | validation_input_path: 'example/img'
9 | save_path: 'example/result'
10 | mask_sim_range: [0, 1, 2]
11 | generate:
12 | use_lora: false
13 | use_db: true
14 | global_seed: 10201034102130841429
15 | lora_path: "models/DreamBooth_LoRA/cyberpunk.safetensors"
16 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
17 | lora_alpha: 0.4
18 |
--------------------------------------------------------------------------------
/example/config/concert.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - 1man is smiling, masterpiece, best quality, 1boy, afro, dark skin, playing guitar, concert, upper body, sweat, stage lights, oversized hawaiian shirt, intricate, print, pattern, happy, necklace, bokeh, jeans, drummer, dynamic pose
4 | - 1man is crying, masterpiece, best quality, 1boy, afro, dark skin, playing guitar, concert, upper body, sweat, stage lights, oversized hawaiian shirt, intricate, print, pattern, happy, necklace, bokeh, jeans, drummer, dynamic pose
5 | - 1man is singing, masterpiece, best quality, 1boy, afro, dark skin, playing guitar, concert, upper body, sweat, stage lights, oversized hawaiian shirt, intricate, print, pattern, happy, necklace, bokeh, jeans, drummer, dynamic pose
6 | n_prompt:
7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
8 | validation_data:
9 | input_name: 'concert'
10 | validation_input_path: 'example/img'
11 | save_path: 'example/result'
12 | mask_sim_range: [-3]
13 | generate:
14 | use_lora: false
15 | use_db: true
16 | global_seed: 4292543217695451092 # To get 3d style shown in github, you can use seed: 4292543217695451088
17 | lora_path: ""
18 | db_path: "models/DreamBooth_LoRA/realisticVisionV51_v51VAE.safetensors"
19 | lora_alpha: 0.8
20 |
--------------------------------------------------------------------------------
/example/config/genshin.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k
4 | n_prompt:
5 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
6 | validation_data:
7 | input_name: 'genshin'
8 | validation_input_path: 'example/img'
9 | save_path: 'example/result'
10 | mask_sim_range: [0, 1, 2]
11 | generate:
12 | use_lora: false
13 | use_db: true
14 | sample_width: 512
15 | sample_height: 768
16 | global_seed: 10041042941301238026
17 | lora_path: "models/DreamBooth_LoRA/genshin.safetensors"
18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
19 | lora_alpha: 0.4
20 |
--------------------------------------------------------------------------------
/example/config/harry.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - 1boy smiling
4 | - 1boy playing magic fire
5 | - 1boy is waving hands
6 | n_prompt:
7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
8 | validation_data:
9 | input_name: 'harry'
10 | validation_input_path: 'example/img'
11 | save_path: 'example/result'
12 | mask_sim_range: [1]
13 | generate:
14 | use_lora: false
15 | use_db: true
16 | global_seed: 10201403011320481249
17 | lora_path: ""
18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
19 | lora_alpha: 0.8
20 |
--------------------------------------------------------------------------------
/example/config/labrador.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - a golden labrador jump
4 | - a golden labrador walking
5 | - a golden labrador is running
6 | n_prompt:
7 | - 'collar, leashes, collars, wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly'
8 | validation_data:
9 | input_name: 'labrador'
10 | validation_input_path: 'example/img'
11 | save_path: 'example/result'
12 | mask_sim_range: [0, 1, 2]
13 | generate:
14 | use_lora: false
15 | use_db: true
16 | global_seed: 4292543217695451000
17 | lora_path: ""
18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
19 | lora_alpha: 0.8
20 |
--------------------------------------------------------------------------------
/example/config/lighthouse.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - lightning, lighthouse
4 | - sun rising, lighthouse
5 | - fireworks, lighthouse
6 | n_prompt:
7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
8 | validation_data:
9 | input_name: 'lighthouse'
10 | validation_input_path: 'example/img'
11 | save_path: 'example/result'
12 | mask_sim_range: [0]
13 | generate:
14 | use_lora: false
15 | use_db: true
16 | global_seed: 5658137986800322011
17 | lora_path: ""
18 | db_path: "models/DreamBooth_LoRA/realisticVisionV51_v51VAE.safetensors"
19 | lora_alpha: 0.8
20 |
--------------------------------------------------------------------------------
/example/config/majic_girl.yaml:
--------------------------------------------------------------------------------
1 | base: 'example/config/base.yaml'
2 | prompts:
3 | - - 1girl is smiling, lowres,watermark
4 | - 1girl is crying, lowres,watermark
5 | - 1girl, snowing dark night
6 | n_prompt:
7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
8 | validation_data:
9 | input_name: 'majic_girl'
10 | validation_input_path: 'example/img'
11 | save_path: 'example/result'
12 | mask_sim_range: [1]
13 | generate:
14 | use_lora: false
15 | use_db: true
16 | global_seed: 10021403011302841249
17 | lora_path: ""
18 | db_path: "models/DreamBooth_LoRA/majicmixRealistic_v5.safetensors"
19 | lora_alpha: 0.8
20 |
--------------------------------------------------------------------------------
/example/config/train.yaml:
--------------------------------------------------------------------------------
1 | image_finetune: false
2 |
3 | output_dir: "outputs"
4 | pretrained_model_path: "./models/StableDiffusion/"
5 | pretrained_motion_module_path: './models/Motion_Module/mm_sd_v15_v2.ckpt'
6 |
7 |
8 | unet_additional_kwargs:
9 | use_motion_module : true
10 | motion_module_resolutions : [ 1,2,4,8 ]
11 | unet_use_cross_frame_attention : false
12 | unet_use_temporal_attention : false
13 |
14 | motion_module_type: Vanilla
15 | motion_module_kwargs:
16 | num_attention_heads : 8
17 | num_transformer_block : 1
18 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
19 | temporal_position_encoding : true
20 | temporal_position_encoding_max_len : 32
21 | temporal_attention_dim_div : 1
22 | zero_initialize : true
23 |
24 | mask_sim_range: [0.2, 1.0]
25 |
26 | noise_scheduler_kwargs:
27 | num_train_timesteps: 1000
28 | beta_start: 0.00085
29 | beta_end: 0.012
30 | beta_schedule: "linear"
31 | steps_offset: 1
32 | clip_sample: false
33 |
34 | train_data:
35 | csv_path: "./results_10M_train.csv"
36 | video_folder: "data/WebVid10M/" # local path: replace it with yours
37 | # video_folder: "webvideo:s3://WebVid10M/" # petreloss path: replace it with yours
38 | sample_size: 256
39 | sample_stride: 4
40 | sample_n_frames: 16
41 | use_petreloss: false #set this as true if you want to use petreloss path
42 | conf_path: "~/petreloss.conf"
43 |
44 | validation_data:
45 | prompts:
46 | - "waves, ocean flows, sand, clean sea, breath-taking beautiful beach, tropicaHl beach."
47 | - "1girl walking on the street"
48 | - "Robot dancing in times square."
49 | - "Pacific coast, carmel by the sea ocean and waves."
50 | num_inference_steps: 25
51 | guidance_scale: 8.
52 | mask_sim_range: [0.2, 1.0]
53 |
54 | trainable_modules:
55 | - 'conv_in.'
56 | - 'motion_modules'
57 |
58 | # set the path to the finetuned unet's image layers
59 | # according to
60 | # https://github.com/guoyww/AnimateDiff/blob/main/__assets__/docs/animatediff.md#training
61 | unet_checkpoint_path: "models/mm_sd_v15_v2_full.ckpt"
62 |
63 | learning_rate: 1.e-4
64 | train_batch_size: 4
65 | gradient_accumulation_steps: 16
66 |
67 | max_train_epoch: -1
68 | max_train_steps: 500000
69 | checkpointing_epochs: -1
70 | checkpointing_steps: 60
71 |
72 | validation_steps: 3000
73 | validation_steps_tuple: [2, 50, 1000]
74 |
75 | global_seed: 42
76 | mixed_precision_training: true
77 | enable_xformers_memory_efficient_attention: True
78 |
79 | is_debug: False
80 |
81 | # precalculated statistics
82 | statistic: [[0., 0.],
83 | [0.3535855, 24.23687346],
84 | [0.91609545, 30.65091947],
85 | [1.41165152, 34.40093286],
86 | [1.56943881, 36.99639585],
87 | [1.73182842, 39.42044163],
88 | [1.82733002, 40.94703526],
89 | [1.88060527, 42.66233244],
90 | [1.96208071, 43.73070788],
91 | [2.02723091, 44.25965378],
92 | [2.10820894, 45.66120213],
93 | [2.21115041, 46.29561324],
94 | [2.23412351, 47.08810863],
95 | [2.29430165, 47.9515062],
96 | [2.32986362, 48.69085638],
97 | [2.37310751, 49.19931439]]
98 |
--------------------------------------------------------------------------------
/example/img/anya.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/anya.jpg
--------------------------------------------------------------------------------
/example/img/bear.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/bear.jpg
--------------------------------------------------------------------------------
/example/img/concert.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/concert.png
--------------------------------------------------------------------------------
/example/img/genshin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/genshin.jpg
--------------------------------------------------------------------------------
/example/img/harry.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/harry.png
--------------------------------------------------------------------------------
/example/img/labrador.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/labrador.png
--------------------------------------------------------------------------------
/example/img/lighthouse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/lighthouse.jpg
--------------------------------------------------------------------------------
/example/img/majic_girl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/majic_girl.jpg
--------------------------------------------------------------------------------
/example/openxlab/1-realistic.yaml:
--------------------------------------------------------------------------------
1 | dreambooth: 'realisticVisionV51_v51VAE.safetensors'
2 |
3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
4 |
5 | guidance_scale: 7
6 |
--------------------------------------------------------------------------------
/example/openxlab/3-3d.yaml:
--------------------------------------------------------------------------------
1 | dreambooth: 'rcnzCartoon3d_v10.safetensors'
2 |
3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
4 | prefix: ''
5 |
6 | guidance_scale: 7
7 |
8 | ip_adapter_scale: 0.0
9 |
--------------------------------------------------------------------------------
/example/replicate/1-realistic.yaml:
--------------------------------------------------------------------------------
1 | dreambooth: 'realisticVisionV51_v51VAE.safetensors'
2 |
3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
4 |
5 | guidance_scale: 7
6 |
--------------------------------------------------------------------------------
/example/replicate/3-3d.yaml:
--------------------------------------------------------------------------------
1 | dreambooth: 'rcnzCartoon3d_v10.safetensors'
2 |
3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
4 | prefix: ''
5 |
6 | guidance_scale: 7
7 |
8 | ip_adapter_scale: 0.0
9 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2 | import argparse
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | from omegaconf import OmegaConf
8 |
9 | from animatediff.pipelines import I2VPipeline
10 | from animatediff.utils.util import preprocess_img, save_videos_grid
11 |
12 |
13 | def seed_everything(seed):
14 | import random
15 |
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 | np.random.seed(seed % (2**32))
19 | random.seed(seed)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | functional_group = parser.add_mutually_exclusive_group()
25 | parser.add_argument("--config", type=str, default="configs/test.yaml")
26 | parser.add_argument(
27 | "--magnitude", type=int, default=None, choices=[0, 1, 2, -1, -2, -3]
28 | ) # negative is for style transfer
29 | functional_group.add_argument("--loop", action="store_true")
30 | functional_group.add_argument("--style_transfer", action="store_true")
31 | args = parser.parse_args()
32 |
33 | config = OmegaConf.load(args.config)
34 | base_config = OmegaConf.load(config.base)
35 | config = OmegaConf.merge(base_config, config)
36 |
37 | if args.magnitude is not None:
38 | config.validation_data.mask_sim_range = [args.magnitude]
39 |
40 | if args.style_transfer:
41 | config.validation_data.mask_sim_range = [
42 | -1 * magnitude - 1 if magnitude >= 0 else magnitude for magnitude in config.validation_data.mask_sim_range
43 | ]
44 | elif args.loop:
45 | config.validation_data.mask_sim_range = [
46 | magnitude + 3 if magnitude >= 0 else magnitude for magnitude in config.validation_data.mask_sim_range
47 | ]
48 |
49 | os.makedirs(config.validation_data.save_path, exist_ok=True)
50 | folder_num = len(os.listdir(config.validation_data.save_path))
51 | target_dir = f"{config.validation_data.save_path}/{folder_num}/"
52 |
53 | # prepare paths and pipeline
54 | base_model_path = config.pretrained_model_path
55 | unet_path = config.generate.model_path
56 | dreambooth_path = config.generate.db_path
57 | if config.generate.use_lora:
58 | lora_path = config.generate.get("lora_path", None)
59 | lora_alpha = config.generate.get("lora_alpha", 0)
60 | else:
61 | lora_path = None
62 | lora_alpha = 0
63 | validation_pipeline = I2VPipeline.build_pipeline(
64 | config,
65 | base_model_path,
66 | unet_path,
67 | dreambooth_path,
68 | lora_path,
69 | lora_alpha,
70 | )
71 | generator = torch.Generator(device="cuda")
72 | generator.manual_seed(config.generate.global_seed)
73 |
74 | global_inf_num = 0
75 |
76 | # if not os.path.exists(target_dir):
77 | os.makedirs(target_dir, exist_ok=True)
78 |
79 | # print(" >>> Begin test >>>")
80 | print(f"using unet : {unet_path}")
81 | print(f"using DreamBooth: {dreambooth_path}")
82 | print(f"using Lora : {lora_path}")
83 |
84 | sim_ranges = config.validation_data.mask_sim_range
85 | if isinstance(sim_ranges, int):
86 | sim_ranges = [sim_ranges]
87 |
88 | OmegaConf.save(config, os.path.join(target_dir, "config.yaml"))
89 | generator.manual_seed(config.generate.global_seed)
90 | seed_everything(config.generate.global_seed)
91 |
92 | # load image
93 | img_root = config.validation_data.validation_input_path
94 | input_name = config.validation_data.input_name
95 | if os.path.exists(os.path.join(img_root, f"{input_name}.jpg")):
96 | image_name = os.path.join(img_root, f"{input_name}.jpg")
97 | elif os.path.exists(os.path.join(img_root, f"{input_name}.png")):
98 | image_name = os.path.join(img_root, f"{input_name}.png")
99 | else:
100 | raise ValueError("image_name should be .jpg or .png")
101 | # image = np.array(Image.open(image_name))
102 | image, gen_height, gen_width = preprocess_img(image_name)
103 | config.generate.sample_height = gen_height
104 | config.generate.sample_width = gen_width
105 |
106 | for sim_range in sim_ranges:
107 | print(f"using sim_range : {sim_range}")
108 | config.validation_data.mask_sim_range = sim_range
109 | prompt_num = 0
110 | for prompt, n_prompt in zip(config.prompts, config.n_prompt):
111 | print(f"using n_prompt : {n_prompt}")
112 | prompt_num += 1
113 | for single_prompt in prompt:
114 | print(f" >>> Begin test {global_inf_num} >>>")
115 | global_inf_num += 1
116 | image_path = ""
117 | sample = validation_pipeline(
118 | image=image,
119 | prompt=single_prompt,
120 | generator=generator,
121 | # global_inf_num = global_inf_num,
122 | video_length=config.generate.video_length,
123 | height=config.generate.sample_height,
124 | width=config.generate.sample_width,
125 | negative_prompt=n_prompt,
126 | mask_sim_template_idx=config.validation_data.mask_sim_range,
127 | **config.validation_data,
128 | ).videos
129 | save_videos_grid(sample, target_dir + f"{global_inf_num}_sim_{sim_range}.gif")
130 | print(f" <<< test {global_inf_num} Done <<<")
131 | print(" <<< Test Done <<<")
132 |
--------------------------------------------------------------------------------
/models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt
--------------------------------------------------------------------------------
/models/IP_Adapter/Put IP-Adapter checkpoints here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/IP_Adapter/Put IP-Adapter checkpoints here.txt
--------------------------------------------------------------------------------
/models/Motion_Module/Put motion module checkpoints here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/Motion_Module/Put motion module checkpoints here.txt
--------------------------------------------------------------------------------
/models/VAE/Put VAE checkpoints here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/VAE/Put VAE checkpoints here.txt
--------------------------------------------------------------------------------
/pia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/pia.png
--------------------------------------------------------------------------------
/pia.yml:
--------------------------------------------------------------------------------
1 | name: pia
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - bzip2=1.0.8=h5eee18b_6
10 | - ca-certificates=2024.3.11=h06a4308_0
11 | - certifi=2024.6.2=py310h06a4308_0
12 | - cuda-cudart=11.8.89=0
13 | - cuda-cupti=11.8.87=0
14 | - cuda-libraries=11.8.0=0
15 | - cuda-nvrtc=11.8.89=0
16 | - cuda-nvtx=11.8.86=0
17 | - cuda-runtime=11.8.0=0
18 | - cuda-version=12.4=hbda6634_3
19 | - ld_impl_linux-64=2.38=h1181459_1
20 | - libcublas=11.11.3.6=0
21 | - libcufft=10.9.0.58=0
22 | - libcufile=1.9.1.3=h99ab3db_1
23 | - libcurand=10.3.5.147=h99ab3db_1
24 | - libcusolver=11.4.1.48=0
25 | - libcusparse=11.7.5.86=0
26 | - libffi=3.4.4=h6a678d5_1
27 | - libgcc-ng=11.2.0=h1234567_1
28 | - libgomp=11.2.0=h1234567_1
29 | - libnpp=11.8.0.86=0
30 | - libnvjpeg=11.9.0.86=0
31 | - libstdcxx-ng=11.2.0=h1234567_1
32 | - libuuid=1.41.5=h5eee18b_0
33 | - ncurses=6.4=h6a678d5_0
34 | - openssl=3.0.13=h7f8727e_2
35 | - pip=24.0=py310h06a4308_0
36 | - python=3.10.14=h955ad1f_1
37 | - pytorch-cuda=11.8=h7e8668a_5
38 | - readline=8.2=h5eee18b_0
39 | - setuptools=69.5.1=py310h06a4308_0
40 | - sqlite=3.45.3=h5eee18b_0
41 | - tk=8.6.14=h39e8969_0
42 | - wheel=0.43.0=py310h06a4308_0
43 | - xz=5.4.6=h5eee18b_1
44 | - zlib=1.2.13=h5eee18b_1
45 | - pip:
46 | - accelerate==0.31.0
47 | - aiofiles==23.2.1
48 | - altair==5.3.0
49 | - annotated-types==0.7.0
50 | - antlr4-python3-runtime==4.9.3
51 | - anyio==4.4.0
52 | - attrs==23.2.0
53 | - beautifulsoup4==4.12.3
54 | - blessed==1.20.0
55 | - boto3==1.34.125
56 | - botocore==1.34.125
57 | - cfgv==3.4.0
58 | - charset-normalizer==3.3.2
59 | - click==8.1.7
60 | - coloredlogs==15.0.1
61 | - contourpy==1.2.1
62 | - cycler==0.12.1
63 | - decorator==4.4.2
64 | - decord==0.6.0
65 | - diffusers==0.24.0
66 | - distlib==0.3.8
67 | - dnspython==2.6.1
68 | - docker-pycreds==0.4.0
69 | - einops==0.8.0
70 | - email-validator==2.1.1
71 | - environs==11.0.0
72 | - exceptiongroup==1.2.1
73 | - fastapi==0.111.0
74 | - fastapi-cli==0.0.4
75 | - ffmpy==0.3.2
76 | - filelock==3.14.0
77 | - fonttools==4.53.0
78 | - fsspec==2024.6.0
79 | - gdown==5.2.0
80 | - git-lfs==1.6
81 | - gitdb==4.0.11
82 | - gitpython==3.1.43
83 | - gpustat==1.1.1
84 | - gradio==4.36.0
85 | - gradio-client==1.0.1
86 | - h11==0.14.0
87 | - httpcore==1.0.5
88 | - httptools==0.6.1
89 | - httpx==0.27.0
90 | - huggingface-hub==0.23.3
91 | - humanfriendly==10.0
92 | - humanize==4.9.0
93 | - identify==2.5.36
94 | - idna==3.7
95 | - imageio==2.33.1
96 | - imageio-ffmpeg==0.5.1
97 | - importlib-metadata==7.1.0
98 | - importlib-resources==6.4.0
99 | - jinja2==3.1.3
100 | - jmespath==1.0.1
101 | - jsonschema==4.22.0
102 | - jsonschema-specifications==2023.12.1
103 | - kiwisolver==1.4.5
104 | - markdown-it-py==3.0.0
105 | - markupsafe==2.1.5
106 | - marshmallow==3.21.3
107 | - matplotlib==3.9.0
108 | - mdurl==0.1.2
109 | - moviepy==1.0.3
110 | - mpmath==1.3.0
111 | - multiprocessing-logging==0.3.4
112 | - networkx==3.2.1
113 | - nodeenv==1.9.1
114 | - numpy==1.26.4
115 | - nvidia-cublas-cu11==11.11.3.6
116 | - nvidia-cublas-cu12==12.1.3.1
117 | - nvidia-cuda-cupti-cu11==11.8.87
118 | - nvidia-cuda-cupti-cu12==12.1.105
119 | - nvidia-cuda-nvrtc-cu11==11.8.89
120 | - nvidia-cuda-nvrtc-cu12==12.1.105
121 | - nvidia-cuda-runtime-cu11==11.8.89
122 | - nvidia-cuda-runtime-cu12==12.1.105
123 | - nvidia-cudnn-cu11==8.7.0.84
124 | - nvidia-cudnn-cu12==8.9.2.26
125 | - nvidia-cufft-cu11==10.9.0.58
126 | - nvidia-cufft-cu12==11.0.2.54
127 | - nvidia-curand-cu11==10.3.0.86
128 | - nvidia-curand-cu12==10.3.2.106
129 | - nvidia-cusolver-cu11==11.4.1.48
130 | - nvidia-cusolver-cu12==11.4.5.107
131 | - nvidia-cusparse-cu11==11.7.5.86
132 | - nvidia-cusparse-cu12==12.1.0.106
133 | - nvidia-ml-py==12.555.43
134 | - nvidia-nccl-cu11==2.20.5
135 | - nvidia-nccl-cu12==2.20.5
136 | - nvidia-nvjitlink-cu12==12.5.40
137 | - nvidia-nvtx-cu11==11.8.86
138 | - nvidia-nvtx-cu12==12.1.105
139 | - omegaconf==2.3.0
140 | - opencv-python==4.10.0.82
141 | - orjson==3.10.3
142 | - packaging==24.0
143 | - pandas==2.2.2
144 | - pillow==10.2.0
145 | - platformdirs==4.2.2
146 | - pre-commit==3.7.1
147 | - proglog==0.1.10
148 | - protobuf==5.27.1
149 | - psutil==5.9.8
150 | - pydantic==2.7.3
151 | - pydantic-core==2.18.4
152 | - pydub==0.25.1
153 | - pygments==2.18.0
154 | - pyparsing==3.1.2
155 | - python-dateutil==2.9.0.post0
156 | - python-dotenv==1.0.1
157 | - python-multipart==0.0.9
158 | - pytz==2024.1
159 | - pyyaml==6.0.1
160 | - referencing==0.35.1
161 | - regex==2024.5.15
162 | - requests==2.32.3
163 | - rich==13.7.1
164 | - rpds-py==0.18.1
165 | - ruff==0.4.8
166 | - s3transfer==0.10.1
167 | - safetensors==0.4.3
168 | - semantic-version==2.10.0
169 | - sentry-sdk==2.5.0
170 | - setproctitle==1.3.3
171 | - shellingham==1.5.4
172 | - six==1.16.0
173 | - smmap==5.0.1
174 | - sniffio==1.3.1
175 | - soupsieve==2.5
176 | - starlette==0.37.2
177 | - sympy==1.12
178 | - tokenizers==0.19.1
179 | - tomlkit==0.12.0
180 | - toolz==0.12.1
181 | - torch==2.3.1+cu118
182 | - torchaudio==2.3.1+cu118
183 | - torchvision==0.18.1+cu118
184 | - tqdm==4.66.4
185 | - transformers==4.41.2
186 | - triton==2.3.1
187 | - typer==0.12.3
188 | - typing-extensions==4.9.0
189 | - tzdata==2024.1
190 | - ujson==5.10.0
191 | - urllib3==2.2.1
192 | - uvicorn==0.30.1
193 | - uvloop==0.19.0
194 | - virtualenv==20.26.2
195 | - wandb==0.17.1
196 | - watchfiles==0.22.0
197 | - wcwidth==0.2.13
198 | - websockets==11.0.3
199 | - xformers==0.0.26.post1
200 | - zipp==3.19.2
201 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # https://github.com/replicate/cog/blob/main/docs/python.md
3 |
4 | import os.path as osp
5 |
6 | import numpy as np
7 | import torch
8 | from cog import BasePredictor, Input, Path
9 | from omegaconf import OmegaConf
10 | from PIL import Image
11 |
12 | from animatediff.pipelines import I2VPipeline
13 | from animatediff.utils.util import save_videos_grid
14 |
15 |
16 | N_PROMPT = (
17 | "wrong white balance, dark, sketches,worst quality,low quality, "
18 | "deformed, distorted, disfigured, bad eyes, wrong lips, "
19 | "weird mouth, bad teeth, mutated hands and fingers, bad anatomy,"
20 | "wrong anatomy, amputation, extra limb, missing limb, "
21 | "floating,limbs, disconnected limbs, mutation, ugly, disgusting, "
22 | "bad_pictures, negative_hand-neg"
23 | )
24 |
25 |
26 | BASE_CONFIG = "example/config/base.yaml"
27 | STYLE_CONFIG_LIST = {
28 | "realistic": "example/replicate/1-realistic.yaml",
29 | "3d_cartoon": "example/replicate/3-3d.yaml",
30 | }
31 |
32 |
33 | PIA_PATH = "models/PIA"
34 | VAE_PATH = "models/VAE"
35 | DreamBooth_LoRA_PATH = "models/DreamBooth_LoRA"
36 | STABLE_DIFFUSION_PATH = "models/StableDiffusion"
37 |
38 |
39 | class Predictor(BasePredictor):
40 | def setup(self) -> None:
41 | """Load the model into memory to make running multiple predictions efficient"""
42 |
43 | self.ip_adapter_dir = "models/IP_Adapter/h94/IP-Adapter/models" # cached h94/IP-Adapter
44 |
45 | self.inference_config = OmegaConf.load("example/config/base.yaml")
46 | self.stable_diffusion_dir = self.inference_config.pretrained_model_path
47 | self.pia_path = self.inference_config.generate.model_path
48 | self.style_configs = {k: OmegaConf.load(v) for k, v in STYLE_CONFIG_LIST.items()}
49 | self.pipeline_dict = self.load_model_list()
50 |
51 | def load_model_list(self):
52 | pipeline_dict = {}
53 | for style, cfg in self.style_configs.items():
54 | print(f"Loading {style}")
55 | dreambooth_path = cfg.get("dreambooth", "none")
56 | if dreambooth_path and dreambooth_path.upper() != "NONE":
57 | dreambooth_path = osp.join(DreamBooth_LoRA_PATH, dreambooth_path)
58 | lora_path = cfg.get("lora", None)
59 | if lora_path is not None:
60 | lora_path = osp.join(DreamBooth_LoRA_PATH, lora_path)
61 | lora_alpha = cfg.get("lora_alpha", 0.0)
62 | vae_path = cfg.get("vae", None)
63 | if vae_path is not None:
64 | vae_path = osp.join(VAE_PATH, vae_path)
65 |
66 | pipeline_dict[style] = I2VPipeline.build_pipeline(
67 | self.inference_config,
68 | STABLE_DIFFUSION_PATH,
69 | unet_path=osp.join(PIA_PATH, "pia.ckpt"),
70 | dreambooth_path=dreambooth_path,
71 | lora_path=lora_path,
72 | lora_alpha=lora_alpha,
73 | vae_path=vae_path,
74 | ip_adapter_path=self.ip_adapter_dir,
75 | ip_adapter_scale=0.1,
76 | )
77 | return pipeline_dict
78 |
79 | def predict(
80 | self,
81 | prompt: str = Input(description="Input prompt."),
82 | image: Path = Input(description="Input image"),
83 | negative_prompt: str = Input(description="Things do not show in the output.", default=N_PROMPT),
84 | style: str = Input(
85 | description="Choose a style",
86 | choices=["3d_cartoon", "realistic"],
87 | default="3d_cartoon",
88 | ),
89 | max_size: int = Input(
90 | description="Max size (The long edge of the input image will be resized to this value, "
91 | "larger value means slower inference speed)",
92 | default=512,
93 | choices=[512, 576, 640, 704, 768, 832, 896, 960, 1024],
94 | ),
95 | motion_scale: int = Input(
96 | description="Larger value means larger motion but less identity consistency.",
97 | ge=1,
98 | le=3,
99 | default=1,
100 | ),
101 | sampling_steps: int = Input(description="Number of denoising steps", ge=10, le=100, default=25),
102 | animation_length: int = Input(description="Length of the output", ge=8, le=24, default=16),
103 | guidance_scale: float = Input(
104 | description="Scale for classifier-free guidance",
105 | ge=1.0,
106 | le=20.0,
107 | default=7.5,
108 | ),
109 | ip_adapter_scale: float = Input(
110 | description="Scale for classifier-free guidance",
111 | ge=0.0,
112 | le=1.0,
113 | default=0.0,
114 | ),
115 | seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None),
116 | ) -> Path:
117 | """Run a single prediction on the model"""
118 | if seed is None:
119 | torch.seed()
120 | seed = torch.initial_seed()
121 | else:
122 | torch.manual_seed(seed)
123 | print(f"Using seed: {seed}")
124 |
125 | pipeline = self.pipeline_dict[style]
126 |
127 | init_img, h, w = preprocess_img(str(image), max_size)
128 |
129 | sample = pipeline(
130 | image=init_img,
131 | prompt=prompt,
132 | negative_prompt=negative_prompt,
133 | num_inference_steps=sampling_steps,
134 | guidance_scale=guidance_scale,
135 | width=w,
136 | height=h,
137 | video_length=animation_length,
138 | mask_sim_template_idx=motion_scale,
139 | ip_adapter_scale=ip_adapter_scale,
140 | ).videos
141 |
142 | out_path = "/tmp/out.mp4"
143 | save_videos_grid(sample, out_path)
144 | return Path(out_path)
145 |
146 |
147 | def preprocess_img(img_np, max_size: int = 512):
148 | ori_image = Image.open(img_np).convert("RGB")
149 |
150 | width, height = ori_image.size
151 |
152 | long_edge = max(width, height)
153 | if long_edge > max_size:
154 | scale_factor = max_size / long_edge
155 | else:
156 | scale_factor = 1
157 | width = int(width * scale_factor)
158 | height = int(height * scale_factor)
159 | ori_image = ori_image.resize((width, height))
160 |
161 | if (width % 8 != 0) or (height % 8 != 0):
162 | in_width = (width // 8) * 8
163 | in_height = (height // 8) * 8
164 | else:
165 | in_width = width
166 | in_height = height
167 |
168 | in_image = ori_image.resize((in_width, in_height))
169 | in_image_np = np.array(in_image)
170 | return in_image_np, in_height, in_width
171 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | # Never enforce `E501` (line length violations).
3 | ignore = ["C901", "E501", "E741", "F402", "F823"]
4 | select = ["C", "E", "F", "I", "W"]
5 | line-length = 119
6 |
7 | # Ignore import violations in all `__init__.py` files.
8 | [tool.ruff.per-file-ignores]
9 | "__init__.py" = ["E402", "F401", "F403", "F811"]
10 | "src/diffusers/utils/dummy_*.py" = ["F401"]
11 |
12 | [tool.ruff.isort]
13 | lines-after-imports = 2
14 | known-first-party = ["diffusers"]
15 |
16 | [tool.ruff.format]
17 | # Like Black, use double quotes for strings.
18 | quote-style = "double"
19 |
20 | # Like Black, indent with spaces, rather than tabs.
21 | indent-style = "space"
22 |
23 | # Like Black, respect magic trailing commas.
24 | skip-magic-trailing-comma = false
25 |
26 | # Like Black, automatically detect the appropriate line ending.
27 | line-ending = "auto"
28 |
--------------------------------------------------------------------------------