├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── __assets__ └── image_animation │ ├── loop │ ├── labrador │ │ ├── 1.gif │ │ ├── 2.gif │ │ └── 3.gif │ └── lighthouse │ │ ├── 1.gif │ │ ├── 2.gif │ │ └── 3.gif │ ├── magnitude │ ├── bear │ │ ├── 1.gif │ │ ├── 2.gif │ │ └── 3.gif │ ├── genshin │ │ ├── 1.gif │ │ ├── 2.gif │ │ └── 3.gif │ └── labrador │ │ ├── 1.gif │ │ ├── 2.gif │ │ └── 3.gif │ ├── majic │ ├── 1.gif │ ├── 2.gif │ └── 3.gif │ ├── rcnz │ ├── 1.gif │ ├── 2.gif │ └── 3.gif │ ├── real │ ├── 1.gif │ ├── 2.gif │ └── 3.gif │ ├── style_transfer │ ├── anya │ │ ├── 1.gif │ │ ├── 2.gif │ │ └── 3.gif │ ├── bear │ │ ├── 1.gif │ │ ├── 2.gif │ │ └── 3.gif │ └── concert │ │ ├── 1.gif │ │ ├── 2.gif │ │ ├── 3.gif │ │ ├── 4.gif │ │ ├── 5.gif │ │ └── 6.gif │ └── teaser │ └── teaser.gif ├── animatediff ├── data │ ├── dataset.py │ └── video_transformer.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── motion_module.py │ ├── resnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ ├── __init__.py │ ├── i2v_pipeline.py │ ├── pipeline_animation.py │ └── validation_pipeline.py └── utils │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ └── util.py ├── app.py ├── cog.yaml ├── download_bashscripts ├── 1-RealisticVision.sh ├── 2-RcnzCartoon.sh └── 3-MajicMix.sh ├── environment-pt2.yaml ├── environment.yaml ├── example ├── config │ ├── anya.yaml │ ├── base.yaml │ ├── bear.yaml │ ├── concert.yaml │ ├── genshin.yaml │ ├── harry.yaml │ ├── labrador.yaml │ ├── lighthouse.yaml │ ├── majic_girl.yaml │ └── train.yaml ├── img │ ├── anya.jpg │ ├── bear.jpg │ ├── concert.png │ ├── genshin.jpg │ ├── harry.png │ ├── labrador.png │ ├── lighthouse.jpg │ └── majic_girl.jpg ├── openxlab │ ├── 1-realistic.yaml │ └── 3-3d.yaml └── replicate │ ├── 1-realistic.yaml │ └── 3-3d.yaml ├── inference.py ├── models ├── DreamBooth_LoRA │ └── Put personalized T2I checkpoints here.txt ├── IP_Adapter │ └── Put IP-Adapter checkpoints here.txt ├── Motion_Module │ └── Put motion module checkpoints here.txt └── VAE │ └── Put VAE checkpoints here.txt ├── pia.png ├── pia.yml ├── predict.py ├── pyproject.toml └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.pt 3 | *.mov 4 | *.pth 5 | *.json 6 | *.mov 7 | *.npz 8 | *.npy 9 | *.boj 10 | *.onnx 11 | *.tar 12 | *.bin 13 | cache* 14 | batch* 15 | *.jpg 16 | *.png 17 | *.mp4 18 | *.gif 19 | *.ckpt 20 | *.safetensors 21 | *.zip 22 | *.csv 23 | *.log 24 | 25 | **/__pycache__/ 26 | samples/ 27 | wandb/ 28 | outputs/ 29 | example/result 30 | models/StableDiffusion 31 | models/PIA 32 | 33 | !pia.png 34 | .DS_Store 35 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.3.5 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [ --fix ] 9 | # Run the formatter. 10 | - id: ruff-format 11 | - repo: https://github.com/codespell-project/codespell 12 | rev: v2.2.1 13 | hooks: 14 | - id: codespell 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v4.3.0 17 | hooks: 18 | - id: trailing-whitespace 19 | - id: check-yaml 20 | - id: end-of-file-fixer 21 | - id: requirements-txt-fixer 22 | - id: fix-encoding-pragma 23 | args: ["--remove"] 24 | - id: mixed-line-ending 25 | args: ["--fix=lf"] 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVPR 2024 | PIA:Personalized Image Animator 2 | 3 | [**PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models**](https://arxiv.org/abs/2312.13964) 4 | 5 | [Yiming Zhang*](https://github.com/ymzhang0319), [Zhening Xing*](https://github.com/LeoXing1996/), [Yanhong Zeng†](https://zengyh1900.github.io/), [Youqing Fang](https://github.com/FangYouqing), [Kai Chen†](https://chenkai.site/) 6 | 7 | (*equal contribution, †corresponding Author) 8 | 9 | 10 | [![arXiv](https://img.shields.io/badge/arXiv-2312.13964-b31b1b.svg)](https://arxiv.org/abs/2312.13964) 11 | [![Project Page](https://img.shields.io/badge/PIA-Website-green)](https://pi-animator.github.io) 12 | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia) 13 | [![Third Party Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/PIA-colab/blob/main/PIA_colab.ipynb) 14 | [![HuggingFace Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/Leoxing/PIA) 15 | 16 | Open in HugginFace 17 | 18 | [![Replicate](https://replicate.com/cjwbw/pia/badge)](https://replicate.com/cjwbw/pia) 19 | 20 | 21 | PIA is a personalized image animation method which can generate videos with **high motion controllability** and **strong text and image alignment**. 22 | 23 | If you find our project helpful, please give it a star :star: or [cite](#bibtex) it, we would be very grateful :sparkling_heart: . 24 | 25 | 26 | 27 | 28 | ## What's New 29 | - [x] `2024/01/03` [Replicate Demo & API](https://replicate.com/cjwbw/pia) support! 30 | - [x] `2024/01/03` [Colab](https://github.com/camenduru/PIA-colab) support from [camenduru](https://github.com/camenduru)! 31 | - [x] `2023/12/28` Support `scaled_dot_product_attention` for 1024x1024 images with just 16GB of GPU memory. 32 | - [x] `2023/12/25` HuggingFace demo is available now! [🤗 Hub](https://huggingface.co/spaces/Leoxing/PIA/) 33 | - [x] `2023/12/22` Release the demo of PIA on [OpenXLab](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia) and checkpoints on [Google Drive](https://drive.google.com/file/d/1RL3Fp0Q6pMD8PbGPULYUnvjqyRQXGHwN/view?usp=drive_link) or [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/zhangyiming/PIA) 34 | 35 | ## Setup 36 | ### Prepare Environment 37 | 38 | Use the following command to install a conda environment for PIA from scratch: 39 | 40 | ``` 41 | conda env create -f pia.yml 42 | conda activate pia 43 | ``` 44 | You may also want to install it based on an existing environment, then you can use `environment-pt2.yaml` for Pytorch==2.0.0. If you want to use lower version of Pytorch (e.g. 1.13.1), you can use the following command: 45 | 46 | ``` 47 | conda env create -f environment.yaml 48 | conda activate pia 49 | ``` 50 | 51 | We strongly recommend you to use Pytorch==2.0.0 which supports `scaled_dot_product_attention` for memory-efficient image animation. 52 | 53 | ### Download checkpoints 54 |
  • Download the Stable Diffusion v1-5
  • 55 | 56 | ``` 57 | conda install git-lfs 58 | git lfs install 59 | git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/ 60 | ``` 61 | 62 |
  • Download PIA
  • 63 | 64 | ``` 65 | git clone https://huggingface.co/Leoxing/PIA models/PIA/ 66 | ``` 67 | 68 |
  • Download Personalized Models
  • 69 | 70 | ``` 71 | bash download_bashscripts/1-RealisticVision.sh 72 | bash download_bashscripts/2-RcnzCartoon.sh 73 | bash download_bashscripts/3-MajicMix.sh 74 | ``` 75 | 76 | 77 | You can also download *pia.ckpt* manually through link on [Google Drive](https://drive.google.com/file/d/1RL3Fp0Q6pMD8PbGPULYUnvjqyRQXGHwN/view?usp=drive_link) 78 | or [HuggingFace](https://huggingface.co/Leoxing/PIA). 79 | 80 | Put checkpoints as follows: 81 | ``` 82 | └── models 83 | ├── DreamBooth_LoRA 84 | │ ├── ... 85 | ├── PIA 86 | │ ├── pia.ckpt 87 | └── StableDiffusion 88 | ├── vae 89 | ├── unet 90 | └── ... 91 | ``` 92 | 93 | ## Inference 94 | ### Image Animation 95 | Image to Video result can be obtained by: 96 | ``` 97 | python inference.py --config=example/config/lighthouse.yaml 98 | python inference.py --config=example/config/harry.yaml 99 | python inference.py --config=example/config/majic_girl.yaml 100 | ``` 101 | Run the command above, then you can find the results in example/result: 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 |

    Input Image

    lightning, lighthouse

    sun rising, lighthouse

    fireworks, lighthouse

    Input Image

    1boy smiling

    1boy playing the magic fire

    1boy is waving hands

    Input Image

    1girl is smiling

    1girl is crying

    1girl, snowing

    141 | 142 | 151 | 152 | ### Motion Magnitude 153 | You can control the motion magnitude through the parameter **magnitude**: 154 | ```sh 155 | python inference.py --config=example/config/xxx.yaml --magnitude=0 # Small Motion 156 | python inference.py --config=example/config/xxx.yaml --magnitude=1 # Moderate Motion 157 | python inference.py --config=example/config/xxx.yaml --magnitude=2 # Large Motion 158 | ``` 159 | Examples: 160 | 161 | ```sh 162 | python inference.py --config=example/config/labrador.yaml 163 | python inference.py --config=example/config/bear.yaml 164 | python inference.py --config=example/config/genshin.yaml 165 | ``` 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 |

    Input Image
    & Prompt

    Small Motion

    Moderate Motion

    Large Motion

    a golden labrador is running
    1bear is walking, ...
    cherry blossom, ...
    193 | 194 | ### Style Transfer 195 | To achieve style transfer, you can run the command(*Please don't forget set the base model in xxx.yaml*): 196 | 197 | Examples: 198 | 199 | ```sh 200 | python inference.py --config example/config/concert.yaml --style_transfer 201 | python inference.py --config example/config/anya.yaml --style_transfer 202 | ``` 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 |

    Input Image
    & Base Model

    1man is smiling

    1man is crying

    1man is singing

    Realistic Vision
    RCNZ Cartoon 3d

    1girl smiling

    1girl open mouth

    1girl is crying, pout

    RCNZ Cartoon 3d
    235 | 236 | ### Loop Video 237 | 238 | You can generate loop by using the parameter --loop 239 | 240 | ```sh 241 | python inference.py --config=example/config/xxx.yaml --loop 242 | ``` 243 | 244 | Examples: 245 | ```sh 246 | python inference.py --config=example/config/lighthouse.yaml --loop 247 | python inference.py --config=example/config/labrador.yaml --loop 248 | ``` 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 |

    Input Image

    lightning, lighthouse

    sun rising, lighthouse

    fireworks, lighthouse

    Input Image

    labrador jumping

    labrador walking

    labrador running

    276 | 277 | 278 | ## Training 279 | 280 | We provide [training script]("train.py") for PIA. It borrows from [AnimateDiff](https://github.com/guoyww/AnimateDiff/tree/main) heavily, so please prepare the dataset and configuration files according to the [guideline](https://github.com/guoyww/AnimateDiff/blob/main/__assets__/docs/animatediff.md#steps-for-training). 281 | 282 | After preparation, you can train the model by running the following command using torchrun: 283 | 284 | ```shell 285 | torchrun --nnodes=1 --nproc_per_node=1 train.py --config example/config/train.yaml 286 | ``` 287 | 288 | or by slurm, 289 | ```shell 290 | srun --quotatype=reserved --job-name=pia --gres=gpu:8 --ntasks-per-node=8 --ntasks=8 --cpus-per-task=4 --kill-on-bad-exit=1 python train.py --config example/config/train.yaml 291 | ``` 292 | 293 | 294 | ## AnimateBench 295 | We have open-sourced AnimateBench on [HuggingFace](https://huggingface.co/datasets/ymzhang319/AnimateBench) which includes images, prompts and configs to evaluate PIA and other image animation methods. 296 | 297 | 298 | ## BibTex 299 | ``` 300 | @inproceedings{zhang2024pia, 301 | title={Pia: Your personalized image animator via plug-and-play modules in text-to-image models}, 302 | author={Zhang, Yiming and Xing, Zhening and Zeng, Yanhong and Fang, Youqing and Chen, Kai}, 303 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 304 | pages={7747--7756}, 305 | year={2024} 306 | } 307 | ``` 308 | 309 | 310 | 311 | 312 | ## Contact Us 313 | **Yiming Zhang**: zhangyiming@pjlab.org.cn 314 | 315 | **Zhening Xing**: xingzhening@pjlab.org.cn 316 | 317 | **Yanhong Zeng**: zengyh1900@gmail.com 318 | 319 | ## Acknowledgements 320 | The code is built upon [AnimateDiff](https://github.com/guoyww/AnimateDiff), [Tune-a-Video](https://github.com/showlab/Tune-A-Video) and [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) 321 | 322 | You may also want to try other project from our team: 323 | 324 | MMagic 325 | 326 | -------------------------------------------------------------------------------- /__assets__/image_animation/loop/labrador/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/labrador/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/loop/labrador/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/labrador/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/loop/labrador/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/labrador/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/loop/lighthouse/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/lighthouse/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/loop/lighthouse/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/lighthouse/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/loop/lighthouse/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/loop/lighthouse/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/bear/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/bear/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/bear/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/bear/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/bear/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/bear/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/genshin/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/genshin/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/genshin/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/genshin/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/genshin/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/genshin/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/labrador/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/labrador/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/labrador/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/labrador/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/magnitude/labrador/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/magnitude/labrador/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/majic/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/majic/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/majic/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/majic/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/majic/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/majic/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/rcnz/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/rcnz/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/rcnz/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/rcnz/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/rcnz/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/rcnz/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/real/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/real/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/real/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/real/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/real/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/real/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/anya/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/anya/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/anya/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/anya/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/anya/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/anya/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/bear/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/bear/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/bear/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/bear/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/bear/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/bear/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/concert/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/1.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/concert/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/2.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/concert/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/3.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/concert/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/4.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/concert/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/5.gif -------------------------------------------------------------------------------- /__assets__/image_animation/style_transfer/concert/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/style_transfer/concert/6.gif -------------------------------------------------------------------------------- /__assets__/image_animation/teaser/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/__assets__/image_animation/teaser/teaser.gif -------------------------------------------------------------------------------- /animatediff/data/dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import io 3 | import os 4 | import random 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms as transforms 10 | from decord import VideoReader 11 | from torch.utils.data.dataset import Dataset 12 | 13 | import animatediff.data.video_transformer as video_transforms 14 | from animatediff.utils.util import detect_edges, zero_rank_print 15 | 16 | 17 | try: 18 | from petrel_client.client import Client 19 | except ImportError as e: 20 | print(e) 21 | 22 | 23 | def get_score(video_data, cond_frame_idx, weight=[1.0, 1.0, 1.0, 1.0], use_edge=True): 24 | """ 25 | Similar to get_score under utils/util.py/detect_edges 26 | """ 27 | """ 28 | the shape of video_data is f c h w, np.ndarray 29 | """ 30 | h, w = video_data.shape[1], video_data.shape[2] 31 | 32 | cond_frame = video_data[cond_frame_idx] 33 | cond_hsv_list = list(cv2.split(cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV))) 34 | 35 | if use_edge: 36 | cond_frame_lum = cond_hsv_list[-1] 37 | cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8)) 38 | cond_hsv_list.append(cond_frame_edge) 39 | 40 | score_sum = [] 41 | 42 | for frame_idx in range(video_data.shape[0]): 43 | frame = video_data[frame_idx] 44 | hsv_list = list(cv2.split(cv2.cvtColor(frame.astype(np.float32), cv2.COLOR_RGB2HSV))) 45 | 46 | if use_edge: 47 | frame_img_lum = hsv_list[-1] 48 | frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8)) 49 | hsv_list.append(frame_img_edge) 50 | 51 | hsv_diff = [np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight))] 52 | hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))] 53 | score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight))) 54 | 55 | return score_sum 56 | 57 | 58 | class WebVid10M(Dataset): 59 | def __init__( 60 | self, 61 | csv_path, 62 | video_folder, 63 | sample_size=256, 64 | sample_stride=4, 65 | sample_n_frames=16, 66 | is_image=False, 67 | use_petreloss=False, 68 | conf_path=None, 69 | ): 70 | if use_petreloss: 71 | self._client = Client(conf_path=conf_path, enable_mc=True) 72 | else: 73 | self._client = None 74 | self.video_folder = video_folder 75 | self.sample_stride = sample_stride 76 | self.sample_n_frames = sample_n_frames 77 | self.is_image = is_image 78 | self.temporal_sampler = video_transforms.TemporalRandomCrop(sample_n_frames * sample_stride) 79 | 80 | zero_rank_print(f"loading annotations from {csv_path} ...") 81 | with open(csv_path, "r") as csvfile: 82 | self.dataset = list(csv.DictReader(csvfile)) 83 | self.length = len(self.dataset) 84 | zero_rank_print(f"data scale: {self.length}") 85 | 86 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 87 | self.pixel_transforms = transforms.Compose( 88 | [ 89 | transforms.RandomHorizontalFlip(), 90 | transforms.Resize(sample_size[0]), 91 | transforms.CenterCrop(sample_size), 92 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 93 | ] 94 | ) 95 | 96 | def get_batch(self, idx): 97 | video_dict = self.dataset[idx] 98 | videoid, name, page_dir = video_dict["videoid"], video_dict["name"], video_dict["page_dir"] 99 | 100 | if self._client is not None: 101 | video_dir = os.path.join(self.video_folder, page_dir, f"{videoid}.mp4") 102 | video_bytes = self._client.Get(video_dir) 103 | video_bytes = io.BytesIO(video_bytes) 104 | # ensure not reading zero byte 105 | assert video_bytes.getbuffer().nbytes != 0 106 | video_reader = VideoReader(video_bytes) 107 | else: 108 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") 109 | video_reader = VideoReader(video_dir) 110 | 111 | total_frames = len(video_reader) 112 | if not self.is_image: 113 | start_frame_ind, end_frame_ind = self.temporal_sampler(total_frames) 114 | assert end_frame_ind - start_frame_ind >= self.sample_n_frames 115 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int) 116 | else: 117 | frame_indice = [random.randint(0, total_frames - 1)] 118 | 119 | pixel_values_np = video_reader.get_batch(frame_indice).asnumpy() 120 | cond_frames = random.randint(0, self.sample_n_frames - 1) 121 | 122 | # f h w c -> f c h w 123 | pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous() 124 | pixel_values = pixel_values / 255.0 125 | del video_reader 126 | 127 | if self.is_image: 128 | pixel_values = pixel_values[0] 129 | 130 | return pixel_values, name, cond_frames, videoid 131 | 132 | def __len__(self): 133 | return self.length 134 | 135 | def __getitem__(self, idx): 136 | while True: 137 | try: 138 | video, name, cond_frames, videoid = self.get_batch(idx) 139 | break 140 | 141 | except Exception: 142 | zero_rank_print("Error loading video, retrying...") 143 | idx = random.randint(0, self.length - 1) 144 | 145 | video = self.pixel_transforms(video) 146 | video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5 147 | video_ = video_ * 255 148 | score = get_score(video_, cond_frame_idx=cond_frames) 149 | del video_ 150 | sample = {"pixel_values": video, "text": name, "score": score, "cond_frames": cond_frames, "vid": videoid} 151 | return sample 152 | -------------------------------------------------------------------------------- /animatediff/data/video_transformer.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | 4 | import torch 5 | 6 | 7 | def _is_tensor_video_clip(clip): 8 | if not torch.is_tensor(clip): 9 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 10 | 11 | if not clip.ndimension() == 4: 12 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 13 | 14 | return True 15 | 16 | 17 | def crop(clip, i, j, h, w): 18 | """ 19 | Args: 20 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 21 | """ 22 | if len(clip.size()) != 4: 23 | raise ValueError("clip should be a 4D tensor") 24 | return clip[..., i : i + h, j : j + w] 25 | 26 | 27 | def resize(clip, target_size, interpolation_mode): 28 | if len(target_size) != 2: 29 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 30 | return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) 31 | 32 | 33 | def resize_scale(clip, target_size, interpolation_mode): 34 | if len(target_size) != 2: 35 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 36 | _, _, H, W = clip.shape 37 | scale_ = target_size[0] / min(H, W) 38 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) 39 | 40 | 41 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 42 | """ 43 | Do spatial cropping and resizing to the video clip 44 | Args: 45 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 46 | i (int): i in (i,j) i.e coordinates of the upper left corner. 47 | j (int): j in (i,j) i.e coordinates of the upper left corner. 48 | h (int): Height of the cropped region. 49 | w (int): Width of the cropped region. 50 | size (tuple(int, int)): height and width of resized clip 51 | Returns: 52 | clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) 53 | """ 54 | if not _is_tensor_video_clip(clip): 55 | raise ValueError("clip should be a 4D torch.tensor") 56 | clip = crop(clip, i, j, h, w) 57 | clip = resize(clip, size, interpolation_mode) 58 | return clip 59 | 60 | 61 | def center_crop(clip, crop_size): 62 | if not _is_tensor_video_clip(clip): 63 | raise ValueError("clip should be a 4D torch.tensor") 64 | h, w = clip.size(-2), clip.size(-1) 65 | th, tw = crop_size 66 | if h < th or w < tw: 67 | raise ValueError("height and width must be no smaller than crop_size") 68 | 69 | i = int(round((h - th) / 2.0)) 70 | j = int(round((w - tw) / 2.0)) 71 | return crop(clip, i, j, th, tw) 72 | 73 | 74 | def random_shift_crop(clip): 75 | """ 76 | Slide along the long edge, with the short edge as crop size 77 | """ 78 | if not _is_tensor_video_clip(clip): 79 | raise ValueError("clip should be a 4D torch.tensor") 80 | h, w = clip.size(-2), clip.size(-1) 81 | 82 | if h <= w: 83 | # long_edge = w 84 | short_edge = h 85 | else: 86 | # long_edge = h 87 | short_edge = w 88 | 89 | th, tw = short_edge, short_edge 90 | 91 | i = torch.randint(0, h - th + 1, size=(1,)).item() 92 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 93 | return crop(clip, i, j, th, tw) 94 | 95 | 96 | def to_tensor(clip): 97 | """ 98 | Convert tensor data type from uint8 to float, divide value by 255.0 and 99 | permute the dimensions of clip tensor 100 | Args: 101 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 102 | Return: 103 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 104 | """ 105 | _is_tensor_video_clip(clip) 106 | if not clip.dtype == torch.uint8: 107 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) 108 | # return clip.float().permute(3, 0, 1, 2) / 255.0 109 | return clip.float() / 255.0 110 | 111 | 112 | def normalize(clip, mean, std, inplace=False): 113 | """ 114 | Args: 115 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 116 | mean (tuple): pixel RGB mean. Size is (3) 117 | std (tuple): pixel standard deviation. Size is (3) 118 | Returns: 119 | normalized clip (torch.tensor): Size is (T, C, H, W) 120 | """ 121 | if not _is_tensor_video_clip(clip): 122 | raise ValueError("clip should be a 4D torch.tensor") 123 | if not inplace: 124 | clip = clip.clone() 125 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 126 | print(mean) 127 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 128 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 129 | return clip 130 | 131 | 132 | def hflip(clip): 133 | """ 134 | Args: 135 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 136 | Returns: 137 | flipped clip (torch.tensor): Size is (T, C, H, W) 138 | """ 139 | if not _is_tensor_video_clip(clip): 140 | raise ValueError("clip should be a 4D torch.tensor") 141 | return clip.flip(-1) 142 | 143 | 144 | class RandomCropVideo: 145 | def __init__(self, size): 146 | if isinstance(size, numbers.Number): 147 | self.size = (int(size), int(size)) 148 | else: 149 | self.size = size 150 | 151 | def __call__(self, clip): 152 | """ 153 | Args: 154 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 155 | Returns: 156 | torch.tensor: randomly cropped video clip. 157 | size is (T, C, OH, OW) 158 | """ 159 | i, j, h, w = self.get_params(clip) 160 | return crop(clip, i, j, h, w) 161 | 162 | def get_params(self, clip): 163 | h, w = clip.shape[-2:] 164 | th, tw = self.size 165 | 166 | if h < th or w < tw: 167 | raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") 168 | 169 | if w == tw and h == th: 170 | return 0, 0, h, w 171 | 172 | i = torch.randint(0, h - th + 1, size=(1,)).item() 173 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 174 | 175 | return i, j, th, tw 176 | 177 | def __repr__(self) -> str: 178 | return f"{self.__class__.__name__}(size={self.size})" 179 | 180 | 181 | class UCFCenterCropVideo: 182 | def __init__( 183 | self, 184 | size, 185 | interpolation_mode="bilinear", 186 | ): 187 | if isinstance(size, tuple): 188 | if len(size) != 2: 189 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 190 | self.size = size 191 | else: 192 | self.size = (size, size) 193 | 194 | self.interpolation_mode = interpolation_mode 195 | 196 | def __call__(self, clip): 197 | """ 198 | Args: 199 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 200 | Returns: 201 | torch.tensor: scale resized / center cropped video clip. 202 | size is (T, C, crop_size, crop_size) 203 | """ 204 | clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) 205 | clip_center_crop = center_crop(clip_resize, self.size) 206 | return clip_center_crop 207 | 208 | def __repr__(self) -> str: 209 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 210 | 211 | 212 | class KineticsRandomCropResizeVideo: 213 | """ 214 | Slide along the long edge, with the short edge as crop size. And resize to the desired size. 215 | """ 216 | 217 | def __init__( 218 | self, 219 | size, 220 | interpolation_mode="bilinear", 221 | ): 222 | if isinstance(size, tuple): 223 | if len(size) != 2: 224 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 225 | self.size = size 226 | else: 227 | self.size = (size, size) 228 | 229 | self.interpolation_mode = interpolation_mode 230 | 231 | def __call__(self, clip): 232 | clip_random_crop = random_shift_crop(clip) 233 | clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) 234 | return clip_resize 235 | 236 | 237 | class CenterCropVideo: 238 | def __init__( 239 | self, 240 | size, 241 | interpolation_mode="bilinear", 242 | ): 243 | if isinstance(size, tuple): 244 | if len(size) != 2: 245 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 246 | self.size = size 247 | else: 248 | self.size = (size, size) 249 | 250 | self.interpolation_mode = interpolation_mode 251 | 252 | def __call__(self, clip): 253 | """ 254 | Args: 255 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 256 | Returns: 257 | torch.tensor: center cropped video clip. 258 | size is (T, C, crop_size, crop_size) 259 | """ 260 | clip_center_crop = center_crop(clip, self.size) 261 | return clip_center_crop 262 | 263 | def __repr__(self) -> str: 264 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 265 | 266 | 267 | class NormalizeVideo: 268 | """ 269 | Normalize the video clip by mean subtraction and division by standard deviation 270 | Args: 271 | mean (3-tuple): pixel RGB mean 272 | std (3-tuple): pixel RGB standard deviation 273 | inplace (boolean): whether do in-place normalization 274 | """ 275 | 276 | def __init__(self, mean, std, inplace=False): 277 | self.mean = mean 278 | self.std = std 279 | self.inplace = inplace 280 | 281 | def __call__(self, clip): 282 | """ 283 | Args: 284 | clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) 285 | """ 286 | return normalize(clip, self.mean, self.std, self.inplace) 287 | 288 | def __repr__(self) -> str: 289 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" 290 | 291 | 292 | class ToTensorVideo: 293 | """ 294 | Convert tensor data type from uint8 to float, divide value by 255.0 and 295 | permute the dimensions of clip tensor 296 | """ 297 | 298 | def __init__(self): 299 | pass 300 | 301 | def __call__(self, clip): 302 | """ 303 | Args: 304 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 305 | Return: 306 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 307 | """ 308 | return to_tensor(clip) 309 | 310 | def __repr__(self) -> str: 311 | return self.__class__.__name__ 312 | 313 | 314 | class RandomHorizontalFlipVideo: 315 | """ 316 | Flip the video clip along the horizontal direction with a given probability 317 | Args: 318 | p (float): probability of the clip being flipped. Default value is 0.5 319 | """ 320 | 321 | def __init__(self, p=0.5): 322 | self.p = p 323 | 324 | def __call__(self, clip): 325 | """ 326 | Args: 327 | clip (torch.tensor): Size is (T, C, H, W) 328 | Return: 329 | clip (torch.tensor): Size is (T, C, H, W) 330 | """ 331 | if random.random() < self.p: 332 | clip = hflip(clip) 333 | return clip 334 | 335 | def __repr__(self) -> str: 336 | return f"{self.__class__.__name__}(p={self.p})" 337 | 338 | 339 | # ------------------------------------------------------------ 340 | # --------------------- Sampling --------------------------- 341 | # ------------------------------------------------------------ 342 | class TemporalRandomCrop(object): 343 | """Temporally crop the given frame indices at a random location. 344 | 345 | Args: 346 | size (int): Desired length of frames will be seen in the model. 347 | """ 348 | 349 | def __init__(self, size): 350 | self.size = size 351 | 352 | def __call__(self, total_frames): 353 | rand_end = max(0, total_frames - self.size - 1) 354 | begin_index = random.randint(0, rand_end) 355 | end_index = min(begin_index + self.size, total_frames) 356 | return begin_index, end_index 357 | 358 | 359 | if __name__ == "__main__": 360 | import os 361 | 362 | import numpy as np 363 | import torchvision.io as io 364 | from torchvision import transforms 365 | from torchvision.utils import save_image 366 | 367 | vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW") 368 | 369 | trans = transforms.Compose( 370 | [ 371 | ToTensorVideo(), 372 | RandomHorizontalFlipVideo(), 373 | UCFCenterCropVideo(512), 374 | # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 375 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 376 | ] 377 | ) 378 | 379 | target_video_len = 32 380 | frame_interval = 1 381 | total_frames = len(vframes) 382 | print(total_frames) 383 | 384 | temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) 385 | 386 | # Sampling video frames 387 | start_frame_ind, end_frame_ind = temporal_sample(total_frames) 388 | # print(start_frame_ind) 389 | # print(end_frame_ind) 390 | assert end_frame_ind - start_frame_ind >= target_video_len 391 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) 392 | 393 | select_vframes = vframes[frame_indice] 394 | 395 | select_vframes_trans = trans(select_vframes) 396 | 397 | select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) 398 | 399 | io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) 400 | 401 | for i in range(target_video_len): 402 | save_image( 403 | select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1) 404 | ) 405 | -------------------------------------------------------------------------------- /animatediff/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/animatediff/models/__init__.py -------------------------------------------------------------------------------- /animatediff/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/guoyww/AnimateDiff 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from einops import rearrange, repeat 9 | from torch import nn 10 | 11 | from diffusers.configuration_utils import ConfigMixin, register_to_config 12 | from diffusers.models import ModelMixin 13 | from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward 14 | from diffusers.utils import BaseOutput 15 | from diffusers.utils.import_utils import is_xformers_available 16 | 17 | 18 | @dataclass 19 | class Transformer3DModelOutput(BaseOutput): 20 | sample: torch.FloatTensor 21 | 22 | 23 | if is_xformers_available(): 24 | import xformers 25 | import xformers.ops 26 | else: 27 | xformers = None 28 | 29 | 30 | class Transformer3DModel(ModelMixin, ConfigMixin): 31 | @register_to_config 32 | def __init__( 33 | self, 34 | num_attention_heads: int = 16, 35 | attention_head_dim: int = 88, 36 | in_channels: Optional[int] = None, 37 | num_layers: int = 1, 38 | dropout: float = 0.0, 39 | norm_num_groups: int = 32, 40 | cross_attention_dim: Optional[int] = None, 41 | attention_bias: bool = False, 42 | activation_fn: str = "geglu", 43 | num_embeds_ada_norm: Optional[int] = None, 44 | use_linear_projection: bool = False, 45 | only_cross_attention: bool = False, 46 | upcast_attention: bool = False, 47 | unet_use_cross_frame_attention=None, 48 | unet_use_temporal_attention=None, 49 | ): 50 | super().__init__() 51 | self.use_linear_projection = use_linear_projection 52 | self.num_attention_heads = num_attention_heads 53 | self.attention_head_dim = attention_head_dim 54 | inner_dim = num_attention_heads * attention_head_dim 55 | 56 | # Define input layers 57 | self.in_channels = in_channels 58 | 59 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 60 | if use_linear_projection: 61 | self.proj_in = nn.Linear(in_channels, inner_dim) 62 | else: 63 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 64 | 65 | # Define transformers blocks 66 | self.transformer_blocks = nn.ModuleList( 67 | [ 68 | BasicTransformerBlock( 69 | inner_dim, 70 | num_attention_heads, 71 | attention_head_dim, 72 | dropout=dropout, 73 | cross_attention_dim=cross_attention_dim, 74 | activation_fn=activation_fn, 75 | num_embeds_ada_norm=num_embeds_ada_norm, 76 | attention_bias=attention_bias, 77 | only_cross_attention=only_cross_attention, 78 | upcast_attention=upcast_attention, 79 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 80 | unet_use_temporal_attention=unet_use_temporal_attention, 81 | ) 82 | for d in range(num_layers) 83 | ] 84 | ) 85 | 86 | # 4. Define output layers 87 | if use_linear_projection: 88 | self.proj_out = nn.Linear(in_channels, inner_dim) 89 | else: 90 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 91 | 92 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 93 | # Input 94 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 95 | video_length = hidden_states.shape[2] 96 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 97 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length) 98 | 99 | batch, channel, height, weight = hidden_states.shape 100 | residual = hidden_states 101 | 102 | hidden_states = self.norm(hidden_states) 103 | if not self.use_linear_projection: 104 | hidden_states = self.proj_in(hidden_states) 105 | inner_dim = hidden_states.shape[1] 106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 107 | else: 108 | inner_dim = hidden_states.shape[1] 109 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 110 | hidden_states = self.proj_in(hidden_states) 111 | 112 | # Blocks 113 | for block in self.transformer_blocks: 114 | hidden_states = block( 115 | hidden_states, 116 | encoder_hidden_states=encoder_hidden_states, 117 | timestep=timestep, 118 | video_length=video_length, 119 | ) 120 | 121 | # Output 122 | if not self.use_linear_projection: 123 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 124 | hidden_states = self.proj_out(hidden_states) 125 | else: 126 | hidden_states = self.proj_out(hidden_states) 127 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 128 | 129 | output = hidden_states + residual 130 | 131 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 132 | if not return_dict: 133 | return (output,) 134 | 135 | return Transformer3DModelOutput(sample=output) 136 | 137 | 138 | class BasicTransformerBlock(nn.Module): 139 | def __init__( 140 | self, 141 | dim: int, 142 | num_attention_heads: int, 143 | attention_head_dim: int, 144 | dropout=0.0, 145 | cross_attention_dim: Optional[int] = None, 146 | activation_fn: str = "geglu", 147 | num_embeds_ada_norm: Optional[int] = None, 148 | attention_bias: bool = False, 149 | only_cross_attention: bool = False, 150 | upcast_attention: bool = False, 151 | unet_use_cross_frame_attention=None, 152 | unet_use_temporal_attention=None, 153 | ): 154 | super().__init__() 155 | self.only_cross_attention = only_cross_attention 156 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 157 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 158 | self.unet_use_temporal_attention = unet_use_temporal_attention 159 | 160 | # SC-Attn 161 | assert unet_use_cross_frame_attention is not None 162 | if unet_use_cross_frame_attention: 163 | self.attn1 = SparseCausalAttention( 164 | query_dim=dim, 165 | heads=num_attention_heads, 166 | dim_head=attention_head_dim, 167 | dropout=dropout, 168 | bias=attention_bias, 169 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 170 | upcast_attention=upcast_attention, 171 | ) 172 | else: 173 | self.attn1 = Attention( 174 | query_dim=dim, 175 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 176 | heads=num_attention_heads, 177 | dim_head=attention_head_dim, 178 | dropout=dropout, 179 | bias=attention_bias, 180 | upcast_attention=upcast_attention, 181 | ) 182 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 183 | 184 | # Cross-Attn 185 | if cross_attention_dim is not None: 186 | self.attn2 = Attention( 187 | query_dim=dim, 188 | cross_attention_dim=cross_attention_dim, 189 | heads=num_attention_heads, 190 | dim_head=attention_head_dim, 191 | dropout=dropout, 192 | bias=attention_bias, 193 | upcast_attention=upcast_attention, 194 | ) 195 | else: 196 | self.attn2 = None 197 | 198 | if cross_attention_dim is not None: 199 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 200 | else: 201 | self.norm2 = None 202 | 203 | # Feed-forward 204 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 205 | self.norm3 = nn.LayerNorm(dim) 206 | 207 | # Temp-Attn 208 | assert unet_use_temporal_attention is not None 209 | if unet_use_temporal_attention: 210 | self.attn_temp = Attention( 211 | query_dim=dim, 212 | heads=num_attention_heads, 213 | dim_head=attention_head_dim, 214 | dropout=dropout, 215 | bias=attention_bias, 216 | upcast_attention=upcast_attention, 217 | ) 218 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 219 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 220 | 221 | def forward( 222 | self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None 223 | ): 224 | # SparseCausal-Attention 225 | norm_hidden_states = ( 226 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 227 | ) 228 | 229 | # if self.only_cross_attention: 230 | # hidden_states = ( 231 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 232 | # ) 233 | # else: 234 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 235 | 236 | if self.unet_use_cross_frame_attention: 237 | hidden_states = ( 238 | self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) 239 | + hidden_states 240 | ) 241 | else: 242 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 243 | 244 | if self.attn2 is not None: 245 | # Cross-Attention 246 | norm_hidden_states = ( 247 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 248 | ) 249 | hidden_states = ( 250 | self.attn2( 251 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | # Feed-forward 257 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 258 | 259 | # Temporal-Attention 260 | if self.unet_use_temporal_attention: 261 | d = hidden_states.shape[1] 262 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 263 | norm_hidden_states = ( 264 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 265 | ) 266 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 267 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 268 | 269 | return hidden_states 270 | 271 | 272 | class CrossAttention(nn.Module): 273 | r""" 274 | A cross attention layer. 275 | 276 | Parameters: 277 | query_dim (`int`): The number of channels in the query. 278 | cross_attention_dim (`int`, *optional*): 279 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. 280 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. 281 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. 282 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 283 | bias (`bool`, *optional*, defaults to False): 284 | Set to `True` for the query, key, and value linear layers to contain a bias parameter. 285 | """ 286 | 287 | def __init__( 288 | self, 289 | query_dim: int, 290 | cross_attention_dim: Optional[int] = None, 291 | heads: int = 8, 292 | dim_head: int = 64, 293 | dropout: float = 0.0, 294 | bias=False, 295 | upcast_attention: bool = False, 296 | upcast_softmax: bool = False, 297 | added_kv_proj_dim: Optional[int] = None, 298 | norm_num_groups: Optional[int] = None, 299 | ): 300 | super().__init__() 301 | inner_dim = dim_head * heads 302 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim 303 | self.upcast_attention = upcast_attention 304 | self.upcast_softmax = upcast_softmax 305 | 306 | self.scale = dim_head**-0.5 307 | 308 | self.heads = heads 309 | # for slice_size > 0 the attention score computation 310 | # is split across the batch axis to save memory 311 | # You can set slice_size with `set_attention_slice` 312 | self.sliceable_head_dim = heads 313 | self._slice_size = None 314 | self._use_memory_efficient_attention_xformers = False 315 | self.added_kv_proj_dim = added_kv_proj_dim 316 | 317 | if norm_num_groups is not None: 318 | self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) 319 | else: 320 | self.group_norm = None 321 | 322 | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) 323 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 324 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 325 | 326 | if self.added_kv_proj_dim is not None: 327 | self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) 328 | self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) 329 | 330 | self.to_out = nn.ModuleList([]) 331 | self.to_out.append(nn.Linear(inner_dim, query_dim)) 332 | self.to_out.append(nn.Dropout(dropout)) 333 | 334 | def reshape_heads_to_batch_dim(self, tensor): 335 | batch_size, seq_len, dim = tensor.shape 336 | head_size = self.heads 337 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 338 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 339 | return tensor 340 | 341 | def reshape_batch_dim_to_heads(self, tensor): 342 | batch_size, seq_len, dim = tensor.shape 343 | head_size = self.heads 344 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 345 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 346 | return tensor 347 | 348 | def set_attention_slice(self, slice_size): 349 | if slice_size is not None and slice_size > self.sliceable_head_dim: 350 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") 351 | 352 | self._slice_size = slice_size 353 | 354 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 355 | batch_size, sequence_length, _ = hidden_states.shape 356 | 357 | encoder_hidden_states = encoder_hidden_states 358 | 359 | if self.group_norm is not None: 360 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 361 | 362 | query = self.to_q(hidden_states) 363 | dim = query.shape[-1] 364 | query = self.reshape_heads_to_batch_dim(query) 365 | 366 | if self.added_kv_proj_dim is not None: 367 | key = self.to_k(hidden_states) 368 | value = self.to_v(hidden_states) 369 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) 370 | encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) 371 | 372 | key = self.reshape_heads_to_batch_dim(key) 373 | value = self.reshape_heads_to_batch_dim(value) 374 | encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) 375 | encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) 376 | 377 | key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) 378 | value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) 379 | else: 380 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 381 | key = self.to_k(encoder_hidden_states) 382 | value = self.to_v(encoder_hidden_states) 383 | 384 | key = self.reshape_heads_to_batch_dim(key) 385 | value = self.reshape_heads_to_batch_dim(value) 386 | 387 | if attention_mask is not None: 388 | if attention_mask.shape[-1] != query.shape[1]: 389 | target_length = query.shape[1] 390 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 391 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 392 | 393 | # attention, what we cannot get enough of 394 | if self._use_memory_efficient_attention_xformers: 395 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 396 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 397 | hidden_states = hidden_states.to(query.dtype) 398 | else: 399 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 400 | hidden_states = self._attention(query, key, value, attention_mask) 401 | else: 402 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 403 | 404 | # linear proj 405 | hidden_states = self.to_out[0](hidden_states) 406 | 407 | # dropout 408 | hidden_states = self.to_out[1](hidden_states) 409 | return hidden_states 410 | 411 | def _attention(self, query, key, value, attention_mask=None): 412 | if self.upcast_attention: 413 | query = query.float() 414 | key = key.float() 415 | 416 | attention_scores = torch.baddbmm( 417 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 418 | query, 419 | key.transpose(-1, -2), 420 | beta=0, 421 | alpha=self.scale, 422 | ) 423 | 424 | if attention_mask is not None: 425 | attention_scores = attention_scores + attention_mask 426 | 427 | if self.upcast_softmax: 428 | attention_scores = attention_scores.float() 429 | 430 | attention_probs = attention_scores.softmax(dim=-1) 431 | 432 | # cast back to the original dtype 433 | attention_probs = attention_probs.to(value.dtype) 434 | 435 | # compute attention output 436 | hidden_states = torch.bmm(attention_probs, value) 437 | 438 | # reshape hidden_states 439 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 440 | return hidden_states 441 | 442 | def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): 443 | batch_size_attention = query.shape[0] 444 | hidden_states = torch.zeros( 445 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype 446 | ) 447 | slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] 448 | for i in range(hidden_states.shape[0] // slice_size): 449 | start_idx = i * slice_size 450 | end_idx = (i + 1) * slice_size 451 | 452 | query_slice = query[start_idx:end_idx] 453 | key_slice = key[start_idx:end_idx] 454 | 455 | if self.upcast_attention: 456 | query_slice = query_slice.float() 457 | key_slice = key_slice.float() 458 | 459 | attn_slice = torch.baddbmm( 460 | torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), 461 | query_slice, 462 | key_slice.transpose(-1, -2), 463 | beta=0, 464 | alpha=self.scale, 465 | ) 466 | 467 | if attention_mask is not None: 468 | attn_slice = attn_slice + attention_mask[start_idx:end_idx] 469 | 470 | if self.upcast_softmax: 471 | attn_slice = attn_slice.float() 472 | 473 | attn_slice = attn_slice.softmax(dim=-1) 474 | 475 | # cast back to the original dtype 476 | attn_slice = attn_slice.to(value.dtype) 477 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 478 | 479 | hidden_states[start_idx:end_idx] = attn_slice 480 | 481 | # reshape hidden_states 482 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 483 | return hidden_states 484 | 485 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): 486 | # TODO attention_mask 487 | query = query.contiguous() 488 | key = key.contiguous() 489 | value = value.contiguous() 490 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 491 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 492 | return hidden_states 493 | 494 | 495 | class SparseCausalAttention(CrossAttention): 496 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 497 | batch_size, sequence_length, _ = hidden_states.shape 498 | 499 | encoder_hidden_states = encoder_hidden_states 500 | 501 | if self.group_norm is not None: 502 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 503 | 504 | query = self.to_q(hidden_states) 505 | dim = query.shape[-1] 506 | query = self.reshape_heads_to_batch_dim(query) 507 | 508 | if self.added_kv_proj_dim is not None: 509 | raise NotImplementedError 510 | 511 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 512 | key = self.to_k(encoder_hidden_states) 513 | value = self.to_v(encoder_hidden_states) 514 | 515 | former_frame_index = torch.arange(video_length) - 1 516 | former_frame_index[0] = 0 517 | 518 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length) 519 | # key = torch.cat([key[:, [0] * video_length], key[:, [0] * video_length]], dim=2) 520 | key = key[:, [0] * video_length] 521 | key = rearrange(key, "b f d c -> (b f) d c") 522 | 523 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length) 524 | # value = torch.cat([value[:, [0] * video_length], value[:, [0] * video_length]], dim=2) 525 | # value = value[:, former_frame_index] 526 | value = rearrange(value, "b f d c -> (b f) d c") 527 | 528 | key = self.reshape_heads_to_batch_dim(key) 529 | value = self.reshape_heads_to_batch_dim(value) 530 | 531 | if attention_mask is not None: 532 | if attention_mask.shape[-1] != query.shape[1]: 533 | target_length = query.shape[1] 534 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 535 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 536 | 537 | # attention, what we cannot get enough of 538 | if self._use_memory_efficient_attention_xformers: 539 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 540 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 541 | hidden_states = hidden_states.to(query.dtype) 542 | else: 543 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 544 | hidden_states = self._attention(query, key, value, attention_mask) 545 | else: 546 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 547 | 548 | # linear proj 549 | hidden_states = self.to_out[0](hidden_states) 550 | 551 | # dropout 552 | hidden_states = self.to_out[1](hidden_states) 553 | return hidden_states 554 | -------------------------------------------------------------------------------- /animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/guoyww/AnimateDiff 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | class InflatedConv3d(nn.Conv2d): 10 | def forward(self, x): 11 | video_length = x.shape[2] 12 | 13 | x = rearrange(x, "b c f h w -> (b f) c h w") 14 | x = super().forward(x) 15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 16 | 17 | return x 18 | 19 | 20 | class Upsample3D(nn.Module): 21 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 22 | super().__init__() 23 | self.channels = channels 24 | self.out_channels = out_channels or channels 25 | self.use_conv = use_conv 26 | self.use_conv_transpose = use_conv_transpose 27 | self.name = name 28 | 29 | # conv = None 30 | if use_conv_transpose: 31 | raise NotImplementedError 32 | elif use_conv: 33 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 34 | 35 | def forward(self, hidden_states, output_size=None): 36 | assert hidden_states.shape[1] == self.channels 37 | 38 | if self.use_conv_transpose: 39 | raise NotImplementedError 40 | 41 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 42 | dtype = hidden_states.dtype 43 | if dtype == torch.bfloat16: 44 | hidden_states = hidden_states.to(torch.float32) 45 | 46 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 47 | if hidden_states.shape[0] >= 64: 48 | hidden_states = hidden_states.contiguous() 49 | 50 | # if `output_size` is passed we force the interpolation output 51 | # size and do not make use of `scale_factor=2` 52 | if output_size is None: 53 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 54 | else: 55 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 56 | 57 | # If the input is bfloat16, we cast back to bfloat16 58 | if dtype == torch.bfloat16: 59 | hidden_states = hidden_states.to(dtype) 60 | 61 | # if self.use_conv: 62 | # if self.name == "conv": 63 | # hidden_states = self.conv(hidden_states) 64 | # else: 65 | # hidden_states = self.Conv2d_0(hidden_states) 66 | hidden_states = self.conv(hidden_states) 67 | 68 | return hidden_states 69 | 70 | 71 | class Downsample3D(nn.Module): 72 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 73 | super().__init__() 74 | self.channels = channels 75 | self.out_channels = out_channels or channels 76 | self.use_conv = use_conv 77 | self.padding = padding 78 | stride = 2 79 | self.name = name 80 | 81 | if use_conv: 82 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 83 | else: 84 | raise NotImplementedError 85 | 86 | def forward(self, hidden_states): 87 | assert hidden_states.shape[1] == self.channels 88 | if self.use_conv and self.padding == 0: 89 | raise NotImplementedError 90 | 91 | assert hidden_states.shape[1] == self.channels 92 | hidden_states = self.conv(hidden_states) 93 | 94 | return hidden_states 95 | 96 | 97 | class ResnetBlock3D(nn.Module): 98 | def __init__( 99 | self, 100 | *, 101 | in_channels, 102 | out_channels=None, 103 | conv_shortcut=False, 104 | dropout=0.0, 105 | temb_channels=512, 106 | groups=32, 107 | groups_out=None, 108 | pre_norm=True, 109 | eps=1e-6, 110 | non_linearity="swish", 111 | time_embedding_norm="default", 112 | output_scale_factor=1.0, 113 | use_in_shortcut=None, 114 | ): 115 | super().__init__() 116 | self.pre_norm = pre_norm 117 | self.pre_norm = True 118 | self.in_channels = in_channels 119 | out_channels = in_channels if out_channels is None else out_channels 120 | self.out_channels = out_channels 121 | self.use_conv_shortcut = conv_shortcut 122 | self.time_embedding_norm = time_embedding_norm 123 | self.output_scale_factor = output_scale_factor 124 | 125 | if groups_out is None: 126 | groups_out = groups 127 | 128 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 129 | 130 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 131 | 132 | if temb_channels is not None: 133 | if self.time_embedding_norm == "default": 134 | time_emb_proj_out_channels = out_channels 135 | elif self.time_embedding_norm == "scale_shift": 136 | time_emb_proj_out_channels = out_channels * 2 137 | else: 138 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 139 | 140 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 141 | else: 142 | self.time_emb_proj = None 143 | 144 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 145 | self.dropout = torch.nn.Dropout(dropout) 146 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 147 | 148 | if non_linearity == "swish": 149 | self.nonlinearity = lambda x: F.silu(x) 150 | elif non_linearity == "mish": 151 | self.nonlinearity = Mish() 152 | elif non_linearity == "silu": 153 | self.nonlinearity = nn.SiLU() 154 | 155 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 156 | 157 | self.conv_shortcut = None 158 | if self.use_in_shortcut: 159 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 160 | 161 | def forward(self, input_tensor, temb): 162 | hidden_states = input_tensor 163 | 164 | hidden_states = self.norm1(hidden_states) 165 | hidden_states = self.nonlinearity(hidden_states) 166 | 167 | hidden_states = self.conv1(hidden_states) 168 | 169 | if temb is not None: 170 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 171 | 172 | if temb is not None and self.time_embedding_norm == "default": 173 | hidden_states = hidden_states + temb 174 | 175 | hidden_states = self.norm2(hidden_states) 176 | 177 | if temb is not None and self.time_embedding_norm == "scale_shift": 178 | scale, shift = torch.chunk(temb, 2, dim=1) 179 | hidden_states = hidden_states * (1 + scale) + shift 180 | 181 | hidden_states = self.nonlinearity(hidden_states) 182 | 183 | hidden_states = self.dropout(hidden_states) 184 | hidden_states = self.conv2(hidden_states) 185 | 186 | if self.conv_shortcut is not None: 187 | input_tensor = self.conv_shortcut(input_tensor) 188 | 189 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 190 | 191 | return output_tensor 192 | 193 | 194 | class Mish(torch.nn.Module): 195 | def forward(self, hidden_states): 196 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 197 | -------------------------------------------------------------------------------- /animatediff/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .i2v_pipeline import I2VPipeline 2 | from .pipeline_animation import AnimationPipeline 3 | from .validation_pipeline import ValidationPipeline 4 | 5 | 6 | __all__ = ["I2VPipeline", "AnimationPipeline", "ValidationPipeline"] 7 | -------------------------------------------------------------------------------- /animatediff/pipelines/pipeline_animation.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py 2 | 3 | import inspect 4 | from dataclasses import dataclass 5 | from typing import Callable, List, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | from einops import rearrange 10 | from packaging import version 11 | from tqdm import tqdm 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from diffusers.configuration_utils import FrozenDict 15 | from diffusers.models import AutoencoderKL 16 | from diffusers.pipelines import DiffusionPipeline 17 | from diffusers.schedulers import ( 18 | DDIMScheduler, 19 | DPMSolverMultistepScheduler, 20 | EulerAncestralDiscreteScheduler, 21 | EulerDiscreteScheduler, 22 | LMSDiscreteScheduler, 23 | PNDMScheduler, 24 | ) 25 | from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging 26 | 27 | from ..models.unet import UNet3DConditionModel 28 | 29 | 30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 31 | 32 | 33 | @dataclass 34 | class AnimationPipelineOutput(BaseOutput): 35 | videos: Union[torch.Tensor, np.ndarray] 36 | 37 | 38 | class AnimationPipeline(DiffusionPipeline): 39 | _optional_components = [] 40 | 41 | def __init__( 42 | self, 43 | vae: AutoencoderKL, 44 | text_encoder: CLIPTextModel, 45 | tokenizer: CLIPTokenizer, 46 | unet: UNet3DConditionModel, 47 | scheduler: Union[ 48 | DDIMScheduler, 49 | PNDMScheduler, 50 | LMSDiscreteScheduler, 51 | EulerDiscreteScheduler, 52 | EulerAncestralDiscreteScheduler, 53 | DPMSolverMultistepScheduler, 54 | ], 55 | ): 56 | super().__init__() 57 | 58 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 59 | deprecation_message = ( 60 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 61 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 62 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 63 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 64 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 65 | " file" 66 | ) 67 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 68 | new_config = dict(scheduler.config) 69 | new_config["steps_offset"] = 1 70 | scheduler._internal_dict = FrozenDict(new_config) 71 | 72 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 73 | deprecation_message = ( 74 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 75 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 76 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 77 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 78 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 79 | ) 80 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 81 | new_config = dict(scheduler.config) 82 | new_config["clip_sample"] = False 83 | scheduler._internal_dict = FrozenDict(new_config) 84 | 85 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 86 | version.parse(unet.config._diffusers_version).base_version 87 | ) < version.parse("0.9.0.dev0") 88 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 89 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 90 | deprecation_message = ( 91 | "The configuration file of the unet has set the default `sample_size` to smaller than" 92 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 93 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 94 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 95 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 96 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 97 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 98 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 99 | " the `unet/config.json` file" 100 | ) 101 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 102 | new_config = dict(unet.config) 103 | new_config["sample_size"] = 64 104 | unet._internal_dict = FrozenDict(new_config) 105 | 106 | self.register_modules( 107 | vae=vae, 108 | text_encoder=text_encoder, 109 | tokenizer=tokenizer, 110 | unet=unet, 111 | scheduler=scheduler, 112 | ) 113 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 114 | 115 | def enable_vae_slicing(self): 116 | self.vae.enable_slicing() 117 | 118 | def disable_vae_slicing(self): 119 | self.vae.disable_slicing() 120 | 121 | def enable_sequential_cpu_offload(self, gpu_id=0): 122 | if is_accelerate_available(): 123 | from accelerate import cpu_offload 124 | else: 125 | raise ImportError("Please install accelerate via `pip install accelerate`") 126 | 127 | device = torch.device(f"cuda:{gpu_id}") 128 | 129 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 130 | if cpu_offloaded_model is not None: 131 | cpu_offload(cpu_offloaded_model, device) 132 | 133 | @property 134 | def _execution_device(self): 135 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 136 | return self.device 137 | for module in self.unet.modules(): 138 | if ( 139 | hasattr(module, "_hf_hook") 140 | and hasattr(module._hf_hook, "execution_device") 141 | and module._hf_hook.execution_device is not None 142 | ): 143 | return torch.device(module._hf_hook.execution_device) 144 | return self.device 145 | 146 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 147 | batch_size = len(prompt) if isinstance(prompt, list) else 1 148 | 149 | text_inputs = self.tokenizer( 150 | prompt, 151 | padding="max_length", 152 | max_length=self.tokenizer.model_max_length, 153 | truncation=True, 154 | return_tensors="pt", 155 | ) 156 | text_input_ids = text_inputs.input_ids 157 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 158 | 159 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 160 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 161 | logger.warning( 162 | "The following part of your input was truncated because CLIP can only handle sequences up to" 163 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 164 | ) 165 | 166 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 167 | attention_mask = text_inputs.attention_mask.to(device) 168 | else: 169 | attention_mask = None 170 | 171 | text_embeddings = self.text_encoder( 172 | text_input_ids.to(device), 173 | attention_mask=attention_mask, 174 | ) 175 | text_embeddings = text_embeddings[0] 176 | 177 | # duplicate text embeddings for each generation per prompt, using mps friendly method 178 | bs_embed, seq_len, _ = text_embeddings.shape 179 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 180 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 181 | 182 | # get unconditional embeddings for classifier free guidance 183 | if do_classifier_free_guidance: 184 | uncond_tokens: List[str] 185 | if negative_prompt is None: 186 | uncond_tokens = [""] * batch_size 187 | elif type(prompt) is not type(negative_prompt): 188 | raise TypeError( 189 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 190 | f" {type(prompt)}." 191 | ) 192 | elif isinstance(negative_prompt, str): 193 | uncond_tokens = [negative_prompt] 194 | elif batch_size != len(negative_prompt): 195 | raise ValueError( 196 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 197 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 198 | " the batch size of `prompt`." 199 | ) 200 | else: 201 | uncond_tokens = negative_prompt 202 | 203 | max_length = text_input_ids.shape[-1] 204 | uncond_input = self.tokenizer( 205 | uncond_tokens, 206 | padding="max_length", 207 | max_length=max_length, 208 | truncation=True, 209 | return_tensors="pt", 210 | ) 211 | 212 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 213 | attention_mask = uncond_input.attention_mask.to(device) 214 | else: 215 | attention_mask = None 216 | 217 | uncond_embeddings = self.text_encoder( 218 | uncond_input.input_ids.to(device), 219 | attention_mask=attention_mask, 220 | ) 221 | uncond_embeddings = uncond_embeddings[0] 222 | 223 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 224 | seq_len = uncond_embeddings.shape[1] 225 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 226 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 227 | 228 | # For classifier free guidance, we need to do two forward passes. 229 | # Here we concatenate the unconditional and text embeddings into a single batch 230 | # to avoid doing two forward passes 231 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 232 | 233 | return text_embeddings 234 | 235 | def decode_latents(self, latents): 236 | video_length = latents.shape[2] 237 | latents = 1 / 0.18215 * latents 238 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 239 | # video = self.vae.decode(latents).sample 240 | video = [] 241 | for frame_idx in tqdm(range(latents.shape[0])): 242 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 243 | video = torch.cat(video) 244 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 245 | video = (video / 2 + 0.5).clamp(0, 1) 246 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 247 | video = video.cpu().float().numpy() 248 | return video 249 | 250 | def prepare_extra_step_kwargs(self, generator, eta): 251 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 252 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 253 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 254 | # and should be between [0, 1] 255 | 256 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 257 | extra_step_kwargs = {} 258 | if accepts_eta: 259 | extra_step_kwargs["eta"] = eta 260 | 261 | # check if the scheduler accepts generator 262 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 263 | if accepts_generator: 264 | extra_step_kwargs["generator"] = generator 265 | return extra_step_kwargs 266 | 267 | def check_inputs(self, prompt, height, width, callback_steps): 268 | if not isinstance(prompt, str) and not isinstance(prompt, list): 269 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 270 | 271 | if height % 8 != 0 or width % 8 != 0: 272 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 273 | 274 | if (callback_steps is None) or ( 275 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 276 | ): 277 | raise ValueError( 278 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 279 | f" {type(callback_steps)}." 280 | ) 281 | 282 | def prepare_latents( 283 | self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None 284 | ): 285 | shape = ( 286 | batch_size, 287 | num_channels_latents, 288 | video_length, 289 | height // self.vae_scale_factor, 290 | width // self.vae_scale_factor, 291 | ) 292 | if isinstance(generator, list) and len(generator) != batch_size: 293 | raise ValueError( 294 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 295 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 296 | ) 297 | if latents is None: 298 | rand_device = "cpu" if device.type == "mps" else device 299 | 300 | if isinstance(generator, list): 301 | shape = shape 302 | # shape = (1,) + shape[1:] 303 | latents = [ 304 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 305 | for i in range(batch_size) 306 | ] 307 | latents = torch.cat(latents, dim=0).to(device) 308 | else: 309 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 310 | else: 311 | if latents.shape != shape: 312 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 313 | latents = latents.to(device) 314 | 315 | # scale the initial noise by the standard deviation required by the scheduler 316 | latents = latents * self.scheduler.init_noise_sigma 317 | return latents 318 | 319 | @torch.no_grad() 320 | def __call__( 321 | self, 322 | prompt: Union[str, List[str]], 323 | video_length: Optional[int], 324 | height: Optional[int] = None, 325 | width: Optional[int] = None, 326 | num_inference_steps: int = 50, 327 | guidance_scale: float = 7.5, 328 | negative_prompt: Optional[Union[str, List[str]]] = None, 329 | num_videos_per_prompt: Optional[int] = 1, 330 | eta: float = 0.0, 331 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 332 | latents: Optional[torch.FloatTensor] = None, 333 | output_type: Optional[str] = "tensor", 334 | return_dict: bool = True, 335 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 336 | callback_steps: Optional[int] = 1, 337 | **kwargs, 338 | ): 339 | # Default height and width to unet 340 | height = height or self.unet.config.sample_size * self.vae_scale_factor 341 | width = width or self.unet.config.sample_size * self.vae_scale_factor 342 | 343 | # Check inputs. Raise error if not correct 344 | self.check_inputs(prompt, height, width, callback_steps) 345 | 346 | # Define call parameters 347 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 348 | batch_size = 1 349 | if latents is not None: 350 | batch_size = latents.shape[0] 351 | if isinstance(prompt, list): 352 | batch_size = len(prompt) 353 | 354 | device = self._execution_device 355 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 356 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 357 | # corresponds to doing no classifier free guidance. 358 | do_classifier_free_guidance = guidance_scale > 1.0 359 | 360 | # Encode input prompt 361 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 362 | if negative_prompt is not None: 363 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 364 | text_embeddings = self._encode_prompt( 365 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 366 | ) 367 | 368 | # Prepare timesteps 369 | self.scheduler.set_timesteps(num_inference_steps, device=device) 370 | timesteps = self.scheduler.timesteps 371 | 372 | # Prepare latent variables 373 | num_channels_latents = self.unet.in_channels 374 | latents = self.prepare_latents( 375 | batch_size * num_videos_per_prompt, 376 | num_channels_latents, 377 | video_length, 378 | height, 379 | width, 380 | text_embeddings.dtype, 381 | device, 382 | generator, 383 | latents, 384 | ) 385 | latents_dtype = latents.dtype 386 | 387 | # Prepare extra step kwargs. 388 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 389 | 390 | # Denoising loop 391 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 392 | with self.progress_bar(total=num_inference_steps) as progress_bar: 393 | for i, t in enumerate(timesteps): 394 | # expand the latents if we are doing classifier free guidance 395 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 396 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 397 | 398 | # predict the noise residual 399 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to( 400 | dtype=latents_dtype 401 | ) 402 | 403 | # perform guidance 404 | if do_classifier_free_guidance: 405 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 406 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 407 | 408 | # compute the previous noisy sample x_t -> x_t-1 409 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 410 | 411 | # call the callback, if provided 412 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 413 | progress_bar.update() 414 | if callback is not None and i % callback_steps == 0: 415 | callback(i, t, latents) 416 | 417 | # Post-processing 418 | video = self.decode_latents(latents) 419 | 420 | # Convert to tensor 421 | if output_type == "tensor": 422 | video = torch.from_numpy(video) 423 | 424 | if not return_dict: 425 | return video 426 | 427 | return AnimationPipelineOutput(videos=video) 428 | -------------------------------------------------------------------------------- /animatediff/pipelines/validation_pipeline.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py 2 | import inspect 3 | from dataclasses import dataclass 4 | from typing import Callable, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from einops import rearrange 9 | from packaging import version 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from animatediff.models.unet import UNet3DConditionModel 15 | from animatediff.utils.util import prepare_mask_coef 16 | from diffusers.configuration_utils import FrozenDict 17 | from diffusers.models import AutoencoderKL 18 | from diffusers.pipelines import DiffusionPipeline 19 | from diffusers.schedulers import ( 20 | DDIMScheduler, 21 | DPMSolverMultistepScheduler, 22 | EulerAncestralDiscreteScheduler, 23 | EulerDiscreteScheduler, 24 | LMSDiscreteScheduler, 25 | PNDMScheduler, 26 | ) 27 | from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging 28 | 29 | 30 | PIL_INTERPOLATION = { 31 | "linear": Image.Resampling.BILINEAR, 32 | "bilinear": Image.Resampling.BILINEAR, 33 | "bicubic": Image.Resampling.BICUBIC, 34 | "lanczos": Image.Resampling.LANCZOS, 35 | "nearest": Image.Resampling.NEAREST, 36 | } 37 | 38 | 39 | def preprocess_image(image): 40 | if isinstance(image, torch.Tensor): 41 | return image 42 | elif isinstance(image, Image.Image): 43 | image = [image] 44 | 45 | if isinstance(image[0], Image.Image): 46 | w, h = image[0].size 47 | w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 48 | 49 | image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] 50 | image = np.concatenate(image, axis=0) 51 | if len(image.shape) == 3: 52 | image = image.reshape(image.shape[0], image.shape[1], image.shape[2], 1) 53 | image = np.array(image).astype(np.float32) / 255.0 54 | image = image.transpose(0, 3, 1, 2) 55 | image = 2.0 * image - 1.0 56 | image = torch.from_numpy(image) 57 | elif isinstance(image[0], torch.Tensor): 58 | image = torch.cat(image, dim=0) 59 | return image 60 | 61 | 62 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 63 | 64 | 65 | @dataclass 66 | class AnimationPipelineOutput(BaseOutput): 67 | videos: Union[torch.Tensor, np.ndarray] 68 | 69 | 70 | class ValidationPipeline(DiffusionPipeline): 71 | _optional_components = [] 72 | 73 | def __init__( 74 | self, 75 | vae: AutoencoderKL, 76 | text_encoder: CLIPTextModel, 77 | tokenizer: CLIPTokenizer, 78 | unet: UNet3DConditionModel, 79 | scheduler: Union[ 80 | DDIMScheduler, 81 | PNDMScheduler, 82 | LMSDiscreteScheduler, 83 | EulerDiscreteScheduler, 84 | EulerAncestralDiscreteScheduler, 85 | DPMSolverMultistepScheduler, 86 | ], 87 | ): 88 | super().__init__() 89 | 90 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 91 | deprecation_message = ( 92 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 93 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 94 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 95 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 96 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 97 | " file" 98 | ) 99 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 100 | new_config = dict(scheduler.config) 101 | new_config["steps_offset"] = 1 102 | scheduler._internal_dict = FrozenDict(new_config) 103 | 104 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 105 | deprecation_message = ( 106 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 107 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 108 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 109 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 110 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 111 | ) 112 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 113 | new_config = dict(scheduler.config) 114 | new_config["clip_sample"] = False 115 | scheduler._internal_dict = FrozenDict(new_config) 116 | 117 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 118 | version.parse(unet.config._diffusers_version).base_version 119 | ) < version.parse("0.9.0.dev0") 120 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 121 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 122 | deprecation_message = ( 123 | "The configuration file of the unet has set the default `sample_size` to smaller than" 124 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 125 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 126 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 127 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 128 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 129 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 130 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 131 | " the `unet/config.json` file" 132 | ) 133 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 134 | new_config = dict(unet.config) 135 | new_config["sample_size"] = 64 136 | unet._internal_dict = FrozenDict(new_config) 137 | 138 | self.register_modules( 139 | vae=vae, 140 | text_encoder=text_encoder, 141 | tokenizer=tokenizer, 142 | unet=unet, 143 | scheduler=scheduler, 144 | ) 145 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 146 | 147 | def enable_vae_slicing(self): 148 | self.vae.enable_slicing() 149 | 150 | def disable_vae_slicing(self): 151 | self.vae.disable_slicing() 152 | 153 | def enable_sequential_cpu_offload(self, gpu_id=0): 154 | if is_accelerate_available(): 155 | from accelerate import cpu_offload 156 | else: 157 | raise ImportError("Please install accelerate via `pip install accelerate`") 158 | 159 | device = torch.device(f"cuda:{gpu_id}") 160 | 161 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 162 | if cpu_offloaded_model is not None: 163 | cpu_offload(cpu_offloaded_model, device) 164 | 165 | @property 166 | def _execution_device(self): 167 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 168 | return self.device 169 | for module in self.unet.modules(): 170 | if ( 171 | hasattr(module, "_hf_hook") 172 | and hasattr(module._hf_hook, "execution_device") 173 | and module._hf_hook.execution_device is not None 174 | ): 175 | return torch.device(module._hf_hook.execution_device) 176 | return self.device 177 | 178 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 179 | batch_size = len(prompt) if isinstance(prompt, list) else 1 180 | 181 | text_inputs = self.tokenizer( 182 | prompt, 183 | padding="max_length", 184 | max_length=self.tokenizer.model_max_length, 185 | truncation=True, 186 | return_tensors="pt", 187 | ) 188 | text_input_ids = text_inputs.input_ids 189 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 190 | 191 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 192 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 193 | logger.warning( 194 | "The following part of your input was truncated because CLIP can only handle sequences up to" 195 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 196 | ) 197 | 198 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 199 | attention_mask = text_inputs.attention_mask.to(device) 200 | else: 201 | attention_mask = None 202 | 203 | text_embeddings = self.text_encoder( 204 | text_input_ids.to(device), 205 | attention_mask=attention_mask, 206 | ) 207 | text_embeddings = text_embeddings[0] 208 | 209 | # duplicate text embeddings for each generation per prompt, using mps friendly method 210 | bs_embed, seq_len, _ = text_embeddings.shape 211 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 212 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 213 | 214 | # get unconditional embeddings for classifier free guidance 215 | if do_classifier_free_guidance: 216 | uncond_tokens: List[str] 217 | if negative_prompt is None: 218 | uncond_tokens = [""] * batch_size 219 | elif type(prompt) is not type(negative_prompt): 220 | raise TypeError( 221 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 222 | f" {type(prompt)}." 223 | ) 224 | elif isinstance(negative_prompt, str): 225 | uncond_tokens = [negative_prompt] 226 | elif batch_size != len(negative_prompt): 227 | raise ValueError( 228 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 229 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 230 | " the batch size of `prompt`." 231 | ) 232 | else: 233 | uncond_tokens = negative_prompt 234 | 235 | max_length = text_input_ids.shape[-1] 236 | uncond_input = self.tokenizer( 237 | uncond_tokens, 238 | padding="max_length", 239 | max_length=max_length, 240 | truncation=True, 241 | return_tensors="pt", 242 | ) 243 | 244 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 245 | attention_mask = uncond_input.attention_mask.to(device) 246 | else: 247 | attention_mask = None 248 | 249 | uncond_embeddings = self.text_encoder( 250 | uncond_input.input_ids.to(device), 251 | attention_mask=attention_mask, 252 | ) 253 | uncond_embeddings = uncond_embeddings[0] 254 | 255 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 256 | seq_len = uncond_embeddings.shape[1] 257 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 258 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 259 | 260 | # For classifier free guidance, we need to do two forward passes. 261 | # Here we concatenate the unconditional and text embeddings into a single batch 262 | # to avoid doing two forward passes 263 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 264 | 265 | return text_embeddings 266 | 267 | def decode_latents(self, latents): 268 | video_length = latents.shape[2] 269 | latents = 1 / 0.18215 * latents 270 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 271 | # video = self.vae.decode(latents).sample 272 | video = [] 273 | for frame_idx in tqdm(range(latents.shape[0])): 274 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 275 | video = torch.cat(video) 276 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 277 | video = (video / 2 + 0.5).clamp(0, 1) 278 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 279 | video = video.cpu().float().numpy() 280 | return video 281 | 282 | def prepare_extra_step_kwargs(self, generator, eta): 283 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 284 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 285 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 286 | # and should be between [0, 1] 287 | 288 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 289 | extra_step_kwargs = {} 290 | if accepts_eta: 291 | extra_step_kwargs["eta"] = eta 292 | 293 | # check if the scheduler accepts generator 294 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 295 | if accepts_generator: 296 | extra_step_kwargs["generator"] = generator 297 | return extra_step_kwargs 298 | 299 | def check_inputs(self, prompt, height, width, callback_steps): 300 | if not isinstance(prompt, str) and not isinstance(prompt, list): 301 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 302 | 303 | if height % 8 != 0 or width % 8 != 0: 304 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 305 | 306 | if (callback_steps is None) or ( 307 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 308 | ): 309 | raise ValueError( 310 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 311 | f" {type(callback_steps)}." 312 | ) 313 | 314 | def prepare_latents( 315 | self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None 316 | ): 317 | shape = ( 318 | batch_size, 319 | num_channels_latents, 320 | video_length, 321 | height // self.vae_scale_factor, 322 | width // self.vae_scale_factor, 323 | ) 324 | 325 | if isinstance(generator, list) and len(generator) != batch_size: 326 | raise ValueError( 327 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 328 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 329 | ) 330 | if latents is None: 331 | rand_device = "cpu" if device.type == "mps" else device 332 | 333 | if isinstance(generator, list): 334 | shape = shape 335 | # shape = (1,) + shape[1:] 336 | latents = [ 337 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 338 | for i in range(batch_size) 339 | ] 340 | latents = torch.cat(latents, dim=0).to(device) 341 | else: 342 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 343 | else: 344 | if latents.shape != shape: 345 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 346 | latents = latents.to(device) 347 | 348 | # scale the initial noise by the standard deviation required by the scheduler 349 | latents = latents * self.scheduler.init_noise_sigma 350 | return latents 351 | 352 | @torch.no_grad() 353 | def __call__( 354 | self, 355 | prompt: Union[str, List[str]], 356 | use_image: bool, 357 | video_length: Optional[int], 358 | height: Optional[int] = None, 359 | width: Optional[int] = None, 360 | num_inference_steps: int = 50, 361 | guidance_scale: float = 7.5, 362 | negative_prompt: Optional[Union[str, List[str]]] = None, 363 | num_videos_per_prompt: Optional[int] = 1, 364 | eta: float = 0.0, 365 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 366 | latents: Optional[torch.FloatTensor] = None, 367 | output_type: Optional[str] = "tensor", 368 | return_dict: bool = True, 369 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 370 | callback_steps: Optional[int] = 1, 371 | **kwargs, 372 | ): 373 | # Default height and width to unet 374 | height = height or self.unet.config.sample_size * self.vae_scale_factor 375 | width = width or self.unet.config.sample_size * self.vae_scale_factor 376 | 377 | # Check inputs. Raise error if not correct 378 | self.check_inputs(prompt, height, width, callback_steps) 379 | 380 | # Define call parameters 381 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 382 | batch_size = 1 383 | if latents is not None: 384 | batch_size = latents.shape[0] 385 | if isinstance(prompt, list): 386 | batch_size = len(prompt) 387 | 388 | device = self._execution_device 389 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 390 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 391 | # corresponds to doing no classifier free guidance. 392 | do_classifier_free_guidance = guidance_scale > 1.0 393 | 394 | # Encode input prompt 395 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 396 | if negative_prompt is not None: 397 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 398 | text_embeddings = self._encode_prompt( 399 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 400 | ) 401 | 402 | # Prepare timesteps 403 | self.scheduler.set_timesteps(num_inference_steps, device=device) 404 | timesteps = self.scheduler.timesteps 405 | 406 | # Prepare latent variables 407 | num_channels_latents = 4 408 | # num_channels_latents = self.unet.in_channels 409 | latents = self.prepare_latents( 410 | batch_size * num_videos_per_prompt, 411 | num_channels_latents, 412 | video_length, 413 | height, 414 | width, 415 | text_embeddings.dtype, 416 | device, 417 | generator, 418 | latents, 419 | ) 420 | latents_dtype = latents.dtype 421 | 422 | if use_image: 423 | shape = ( 424 | batch_size, 425 | num_channels_latents, 426 | video_length, 427 | height // self.vae_scale_factor, 428 | width // self.vae_scale_factor, 429 | ) 430 | 431 | image = Image.open(f"test_image/init_image{use_image}.png").convert("RGB") 432 | image = preprocess_image(image).to(device) 433 | if isinstance(generator, list): 434 | image_latent = [ 435 | self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size) 436 | ] 437 | image_latent = torch.cat(image_latent, dim=0).to(device=device) 438 | else: 439 | image_latent = self.vae.encode(image).latent_dist.sample(generator).to(device=device) 440 | 441 | image_latent = torch.nn.functional.interpolate(image_latent, size=[shape[-2], shape[-1]]) 442 | image_latent_padding = image_latent.clone() * 0.18215 443 | mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(device) 444 | mask_coef = prepare_mask_coef(video_length, 0, kwargs["mask_sim_range"]) 445 | 446 | add_noise = torch.randn(shape).to(device) 447 | masked_image = torch.zeros(shape).to(device) 448 | for f in range(video_length): 449 | mask[:, :, f, :, :] = mask_coef[f] 450 | masked_image[:, :, f, :, :] = image_latent_padding.clone() 451 | mask = mask.to(device) 452 | else: 453 | shape = ( 454 | batch_size, 455 | num_channels_latents, 456 | video_length, 457 | height // self.vae_scale_factor, 458 | width // self.vae_scale_factor, 459 | ) 460 | add_noise = torch.zeros_like(latents).to(device) 461 | masked_image = add_noise 462 | mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(device) 463 | 464 | # Prepare extra step kwargs. 465 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 466 | mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask 467 | masked_image = torch.cat([masked_image] * 2) if do_classifier_free_guidance else masked_image 468 | # Denoising loop 469 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 470 | with self.progress_bar(total=num_inference_steps) as progress_bar: 471 | for i, t in enumerate(timesteps): 472 | # expand the latents if we are doing classifier free guidance 473 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 474 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 475 | 476 | # predict the noise residual 477 | noise_pred = self.unet( 478 | latent_model_input, mask, masked_image, t, encoder_hidden_states=text_embeddings 479 | ).sample.to(dtype=latents_dtype) 480 | 481 | # perform guidance 482 | if do_classifier_free_guidance: 483 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 484 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 485 | 486 | # compute the previous noisy sample x_t -> x_t-1 487 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 488 | 489 | # call the callback, if provided 490 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 491 | progress_bar.update() 492 | if callback is not None and i % callback_steps == 0: 493 | callback(i, t, latents) 494 | 495 | # Post-processing 496 | video = self.decode_latents(latents) 497 | 498 | # Convert to tensor 499 | if output_type == "tensor": 500 | video = torch.from_numpy(video) 501 | 502 | if not return_dict: 503 | return video 504 | 505 | return AnimationPipelineOutput(videos=video) 506 | -------------------------------------------------------------------------------- /animatediff/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Conversion script for the LoRA's safetensors checkpoints.""" 16 | 17 | import argparse 18 | 19 | import torch 20 | 21 | 22 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 23 | # load base model 24 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 25 | 26 | # load LoRA weight from .safetensors 27 | # state_dict = load_file(checkpoint_path) 28 | 29 | visited = [] 30 | 31 | # directly update weight in diffusers model 32 | for key in state_dict: 33 | # it is suggested to print out the key, it usually will be something like below 34 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 35 | 36 | # as we have set the alpha beforehand, so just skip 37 | if ".alpha" in key or key in visited: 38 | continue 39 | 40 | if "text" in key: 41 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 42 | curr_layer = pipeline.text_encoder 43 | else: 44 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 45 | curr_layer = pipeline.unet 46 | 47 | # find the target layer 48 | temp_name = layer_infos.pop(0) 49 | while len(layer_infos) > -1: 50 | try: 51 | curr_layer = curr_layer.__getattr__(temp_name) 52 | if len(layer_infos) > 0: 53 | temp_name = layer_infos.pop(0) 54 | elif len(layer_infos) == 0: 55 | break 56 | except Exception: 57 | if len(temp_name) > 0: 58 | temp_name += "_" + layer_infos.pop(0) 59 | else: 60 | temp_name = layer_infos.pop(0) 61 | 62 | pair_keys = [] 63 | if "lora_down" in key: 64 | pair_keys.append(key.replace("lora_down", "lora_up")) 65 | pair_keys.append(key) 66 | else: 67 | pair_keys.append(key) 68 | pair_keys.append(key.replace("lora_up", "lora_down")) 69 | 70 | # update weight 71 | if len(state_dict[pair_keys[0]].shape) == 4: 72 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 73 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 74 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to( 75 | curr_layer.weight.data.device 76 | ) 77 | else: 78 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 79 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 80 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 81 | 82 | # update visited list 83 | for item in pair_keys: 84 | visited.append(item) 85 | 86 | return pipeline 87 | 88 | 89 | def convert_lora_model_level( 90 | state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6 91 | ): 92 | """convert lora in model level instead of pipeline leval""" 93 | 94 | visited = [] 95 | 96 | # directly update weight in diffusers model 97 | for key in state_dict: 98 | # it is suggested to print out the key, it usually will be something like below 99 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 100 | 101 | # as we have set the alpha beforehand, so just skip 102 | if ".alpha" in key or key in visited: 103 | continue 104 | 105 | if "text" in key: 106 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 107 | assert text_encoder is not None, "text_encoder must be passed since lora contains text encoder layers" 108 | curr_layer = text_encoder 109 | else: 110 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 111 | curr_layer = unet 112 | 113 | # find the target layer 114 | temp_name = layer_infos.pop(0) 115 | while len(layer_infos) > -1: 116 | try: 117 | curr_layer = curr_layer.__getattr__(temp_name) 118 | if len(layer_infos) > 0: 119 | temp_name = layer_infos.pop(0) 120 | elif len(layer_infos) == 0: 121 | break 122 | except Exception: 123 | if len(temp_name) > 0: 124 | temp_name += "_" + layer_infos.pop(0) 125 | else: 126 | temp_name = layer_infos.pop(0) 127 | 128 | pair_keys = [] 129 | if "lora_down" in key: 130 | pair_keys.append(key.replace("lora_down", "lora_up")) 131 | pair_keys.append(key) 132 | else: 133 | pair_keys.append(key) 134 | pair_keys.append(key.replace("lora_up", "lora_down")) 135 | 136 | # update weight 137 | # NOTE: load lycon, maybe have bugs :( 138 | if "conv_in" in pair_keys[0]: 139 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 140 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 141 | weight_up = weight_up.view(weight_up.size(0), -1) 142 | weight_down = weight_down.view(weight_down.size(0), -1) 143 | shape = list(curr_layer.weight.data.shape) 144 | shape[1] = 4 145 | curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape) 146 | elif "conv" in pair_keys[0]: 147 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 148 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 149 | weight_up = weight_up.view(weight_up.size(0), -1) 150 | weight_down = weight_down.view(weight_down.size(0), -1) 151 | shape = list(curr_layer.weight.data.shape) 152 | curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape) 153 | elif len(state_dict[pair_keys[0]].shape) == 4: 154 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 155 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 156 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to( 157 | curr_layer.weight.data.device 158 | ) 159 | else: 160 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 161 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 162 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 163 | 164 | # update visited list 165 | for item in pair_keys: 166 | visited.append(item) 167 | 168 | return unet, text_encoder 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | 174 | parser.add_argument( 175 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 176 | ) 177 | parser.add_argument( 178 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 179 | ) 180 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 181 | parser.add_argument( 182 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 183 | ) 184 | parser.add_argument( 185 | "--lora_prefix_text_encoder", 186 | default="lora_te", 187 | type=str, 188 | help="The prefix of text encoder weight in safetensors", 189 | ) 190 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 191 | parser.add_argument( 192 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 193 | ) 194 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 195 | 196 | args = parser.parse_args() 197 | 198 | base_model_path = args.base_model_path 199 | checkpoint_path = args.checkpoint_path 200 | dump_path = args.dump_path 201 | lora_prefix_unet = args.lora_prefix_unet 202 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 203 | alpha = args.alpha 204 | 205 | pipe = convert_lora_model_level( 206 | base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha 207 | ) 208 | 209 | pipe = pipe.to(args.device) 210 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 211 | -------------------------------------------------------------------------------- /animatediff/utils/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import Optional, Union 4 | 5 | import cv2 6 | import imageio 7 | import moviepy.editor as mpy 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | import torchvision 12 | from einops import rearrange 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | 17 | # We recommend to use the following affinity score(motion magnitude) 18 | # Also encourage to try to construct different score by yourself 19 | RANGE_LIST = [ 20 | [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion 21 | [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion 22 | [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion 23 | [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # Loop 24 | [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # Loop 25 | [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # Loop 26 | [0.5, 0.2], # Style Transfer Large Motion 27 | [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion 28 | [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion 29 | ] 30 | 31 | 32 | def zero_rank_print(s): 33 | if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): 34 | print("### " + s) 35 | 36 | 37 | def save_videos_mp4(video: torch.Tensor, path: str, fps: int = 8): 38 | video = rearrange(video, "b c t h w -> t b c h w") 39 | num_frames, batch_size, channels, height, width = video.shape 40 | assert batch_size == 1, "Only support batch size == 1" 41 | video = video.squeeze(1) 42 | video = rearrange(video, "t c h w -> t h w c") 43 | 44 | def make_frame(t): 45 | frame_tensor = video[int(t * fps)] 46 | frame_np = (frame_tensor * 255).numpy().astype("uint8") 47 | return frame_np 48 | 49 | clip = mpy.VideoClip(make_frame, duration=num_frames / fps) 50 | clip.write_videofile(path, fps=fps, codec="libx264") 51 | 52 | 53 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 54 | videos = rearrange(videos, "b c t h w -> t b c h w") 55 | outputs = [] 56 | for x in videos: 57 | x = torchvision.utils.make_grid(x, nrow=n_rows) 58 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 59 | if rescale: 60 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 61 | x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8) 62 | outputs.append(x) 63 | 64 | os.makedirs(os.path.dirname(path), exist_ok=True) 65 | imageio.mimsave(path, outputs, fps=fps) 66 | 67 | 68 | # DDIM Inversion 69 | @torch.no_grad() 70 | def init_prompt(prompt, pipeline): 71 | uncond_input = pipeline.tokenizer( 72 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt" 73 | ) 74 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 75 | text_input = pipeline.tokenizer( 76 | [prompt], 77 | padding="max_length", 78 | max_length=pipeline.tokenizer.model_max_length, 79 | truncation=True, 80 | return_tensors="pt", 81 | ) 82 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 83 | context = torch.cat([uncond_embeddings, text_embeddings]) 84 | 85 | return context 86 | 87 | 88 | def next_step( 89 | model_output: Union[torch.FloatTensor, np.ndarray], 90 | timestep: int, 91 | sample: Union[torch.FloatTensor, np.ndarray], 92 | ddim_scheduler, 93 | ): 94 | timestep, next_timestep = ( 95 | min(timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), 96 | timestep, 97 | ) 98 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 99 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 100 | beta_prod_t = 1 - alpha_prod_t 101 | next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 102 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 103 | next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction 104 | return next_sample 105 | 106 | 107 | def get_noise_pred_single(latents, t, context, unet): 108 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 109 | return noise_pred 110 | 111 | 112 | @torch.no_grad() 113 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 114 | context = init_prompt(prompt, pipeline) 115 | uncond_embeddings, cond_embeddings = context.chunk(2) 116 | all_latent = [latent] 117 | latent = latent.clone().detach() 118 | for i in tqdm(range(num_inv_steps)): 119 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 120 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 121 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 122 | all_latent.append(latent) 123 | return all_latent 124 | 125 | 126 | @torch.no_grad() 127 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 128 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 129 | return ddim_latents 130 | 131 | 132 | def prepare_mask_coef(video_length: int, cond_frame: int, sim_range: list = [0.2, 1.0]): 133 | assert len(sim_range) == 2, "sim_range should has the length of 2, including the min and max similarity" 134 | 135 | assert video_length > 1, "video_length should be greater than 1" 136 | 137 | assert video_length > cond_frame, "video_length should be greater than cond_frame" 138 | 139 | diff = abs(sim_range[0] - sim_range[1]) / (video_length - 1) 140 | coef = [1.0] * video_length 141 | for f in range(video_length): 142 | f_diff = diff * abs(cond_frame - f) 143 | f_diff = 1 - f_diff 144 | coef[f] *= f_diff 145 | 146 | return coef 147 | 148 | 149 | def prepare_mask_coef_by_statistics(video_length: int, cond_frame: int, sim_range: int): 150 | assert video_length > 0, "video_length should be greater than 0" 151 | 152 | assert video_length > cond_frame, "video_length should be greater than cond_frame" 153 | 154 | range_list = RANGE_LIST 155 | 156 | assert sim_range < len(range_list), f"sim_range type{sim_range} not implemented" 157 | 158 | coef = range_list[sim_range] 159 | coef = coef + ([coef[-1]] * (video_length - len(coef))) 160 | 161 | order = [abs(i - cond_frame) for i in range(video_length)] 162 | coef = [coef[order[i]] for i in range(video_length)] 163 | 164 | return coef 165 | 166 | 167 | def prepare_mask_coef_multi_cond(video_length: int, cond_frames: list, sim_range: list = [0.2, 1.0]): 168 | assert len(sim_range) == 2, "sim_range should has the length of 2, including the min and max similarity" 169 | 170 | assert video_length > 1, "video_length should be greater than 1" 171 | 172 | assert isinstance(cond_frames, list), "cond_frames should be a list" 173 | 174 | assert video_length > max(cond_frames), "video_length should be greater than cond_frame" 175 | 176 | if max(sim_range) == min(sim_range): 177 | cond_coefs = [sim_range[0]] * video_length 178 | return cond_coefs 179 | 180 | cond_coefs = [] 181 | 182 | for cond_frame in cond_frames: 183 | cond_coef = prepare_mask_coef(video_length, cond_frame, sim_range) 184 | cond_coefs.append(cond_coef) 185 | 186 | mixed_coef = [0] * video_length 187 | for conds in range(len(cond_frames)): 188 | for f in range(video_length): 189 | mixed_coef[f] = abs(cond_coefs[conds][f] - mixed_coef[f]) 190 | 191 | if conds > 0: 192 | min_num = min(mixed_coef) 193 | max_num = max(mixed_coef) 194 | 195 | for f in range(video_length): 196 | mixed_coef[f] = (mixed_coef[f] - min_num) / (max_num - min_num) 197 | 198 | mixed_max = max(mixed_coef) 199 | mixed_min = min(mixed_coef) 200 | for f in range(video_length): 201 | mixed_coef[f] = (max(sim_range) - min(sim_range)) * (mixed_coef[f] - mixed_min) / ( 202 | mixed_max - mixed_min 203 | ) + min(sim_range) 204 | 205 | mixed_coef = [ 206 | x if min(sim_range) <= x <= max(sim_range) else min(sim_range) if x < min(sim_range) else max(sim_range) 207 | for x in mixed_coef 208 | ] 209 | 210 | return mixed_coef 211 | 212 | 213 | def prepare_masked_latent_cond(video_length: int, cond_frames: list): 214 | for cond_frame in cond_frames: 215 | assert cond_frame < video_length, "cond_frame should be smaller than video_length" 216 | assert cond_frame > -1, f"cond_frame should be in the range of [0, {video_length}]" 217 | 218 | cond_frames.sort() 219 | nearest = [cond_frames[0]] * video_length 220 | for f in range(video_length): 221 | for cond_frame in cond_frames: 222 | if abs(nearest[f] - f) > abs(cond_frame - f): 223 | nearest[f] = cond_frame 224 | 225 | maked_latent_cond = nearest 226 | 227 | return maked_latent_cond 228 | 229 | 230 | def estimated_kernel_size(frame_width: int, frame_height: int) -> int: 231 | """Estimate kernel size based on video resolution.""" 232 | # TODO: This equation is based on manual estimation from a few videos. 233 | # Create a more comprehensive test suite to optimize against. 234 | size: int = 4 + round(math.sqrt(frame_width * frame_height) / 192) 235 | if size % 2 == 0: 236 | size += 1 237 | return size 238 | 239 | 240 | def detect_edges(lum: np.ndarray) -> np.ndarray: 241 | """Detect edges using the luma channel of a frame. 242 | 243 | Arguments: 244 | lum: 2D 8-bit image representing the luma channel of a frame. 245 | 246 | Returns: 247 | 2D 8-bit image of the same size as the input, where pixels with values of 255 248 | represent edges, and all other pixels are 0. 249 | """ 250 | # Initialize kernel. 251 | kernel_size = estimated_kernel_size(lum.shape[1], lum.shape[0]) 252 | kernel = np.ones((kernel_size, kernel_size), np.uint8) 253 | 254 | # Estimate levels for thresholding. 255 | # TODO(0.6.3): Add config file entries for sigma, aperture/kernel size, etc. 256 | sigma: float = 1.0 / 3.0 257 | median = np.median(lum) 258 | low = int(max(0, (1.0 - sigma) * median)) 259 | high = int(min(255, (1.0 + sigma) * median)) 260 | 261 | # Calculate edges using Canny algorithm, and reduce noise by dilating the edges. 262 | # This increases edge overlap leading to improved robustness against noise and slow 263 | # camera movement. Note that very large kernel sizes can negatively affect accuracy. 264 | edges = cv2.Canny(lum, low, high) 265 | return cv2.dilate(edges, kernel) 266 | 267 | 268 | def prepare_mask_coef_by_score( 269 | video_shape: list, 270 | cond_frame_idx: list, 271 | sim_range: list = [0.2, 1.0], 272 | statistic: list = [1, 100], 273 | coef_max: int = 0.98, 274 | score: Optional[torch.Tensor] = None, 275 | ): 276 | """ 277 | the shape of video_data is (b f c h w) 278 | cond_frame_idx is a list, with length of batch_size 279 | the shape of statistic is (f 2) 280 | the shape of score is (b f) 281 | the shape of coef is (b f) 282 | """ 283 | assert ( 284 | len(video_shape) == 2 285 | ), f"the shape of video_shape should be (b f c h w), but now get {len(video_shape.shape)} channels" 286 | 287 | batch_size, frame_num = video_shape[0], video_shape[1] 288 | 289 | score = score.permute(0, 2, 1).squeeze(0) 290 | 291 | # list -> b 1 292 | cond_fram_mat = torch.tensor(cond_frame_idx).unsqueeze(-1) 293 | 294 | statistic = torch.tensor(statistic) 295 | # (f 2) -> (b f 2) 296 | statistic = statistic.repeat(batch_size, 1, 1) 297 | 298 | # shape of order (b f), shape of cond_mat (b f) 299 | order = torch.arange(0, frame_num, 1) 300 | order = order.repeat(batch_size, 1) 301 | cond_mat = torch.ones((batch_size, frame_num)) * cond_fram_mat 302 | order = abs(order - cond_mat) 303 | 304 | statistic = statistic[:, order.to(torch.long)][0, :, :, :] 305 | 306 | # score (b f) max_s (b f 1) 307 | max_stats = torch.max(statistic, dim=2).values.to(dtype=score.dtype) 308 | min_stats = torch.min(statistic, dim=2).values.to(dtype=score.dtype) 309 | 310 | score[score > max_stats] = max_stats[score > max_stats] * 0.95 311 | score[score < min_stats] = min_stats[score < min_stats] 312 | 313 | eps = 1e-10 314 | coef = 1 - abs((score / (max_stats + eps)) * (max(sim_range) - min(sim_range))) 315 | 316 | indices = torch.arange(coef.shape[0]).unsqueeze(1) 317 | coef[indices, cond_fram_mat] = 1.0 318 | 319 | return coef 320 | 321 | 322 | def preprocess_img(img_path, max_size: int = 512): 323 | ori_image = Image.open(img_path).convert("RGB") 324 | 325 | width, height = ori_image.size 326 | 327 | long_edge = max(width, height) 328 | if long_edge > max_size: 329 | scale_factor = max_size / long_edge 330 | else: 331 | scale_factor = 1 332 | width = int(width * scale_factor) 333 | height = int(height * scale_factor) 334 | ori_image = ori_image.resize((width, height)) 335 | 336 | if (width % 8 != 0) or (height % 8 != 0): 337 | in_width = (width // 8) * 8 338 | in_height = (height // 8) * 8 339 | else: 340 | in_width = width 341 | in_height = height 342 | in_image = ori_image 343 | 344 | in_image = ori_image.resize((in_width, in_height)) 345 | # in_image = ori_image.resize((512, 512)) 346 | in_image_np = np.array(in_image) 347 | return in_image_np, in_height, in_width 348 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import random 5 | from argparse import ArgumentParser 6 | from datetime import datetime 7 | from glob import glob 8 | 9 | import gradio as gr 10 | import numpy as np 11 | import torch 12 | from omegaconf import OmegaConf 13 | from PIL import Image 14 | 15 | from animatediff.pipelines import I2VPipeline 16 | from animatediff.utils.util import save_videos_grid 17 | from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler 18 | 19 | 20 | sample_idx = 0 21 | scheduler_dict = { 22 | "DDIM": DDIMScheduler, 23 | "Euler": EulerDiscreteScheduler, 24 | "PNDM": PNDMScheduler, 25 | } 26 | 27 | css = """ 28 | .toolbutton { 29 | margin-buttom: 0em 0em 0em 0em; 30 | max-width: 2.5em; 31 | min-width: 2.5em !important; 32 | height: 2.5em; 33 | } 34 | """ 35 | 36 | parser = ArgumentParser() 37 | parser.add_argument("--config", type=str, default="example/config/base.yaml") 38 | parser.add_argument("--server-name", type=str, default="0.0.0.0") 39 | parser.add_argument("--port", type=int, default=7860) 40 | parser.add_argument("--share", action="store_true") 41 | 42 | parser.add_argument("--save-path", default="samples") 43 | 44 | args = parser.parse_args() 45 | 46 | 47 | N_PROMPT = ( 48 | "wrong white balance, dark, sketches,worst quality,low quality, " 49 | "deformed, distorted, disfigured, bad eyes, wrong lips, " 50 | "weird mouth, bad teeth, mutated hands and fingers, bad anatomy," 51 | "wrong anatomy, amputation, extra limb, missing limb, " 52 | "floating,limbs, disconnected limbs, mutation, ugly, disgusting, " 53 | "bad_pictures, negative_hand-neg" 54 | ) 55 | 56 | 57 | def preprocess_img(img_np, max_size: int = 512): 58 | ori_image = Image.fromarray(img_np).convert("RGB") 59 | 60 | width, height = ori_image.size 61 | 62 | long_edge = max(width, height) 63 | if long_edge > max_size: 64 | scale_factor = max_size / long_edge 65 | else: 66 | scale_factor = 1 67 | width = int(width * scale_factor) 68 | height = int(height * scale_factor) 69 | ori_image = ori_image.resize((width, height)) 70 | 71 | if (width % 8 != 0) or (height % 8 != 0): 72 | in_width = (width // 8) * 8 73 | in_height = (height // 8) * 8 74 | else: 75 | in_width = width 76 | in_height = height 77 | in_image = ori_image 78 | 79 | in_image = ori_image.resize((in_width, in_height)) 80 | in_image_np = np.array(in_image) 81 | return in_image_np, in_height, in_width 82 | 83 | 84 | class AnimateController: 85 | def __init__(self): 86 | # config dirs 87 | self.basedir = os.getcwd() 88 | self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") 89 | self.ip_adapter_dir = os.path.join(self.basedir, "models", "IP_Adapter") 90 | self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) 91 | self.savedir_sample = os.path.join(self.savedir, "sample") 92 | os.makedirs(self.savedir, exist_ok=True) 93 | 94 | self.stable_diffusion_list = [] 95 | self.motion_module_list = [] 96 | self.personalized_model_list = [] 97 | 98 | self.refresh_personalized_model() 99 | 100 | self.pipeline = None 101 | 102 | self.inference_config = OmegaConf.load(args.config) 103 | self.stable_diffusion_dir = self.inference_config.pretrained_model_path 104 | self.pia_path = self.inference_config.generate.model_path 105 | self.loaded = False 106 | 107 | def refresh_personalized_model(self): 108 | personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) 109 | self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] 110 | 111 | def get_ip_apdater_folder(self): 112 | file_list = os.listdir(self.ip_adapter_dir) 113 | if not file_list: 114 | return False 115 | 116 | if not "ip-adapter_sd15.bin" not in file_list: 117 | print('Cannot find "ip-adapter_sd15.bin" ' f"under {self.ip_adapter_dir}") 118 | return False 119 | if not "image_encoder" not in file_list: 120 | print(f'Cannot find "image_encoder" under {self.ip_adapter_dir}') 121 | return False 122 | 123 | return True 124 | 125 | def load_model(self, dreambooth_path=None, lora_path=None, lora_alpha=1.0, enable_ip_adapter=True): 126 | gr.Info("Start Load Models...") 127 | print("Start Load Models...") 128 | 129 | if lora_path and lora_path.upper() != "NONE": 130 | lora_path = osp.join(self.personalized_model_dir, lora_path) 131 | else: 132 | lora_path = None 133 | 134 | if dreambooth_path and dreambooth_path.upper() != "NONE": 135 | dreambooth_path = osp.join(self.personalized_model_dir, dreambooth_path) 136 | else: 137 | dreambooth_path = None 138 | 139 | if enable_ip_adapter: 140 | if not self.get_ip_apdater_folder(): 141 | print("Load IP-Adapter from remote.") 142 | ip_adapter_path = "h94/IP-Adapter" 143 | else: 144 | ip_adapter_path = self.ip_adapter_dir 145 | else: 146 | ip_adapter_path = None 147 | 148 | self.pipeline = I2VPipeline.build_pipeline( 149 | self.inference_config, 150 | self.stable_diffusion_dir, 151 | unet_path=self.pia_path, 152 | dreambooth_path=dreambooth_path, 153 | lora_path=lora_path, 154 | lora_alpha=lora_alpha, 155 | ip_adapter_path=ip_adapter_path, 156 | ) 157 | gr.Info("Load Finish!") 158 | print("Load Finish!") 159 | self.loaded = True 160 | 161 | return "Load" 162 | 163 | def animate( 164 | self, 165 | init_img, 166 | motion_scale, 167 | prompt_textbox, 168 | negative_prompt_textbox, 169 | sampler_dropdown, 170 | sample_step_slider, 171 | length_slider, 172 | cfg_scale_slider, 173 | seed_textbox, 174 | ip_adapter_scale, 175 | max_size, 176 | progress=gr.Progress(), 177 | ): 178 | if not self.loaded: 179 | raise gr.Error("Please load model first!") 180 | 181 | if seed_textbox != -1 and seed_textbox != "": 182 | torch.manual_seed(int(seed_textbox)) 183 | else: 184 | torch.seed() 185 | seed = torch.initial_seed() 186 | init_img, h, w = preprocess_img(init_img, max_size) 187 | sample = self.pipeline( 188 | image=init_img, 189 | prompt=prompt_textbox, 190 | negative_prompt=negative_prompt_textbox, 191 | num_inference_steps=sample_step_slider, 192 | guidance_scale=cfg_scale_slider, 193 | width=w, 194 | height=h, 195 | video_length=16, 196 | mask_sim_template_idx=motion_scale, 197 | ip_adapter_scale=ip_adapter_scale, 198 | progress_fn=progress, 199 | ).videos 200 | 201 | save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") 202 | save_videos_grid(sample, save_sample_path) 203 | 204 | sample_config = { 205 | "prompt": prompt_textbox, 206 | "n_prompt": negative_prompt_textbox, 207 | "sampler": sampler_dropdown, 208 | "num_inference_steps": sample_step_slider, 209 | "guidance_scale": cfg_scale_slider, 210 | "width": w, 211 | "height": h, 212 | "video_length": length_slider, 213 | "seed": seed, 214 | "motion": motion_scale, 215 | } 216 | json_str = json.dumps(sample_config, indent=4) 217 | with open(os.path.join(self.savedir, "logs.json"), "a") as f: 218 | f.write(json_str) 219 | f.write("\n\n") 220 | 221 | return save_sample_path 222 | 223 | 224 | controller = AnimateController() 225 | 226 | 227 | def ui(): 228 | with gr.Blocks(css=css) as demo: 229 | motion_idx = gr.State(0) 230 | gr.HTML( 231 | "
    Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models
    " 232 | ) 233 | with gr.Row(): 234 | gr.Markdown( 235 | "
    Project Page  " # noqa 236 | "Paper  " 237 | "Code  " # noqa 238 | "Demo
    " # noqa 239 | ) 240 | 241 | with gr.Column(variant="panel"): 242 | gr.Markdown( 243 | """ 244 | ### 1. Model checkpoints (select pretrained model path first). 245 | """ 246 | ) 247 | with gr.Row(): 248 | base_model_dropdown = gr.Dropdown( 249 | label="Select base Dreambooth model", 250 | choices=["none"] + controller.personalized_model_list, 251 | value="none", 252 | interactive=True, 253 | ) 254 | 255 | lora_model_dropdown = gr.Dropdown( 256 | label="Select LoRA model (optional)", 257 | choices=["none"] + controller.personalized_model_list, 258 | value="none", 259 | interactive=True, 260 | ) 261 | 262 | lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0, minimum=0, maximum=2, interactive=True) 263 | 264 | personalized_refresh_button = gr.Button(value="\U0001f503", elem_classes="toolbutton") 265 | 266 | def update_personalized_model(): 267 | controller.refresh_personalized_model() 268 | return [controller.personalized_model_list, ["none"] + controller.personalized_model_list] 269 | 270 | personalized_refresh_button.click( 271 | fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown] 272 | ) 273 | 274 | load_model_button = gr.Button(value="Load") 275 | load_model_button.click( 276 | fn=controller.load_model, 277 | inputs=[ 278 | base_model_dropdown, 279 | lora_model_dropdown, 280 | lora_alpha_slider, 281 | ], 282 | outputs=[load_model_button], 283 | ) 284 | 285 | with gr.Column(variant="panel"): 286 | gr.Markdown( 287 | """ 288 | ### 2. Configs for PIA. 289 | """ 290 | ) 291 | 292 | prompt_textbox = gr.Textbox(label="Prompt", lines=2) 293 | negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1) 294 | 295 | with gr.Row(equal_height=False): 296 | with gr.Column(): 297 | with gr.Row(): 298 | init_img = gr.Image(label="Input Image") 299 | 300 | with gr.Row(): 301 | sampler_dropdown = gr.Dropdown( 302 | label="Sampling method", 303 | choices=list(scheduler_dict.keys()), 304 | value=list(scheduler_dict.keys())[0], 305 | ) 306 | sample_step_slider = gr.Slider( 307 | label="Sampling steps", value=25, minimum=10, maximum=100, step=1 308 | ) 309 | 310 | max_size_slider = gr.Slider( 311 | label="Max size (The long edge of the input image will be resized to this value, larger value means slower inference speed)", 312 | value=512, 313 | step=64, 314 | minimum=512, 315 | maximum=1024, 316 | ) 317 | 318 | length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1) 319 | cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20) 320 | motion_scale_silder = gr.Slider( 321 | label="Motion Scale", value=motion_idx.value, step=1, minimum=0, maximum=2 322 | ) 323 | ip_adapter_scale = gr.Slider(label="IP-Apdater Scale", value=0.0, minimum=0, maximum=1) 324 | 325 | def GenerationMode(motion_scale_silder, option): 326 | if option == "Animation": 327 | motion_idx = motion_scale_silder 328 | elif option == "Style Transfer": 329 | motion_idx = motion_scale_silder * -1 - 1 330 | elif option == "Loop Video": 331 | motion_idx = motion_scale_silder + 3 332 | return motion_idx 333 | 334 | with gr.Row(): 335 | style_selection = gr.Radio( 336 | ["Animation", "Style Transfer", "Loop Video"], label="Generation Mode", value="Animation" 337 | ) 338 | style_selection.change( 339 | fn=GenerationMode, inputs=[motion_scale_silder, style_selection], outputs=[motion_idx] 340 | ) 341 | motion_scale_silder.change( 342 | fn=GenerationMode, inputs=[motion_scale_silder, style_selection], outputs=[motion_idx] 343 | ) 344 | 345 | with gr.Row(): 346 | seed_textbox = gr.Textbox(label="Seed", value=-1) 347 | seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton") 348 | seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False) 349 | 350 | generate_button = gr.Button(value="Generate", variant="primary") 351 | 352 | result_video = gr.Video(label="Generated Animation", interactive=False) 353 | 354 | generate_button.click( 355 | fn=controller.animate, 356 | inputs=[ 357 | init_img, 358 | motion_idx, 359 | prompt_textbox, 360 | negative_prompt_textbox, 361 | sampler_dropdown, 362 | sample_step_slider, 363 | length_slider, 364 | cfg_scale_slider, 365 | seed_textbox, 366 | ip_adapter_scale, 367 | max_size_slider, 368 | ], 369 | outputs=[result_video], 370 | ) 371 | 372 | return demo 373 | 374 | 375 | if __name__ == "__main__": 376 | demo = ui() 377 | demo.queue(3) 378 | demo.launch(server_name=args.server_name, server_port=args.port, share=args.share, allowed_paths=["pia.png"]) 379 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | python_version: "3.11" 7 | system_packages: 8 | - "libgl1-mesa-glx" 9 | - "libglib2.0-0" 10 | python_packages: 11 | - torch==2.0.1 12 | - torchvision==0.15.2 13 | - diffusers==0.24.0 14 | - transformers==4.36.0 15 | - accelerate==0.25.0 16 | - imageio==2.27.0 17 | - decord==0.6.0 18 | - einops==0.7.0 19 | - omegaconf==2.3.0 20 | - safetensors==0.4.1 21 | - opencv-python==4.8.1.78 22 | - moviepy==1.0.3 23 | run: 24 | - pip install xformers 25 | predict: "predict.py:Predictor" 26 | -------------------------------------------------------------------------------- /download_bashscripts/1-RealisticVision.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/DreamBooth_LoRA/realisticVisionV51_v51VAE.safetensors https://huggingface.co/frankjoshua/realisticVisionV51_v51VAE/resolve/main/realisticVisionV51_v51VAE.safetensors?download=true --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /download_bashscripts/2-RcnzCartoon.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors https://civitai.com/api/download/models/71009 --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /download_bashscripts/3-MajicMix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/DreamBooth_LoRA/majicmixRealistic_v5.safetensors https://civitai.com/api/download/models/82446 --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /environment-pt2.yaml: -------------------------------------------------------------------------------- 1 | name: pia 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.10 7 | - pytorch=2.0.0 8 | - torchvision=0.15.0 9 | - pytorch-cuda=11.8 10 | - pip 11 | - pip: 12 | - diffusers==0.24.0 13 | - transformers==4.25.1 14 | - xformers 15 | - imageio==2.33.1 16 | - decord==0.6.0 17 | - gdown 18 | - einops 19 | - omegaconf 20 | - safetensors 21 | - gradio 22 | - wandb 23 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: pia 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.10 7 | - pytorch=1.13.1 8 | - torchvision=0.14.1 9 | - torchaudio=0.13.1 10 | - pytorch-cuda=11.7 11 | - pip 12 | - pip: 13 | - diffusers==0.24.0 14 | - transformers==4.25.1 15 | - xformers==0.0.16 16 | - imageio==2.27.0 17 | - decord==0.6.0 18 | - gdown 19 | - einops 20 | - omegaconf 21 | - safetensors 22 | - gradio 23 | - wandb 24 | -------------------------------------------------------------------------------- /example/config/anya.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - 1girl smiling 4 | - 1girl open mouth 5 | - 1girl crying, pout 6 | n_prompt: 7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 8 | validation_data: 9 | input_name: 'anya' 10 | validation_input_path: 'example/img' 11 | save_path: 'example/result' 12 | mask_sim_range: [-1] 13 | generate: 14 | use_lora: false 15 | use_db: true 16 | global_seed: 10201304011203481429 17 | lora_path: "models/DreamBooth_LoRA/cyberpunk.safetensors" 18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" 19 | lora_alpha: 0.8 20 | -------------------------------------------------------------------------------- /example/config/base.yaml: -------------------------------------------------------------------------------- 1 | generate: 2 | model_path: "models/PIA/pia.ckpt" 3 | use_image: true 4 | use_video: false 5 | sample_width: 512 6 | sample_height: 512 7 | video_length: 16 8 | 9 | validation_data: 10 | mask_sim_range: [0, 1] 11 | cond_frame: 0 12 | num_inference_steps: 25 13 | 14 | img_mask: '' 15 | 16 | noise_scheduler_kwargs: 17 | num_train_timesteps: 1000 18 | beta_start: 0.00085 19 | beta_end: 0.012 20 | beta_schedule: "linear" 21 | steps_offset: 1 22 | clip_sample: false 23 | 24 | pretrained_model_path: "models/StableDiffusion/" 25 | unet_additional_kwargs: 26 | use_motion_module : true 27 | motion_module_resolutions : [ 1,2,4,8 ] 28 | unet_use_cross_frame_attention : false 29 | unet_use_temporal_attention : false 30 | 31 | motion_module_type: Vanilla 32 | motion_module_kwargs: 33 | num_attention_heads : 8 34 | num_transformer_block : 1 35 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 36 | temporal_position_encoding : true 37 | temporal_position_encoding_max_len : 32 38 | temporal_attention_dim_div : 1 39 | zero_initialize : true 40 | -------------------------------------------------------------------------------- /example/config/bear.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - 1bear walking in a shop, best quality, 4k 4 | n_prompt: 5 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 6 | validation_data: 7 | input_name: 'bear' 8 | validation_input_path: 'example/img' 9 | save_path: 'example/result' 10 | mask_sim_range: [0, 1, 2] 11 | generate: 12 | use_lora: false 13 | use_db: true 14 | global_seed: 10201034102130841429 15 | lora_path: "models/DreamBooth_LoRA/cyberpunk.safetensors" 16 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" 17 | lora_alpha: 0.4 18 | -------------------------------------------------------------------------------- /example/config/concert.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - 1man is smiling, masterpiece, best quality, 1boy, afro, dark skin, playing guitar, concert, upper body, sweat, stage lights, oversized hawaiian shirt, intricate, print, pattern, happy, necklace, bokeh, jeans, drummer, dynamic pose 4 | - 1man is crying, masterpiece, best quality, 1boy, afro, dark skin, playing guitar, concert, upper body, sweat, stage lights, oversized hawaiian shirt, intricate, print, pattern, happy, necklace, bokeh, jeans, drummer, dynamic pose 5 | - 1man is singing, masterpiece, best quality, 1boy, afro, dark skin, playing guitar, concert, upper body, sweat, stage lights, oversized hawaiian shirt, intricate, print, pattern, happy, necklace, bokeh, jeans, drummer, dynamic pose 6 | n_prompt: 7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 8 | validation_data: 9 | input_name: 'concert' 10 | validation_input_path: 'example/img' 11 | save_path: 'example/result' 12 | mask_sim_range: [-3] 13 | generate: 14 | use_lora: false 15 | use_db: true 16 | global_seed: 4292543217695451092 # To get 3d style shown in github, you can use seed: 4292543217695451088 17 | lora_path: "" 18 | db_path: "models/DreamBooth_LoRA/realisticVisionV51_v51VAE.safetensors" 19 | lora_alpha: 0.8 20 | -------------------------------------------------------------------------------- /example/config/genshin.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k 4 | n_prompt: 5 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 6 | validation_data: 7 | input_name: 'genshin' 8 | validation_input_path: 'example/img' 9 | save_path: 'example/result' 10 | mask_sim_range: [0, 1, 2] 11 | generate: 12 | use_lora: false 13 | use_db: true 14 | sample_width: 512 15 | sample_height: 768 16 | global_seed: 10041042941301238026 17 | lora_path: "models/DreamBooth_LoRA/genshin.safetensors" 18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" 19 | lora_alpha: 0.4 20 | -------------------------------------------------------------------------------- /example/config/harry.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - 1boy smiling 4 | - 1boy playing magic fire 5 | - 1boy is waving hands 6 | n_prompt: 7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 8 | validation_data: 9 | input_name: 'harry' 10 | validation_input_path: 'example/img' 11 | save_path: 'example/result' 12 | mask_sim_range: [1] 13 | generate: 14 | use_lora: false 15 | use_db: true 16 | global_seed: 10201403011320481249 17 | lora_path: "" 18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" 19 | lora_alpha: 0.8 20 | -------------------------------------------------------------------------------- /example/config/labrador.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - a golden labrador jump 4 | - a golden labrador walking 5 | - a golden labrador is running 6 | n_prompt: 7 | - 'collar, leashes, collars, wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly' 8 | validation_data: 9 | input_name: 'labrador' 10 | validation_input_path: 'example/img' 11 | save_path: 'example/result' 12 | mask_sim_range: [0, 1, 2] 13 | generate: 14 | use_lora: false 15 | use_db: true 16 | global_seed: 4292543217695451000 17 | lora_path: "" 18 | db_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" 19 | lora_alpha: 0.8 20 | -------------------------------------------------------------------------------- /example/config/lighthouse.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - lightning, lighthouse 4 | - sun rising, lighthouse 5 | - fireworks, lighthouse 6 | n_prompt: 7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 8 | validation_data: 9 | input_name: 'lighthouse' 10 | validation_input_path: 'example/img' 11 | save_path: 'example/result' 12 | mask_sim_range: [0] 13 | generate: 14 | use_lora: false 15 | use_db: true 16 | global_seed: 5658137986800322011 17 | lora_path: "" 18 | db_path: "models/DreamBooth_LoRA/realisticVisionV51_v51VAE.safetensors" 19 | lora_alpha: 0.8 20 | -------------------------------------------------------------------------------- /example/config/majic_girl.yaml: -------------------------------------------------------------------------------- 1 | base: 'example/config/base.yaml' 2 | prompts: 3 | - - 1girl is smiling, lowres,watermark 4 | - 1girl is crying, lowres,watermark 5 | - 1girl, snowing dark night 6 | n_prompt: 7 | - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 8 | validation_data: 9 | input_name: 'majic_girl' 10 | validation_input_path: 'example/img' 11 | save_path: 'example/result' 12 | mask_sim_range: [1] 13 | generate: 14 | use_lora: false 15 | use_db: true 16 | global_seed: 10021403011302841249 17 | lora_path: "" 18 | db_path: "models/DreamBooth_LoRA/majicmixRealistic_v5.safetensors" 19 | lora_alpha: 0.8 20 | -------------------------------------------------------------------------------- /example/config/train.yaml: -------------------------------------------------------------------------------- 1 | image_finetune: false 2 | 3 | output_dir: "outputs" 4 | pretrained_model_path: "./models/StableDiffusion/" 5 | pretrained_motion_module_path: './models/Motion_Module/mm_sd_v15_v2.ckpt' 6 | 7 | 8 | unet_additional_kwargs: 9 | use_motion_module : true 10 | motion_module_resolutions : [ 1,2,4,8 ] 11 | unet_use_cross_frame_attention : false 12 | unet_use_temporal_attention : false 13 | 14 | motion_module_type: Vanilla 15 | motion_module_kwargs: 16 | num_attention_heads : 8 17 | num_transformer_block : 1 18 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 19 | temporal_position_encoding : true 20 | temporal_position_encoding_max_len : 32 21 | temporal_attention_dim_div : 1 22 | zero_initialize : true 23 | 24 | mask_sim_range: [0.2, 1.0] 25 | 26 | noise_scheduler_kwargs: 27 | num_train_timesteps: 1000 28 | beta_start: 0.00085 29 | beta_end: 0.012 30 | beta_schedule: "linear" 31 | steps_offset: 1 32 | clip_sample: false 33 | 34 | train_data: 35 | csv_path: "./results_10M_train.csv" 36 | video_folder: "data/WebVid10M/" # local path: replace it with yours 37 | # video_folder: "webvideo:s3://WebVid10M/" # petreloss path: replace it with yours 38 | sample_size: 256 39 | sample_stride: 4 40 | sample_n_frames: 16 41 | use_petreloss: false #set this as true if you want to use petreloss path 42 | conf_path: "~/petreloss.conf" 43 | 44 | validation_data: 45 | prompts: 46 | - "waves, ocean flows, sand, clean sea, breath-taking beautiful beach, tropicaHl beach." 47 | - "1girl walking on the street" 48 | - "Robot dancing in times square." 49 | - "Pacific coast, carmel by the sea ocean and waves." 50 | num_inference_steps: 25 51 | guidance_scale: 8. 52 | mask_sim_range: [0.2, 1.0] 53 | 54 | trainable_modules: 55 | - 'conv_in.' 56 | - 'motion_modules' 57 | 58 | # set the path to the finetuned unet's image layers 59 | # according to 60 | # https://github.com/guoyww/AnimateDiff/blob/main/__assets__/docs/animatediff.md#training 61 | unet_checkpoint_path: "models/mm_sd_v15_v2_full.ckpt" 62 | 63 | learning_rate: 1.e-4 64 | train_batch_size: 4 65 | gradient_accumulation_steps: 16 66 | 67 | max_train_epoch: -1 68 | max_train_steps: 500000 69 | checkpointing_epochs: -1 70 | checkpointing_steps: 60 71 | 72 | validation_steps: 3000 73 | validation_steps_tuple: [2, 50, 1000] 74 | 75 | global_seed: 42 76 | mixed_precision_training: true 77 | enable_xformers_memory_efficient_attention: True 78 | 79 | is_debug: False 80 | 81 | # precalculated statistics 82 | statistic: [[0., 0.], 83 | [0.3535855, 24.23687346], 84 | [0.91609545, 30.65091947], 85 | [1.41165152, 34.40093286], 86 | [1.56943881, 36.99639585], 87 | [1.73182842, 39.42044163], 88 | [1.82733002, 40.94703526], 89 | [1.88060527, 42.66233244], 90 | [1.96208071, 43.73070788], 91 | [2.02723091, 44.25965378], 92 | [2.10820894, 45.66120213], 93 | [2.21115041, 46.29561324], 94 | [2.23412351, 47.08810863], 95 | [2.29430165, 47.9515062], 96 | [2.32986362, 48.69085638], 97 | [2.37310751, 49.19931439]] 98 | -------------------------------------------------------------------------------- /example/img/anya.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/anya.jpg -------------------------------------------------------------------------------- /example/img/bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/bear.jpg -------------------------------------------------------------------------------- /example/img/concert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/concert.png -------------------------------------------------------------------------------- /example/img/genshin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/genshin.jpg -------------------------------------------------------------------------------- /example/img/harry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/harry.png -------------------------------------------------------------------------------- /example/img/labrador.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/labrador.png -------------------------------------------------------------------------------- /example/img/lighthouse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/lighthouse.jpg -------------------------------------------------------------------------------- /example/img/majic_girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/example/img/majic_girl.jpg -------------------------------------------------------------------------------- /example/openxlab/1-realistic.yaml: -------------------------------------------------------------------------------- 1 | dreambooth: 'realisticVisionV51_v51VAE.safetensors' 2 | 3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 4 | 5 | guidance_scale: 7 6 | -------------------------------------------------------------------------------- /example/openxlab/3-3d.yaml: -------------------------------------------------------------------------------- 1 | dreambooth: 'rcnzCartoon3d_v10.safetensors' 2 | 3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 4 | prefix: '' 5 | 6 | guidance_scale: 7 7 | 8 | ip_adapter_scale: 0.0 9 | -------------------------------------------------------------------------------- /example/replicate/1-realistic.yaml: -------------------------------------------------------------------------------- 1 | dreambooth: 'realisticVisionV51_v51VAE.safetensors' 2 | 3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 4 | 5 | guidance_scale: 7 6 | -------------------------------------------------------------------------------- /example/replicate/3-3d.yaml: -------------------------------------------------------------------------------- 1 | dreambooth: 'rcnzCartoon3d_v10.safetensors' 2 | 3 | n_prompt: 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg' 4 | prefix: '' 5 | 6 | guidance_scale: 7 7 | 8 | ip_adapter_scale: 0.0 9 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py 2 | import argparse 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from omegaconf import OmegaConf 8 | 9 | from animatediff.pipelines import I2VPipeline 10 | from animatediff.utils.util import preprocess_img, save_videos_grid 11 | 12 | 13 | def seed_everything(seed): 14 | import random 15 | 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed % (2**32)) 19 | random.seed(seed) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | functional_group = parser.add_mutually_exclusive_group() 25 | parser.add_argument("--config", type=str, default="configs/test.yaml") 26 | parser.add_argument( 27 | "--magnitude", type=int, default=None, choices=[0, 1, 2, -1, -2, -3] 28 | ) # negative is for style transfer 29 | functional_group.add_argument("--loop", action="store_true") 30 | functional_group.add_argument("--style_transfer", action="store_true") 31 | args = parser.parse_args() 32 | 33 | config = OmegaConf.load(args.config) 34 | base_config = OmegaConf.load(config.base) 35 | config = OmegaConf.merge(base_config, config) 36 | 37 | if args.magnitude is not None: 38 | config.validation_data.mask_sim_range = [args.magnitude] 39 | 40 | if args.style_transfer: 41 | config.validation_data.mask_sim_range = [ 42 | -1 * magnitude - 1 if magnitude >= 0 else magnitude for magnitude in config.validation_data.mask_sim_range 43 | ] 44 | elif args.loop: 45 | config.validation_data.mask_sim_range = [ 46 | magnitude + 3 if magnitude >= 0 else magnitude for magnitude in config.validation_data.mask_sim_range 47 | ] 48 | 49 | os.makedirs(config.validation_data.save_path, exist_ok=True) 50 | folder_num = len(os.listdir(config.validation_data.save_path)) 51 | target_dir = f"{config.validation_data.save_path}/{folder_num}/" 52 | 53 | # prepare paths and pipeline 54 | base_model_path = config.pretrained_model_path 55 | unet_path = config.generate.model_path 56 | dreambooth_path = config.generate.db_path 57 | if config.generate.use_lora: 58 | lora_path = config.generate.get("lora_path", None) 59 | lora_alpha = config.generate.get("lora_alpha", 0) 60 | else: 61 | lora_path = None 62 | lora_alpha = 0 63 | validation_pipeline = I2VPipeline.build_pipeline( 64 | config, 65 | base_model_path, 66 | unet_path, 67 | dreambooth_path, 68 | lora_path, 69 | lora_alpha, 70 | ) 71 | generator = torch.Generator(device="cuda") 72 | generator.manual_seed(config.generate.global_seed) 73 | 74 | global_inf_num = 0 75 | 76 | # if not os.path.exists(target_dir): 77 | os.makedirs(target_dir, exist_ok=True) 78 | 79 | # print(" >>> Begin test >>>") 80 | print(f"using unet : {unet_path}") 81 | print(f"using DreamBooth: {dreambooth_path}") 82 | print(f"using Lora : {lora_path}") 83 | 84 | sim_ranges = config.validation_data.mask_sim_range 85 | if isinstance(sim_ranges, int): 86 | sim_ranges = [sim_ranges] 87 | 88 | OmegaConf.save(config, os.path.join(target_dir, "config.yaml")) 89 | generator.manual_seed(config.generate.global_seed) 90 | seed_everything(config.generate.global_seed) 91 | 92 | # load image 93 | img_root = config.validation_data.validation_input_path 94 | input_name = config.validation_data.input_name 95 | if os.path.exists(os.path.join(img_root, f"{input_name}.jpg")): 96 | image_name = os.path.join(img_root, f"{input_name}.jpg") 97 | elif os.path.exists(os.path.join(img_root, f"{input_name}.png")): 98 | image_name = os.path.join(img_root, f"{input_name}.png") 99 | else: 100 | raise ValueError("image_name should be .jpg or .png") 101 | # image = np.array(Image.open(image_name)) 102 | image, gen_height, gen_width = preprocess_img(image_name) 103 | config.generate.sample_height = gen_height 104 | config.generate.sample_width = gen_width 105 | 106 | for sim_range in sim_ranges: 107 | print(f"using sim_range : {sim_range}") 108 | config.validation_data.mask_sim_range = sim_range 109 | prompt_num = 0 110 | for prompt, n_prompt in zip(config.prompts, config.n_prompt): 111 | print(f"using n_prompt : {n_prompt}") 112 | prompt_num += 1 113 | for single_prompt in prompt: 114 | print(f" >>> Begin test {global_inf_num} >>>") 115 | global_inf_num += 1 116 | image_path = "" 117 | sample = validation_pipeline( 118 | image=image, 119 | prompt=single_prompt, 120 | generator=generator, 121 | # global_inf_num = global_inf_num, 122 | video_length=config.generate.video_length, 123 | height=config.generate.sample_height, 124 | width=config.generate.sample_width, 125 | negative_prompt=n_prompt, 126 | mask_sim_template_idx=config.validation_data.mask_sim_range, 127 | **config.validation_data, 128 | ).videos 129 | save_videos_grid(sample, target_dir + f"{global_inf_num}_sim_{sim_range}.gif") 130 | print(f" <<< test {global_inf_num} Done <<<") 131 | print(" <<< Test Done <<<") 132 | -------------------------------------------------------------------------------- /models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt -------------------------------------------------------------------------------- /models/IP_Adapter/Put IP-Adapter checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/IP_Adapter/Put IP-Adapter checkpoints here.txt -------------------------------------------------------------------------------- /models/Motion_Module/Put motion module checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/Motion_Module/Put motion module checkpoints here.txt -------------------------------------------------------------------------------- /models/VAE/Put VAE checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/models/VAE/Put VAE checkpoints here.txt -------------------------------------------------------------------------------- /pia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/PIA/73f06d741532295046807ac8e4042aa9ce7d4f27/pia.png -------------------------------------------------------------------------------- /pia.yml: -------------------------------------------------------------------------------- 1 | name: pia 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - bzip2=1.0.8=h5eee18b_6 10 | - ca-certificates=2024.3.11=h06a4308_0 11 | - certifi=2024.6.2=py310h06a4308_0 12 | - cuda-cudart=11.8.89=0 13 | - cuda-cupti=11.8.87=0 14 | - cuda-libraries=11.8.0=0 15 | - cuda-nvrtc=11.8.89=0 16 | - cuda-nvtx=11.8.86=0 17 | - cuda-runtime=11.8.0=0 18 | - cuda-version=12.4=hbda6634_3 19 | - ld_impl_linux-64=2.38=h1181459_1 20 | - libcublas=11.11.3.6=0 21 | - libcufft=10.9.0.58=0 22 | - libcufile=1.9.1.3=h99ab3db_1 23 | - libcurand=10.3.5.147=h99ab3db_1 24 | - libcusolver=11.4.1.48=0 25 | - libcusparse=11.7.5.86=0 26 | - libffi=3.4.4=h6a678d5_1 27 | - libgcc-ng=11.2.0=h1234567_1 28 | - libgomp=11.2.0=h1234567_1 29 | - libnpp=11.8.0.86=0 30 | - libnvjpeg=11.9.0.86=0 31 | - libstdcxx-ng=11.2.0=h1234567_1 32 | - libuuid=1.41.5=h5eee18b_0 33 | - ncurses=6.4=h6a678d5_0 34 | - openssl=3.0.13=h7f8727e_2 35 | - pip=24.0=py310h06a4308_0 36 | - python=3.10.14=h955ad1f_1 37 | - pytorch-cuda=11.8=h7e8668a_5 38 | - readline=8.2=h5eee18b_0 39 | - setuptools=69.5.1=py310h06a4308_0 40 | - sqlite=3.45.3=h5eee18b_0 41 | - tk=8.6.14=h39e8969_0 42 | - wheel=0.43.0=py310h06a4308_0 43 | - xz=5.4.6=h5eee18b_1 44 | - zlib=1.2.13=h5eee18b_1 45 | - pip: 46 | - accelerate==0.31.0 47 | - aiofiles==23.2.1 48 | - altair==5.3.0 49 | - annotated-types==0.7.0 50 | - antlr4-python3-runtime==4.9.3 51 | - anyio==4.4.0 52 | - attrs==23.2.0 53 | - beautifulsoup4==4.12.3 54 | - blessed==1.20.0 55 | - boto3==1.34.125 56 | - botocore==1.34.125 57 | - cfgv==3.4.0 58 | - charset-normalizer==3.3.2 59 | - click==8.1.7 60 | - coloredlogs==15.0.1 61 | - contourpy==1.2.1 62 | - cycler==0.12.1 63 | - decorator==4.4.2 64 | - decord==0.6.0 65 | - diffusers==0.24.0 66 | - distlib==0.3.8 67 | - dnspython==2.6.1 68 | - docker-pycreds==0.4.0 69 | - einops==0.8.0 70 | - email-validator==2.1.1 71 | - environs==11.0.0 72 | - exceptiongroup==1.2.1 73 | - fastapi==0.111.0 74 | - fastapi-cli==0.0.4 75 | - ffmpy==0.3.2 76 | - filelock==3.14.0 77 | - fonttools==4.53.0 78 | - fsspec==2024.6.0 79 | - gdown==5.2.0 80 | - git-lfs==1.6 81 | - gitdb==4.0.11 82 | - gitpython==3.1.43 83 | - gpustat==1.1.1 84 | - gradio==4.36.0 85 | - gradio-client==1.0.1 86 | - h11==0.14.0 87 | - httpcore==1.0.5 88 | - httptools==0.6.1 89 | - httpx==0.27.0 90 | - huggingface-hub==0.23.3 91 | - humanfriendly==10.0 92 | - humanize==4.9.0 93 | - identify==2.5.36 94 | - idna==3.7 95 | - imageio==2.33.1 96 | - imageio-ffmpeg==0.5.1 97 | - importlib-metadata==7.1.0 98 | - importlib-resources==6.4.0 99 | - jinja2==3.1.3 100 | - jmespath==1.0.1 101 | - jsonschema==4.22.0 102 | - jsonschema-specifications==2023.12.1 103 | - kiwisolver==1.4.5 104 | - markdown-it-py==3.0.0 105 | - markupsafe==2.1.5 106 | - marshmallow==3.21.3 107 | - matplotlib==3.9.0 108 | - mdurl==0.1.2 109 | - moviepy==1.0.3 110 | - mpmath==1.3.0 111 | - multiprocessing-logging==0.3.4 112 | - networkx==3.2.1 113 | - nodeenv==1.9.1 114 | - numpy==1.26.4 115 | - nvidia-cublas-cu11==11.11.3.6 116 | - nvidia-cublas-cu12==12.1.3.1 117 | - nvidia-cuda-cupti-cu11==11.8.87 118 | - nvidia-cuda-cupti-cu12==12.1.105 119 | - nvidia-cuda-nvrtc-cu11==11.8.89 120 | - nvidia-cuda-nvrtc-cu12==12.1.105 121 | - nvidia-cuda-runtime-cu11==11.8.89 122 | - nvidia-cuda-runtime-cu12==12.1.105 123 | - nvidia-cudnn-cu11==8.7.0.84 124 | - nvidia-cudnn-cu12==8.9.2.26 125 | - nvidia-cufft-cu11==10.9.0.58 126 | - nvidia-cufft-cu12==11.0.2.54 127 | - nvidia-curand-cu11==10.3.0.86 128 | - nvidia-curand-cu12==10.3.2.106 129 | - nvidia-cusolver-cu11==11.4.1.48 130 | - nvidia-cusolver-cu12==11.4.5.107 131 | - nvidia-cusparse-cu11==11.7.5.86 132 | - nvidia-cusparse-cu12==12.1.0.106 133 | - nvidia-ml-py==12.555.43 134 | - nvidia-nccl-cu11==2.20.5 135 | - nvidia-nccl-cu12==2.20.5 136 | - nvidia-nvjitlink-cu12==12.5.40 137 | - nvidia-nvtx-cu11==11.8.86 138 | - nvidia-nvtx-cu12==12.1.105 139 | - omegaconf==2.3.0 140 | - opencv-python==4.10.0.82 141 | - orjson==3.10.3 142 | - packaging==24.0 143 | - pandas==2.2.2 144 | - pillow==10.2.0 145 | - platformdirs==4.2.2 146 | - pre-commit==3.7.1 147 | - proglog==0.1.10 148 | - protobuf==5.27.1 149 | - psutil==5.9.8 150 | - pydantic==2.7.3 151 | - pydantic-core==2.18.4 152 | - pydub==0.25.1 153 | - pygments==2.18.0 154 | - pyparsing==3.1.2 155 | - python-dateutil==2.9.0.post0 156 | - python-dotenv==1.0.1 157 | - python-multipart==0.0.9 158 | - pytz==2024.1 159 | - pyyaml==6.0.1 160 | - referencing==0.35.1 161 | - regex==2024.5.15 162 | - requests==2.32.3 163 | - rich==13.7.1 164 | - rpds-py==0.18.1 165 | - ruff==0.4.8 166 | - s3transfer==0.10.1 167 | - safetensors==0.4.3 168 | - semantic-version==2.10.0 169 | - sentry-sdk==2.5.0 170 | - setproctitle==1.3.3 171 | - shellingham==1.5.4 172 | - six==1.16.0 173 | - smmap==5.0.1 174 | - sniffio==1.3.1 175 | - soupsieve==2.5 176 | - starlette==0.37.2 177 | - sympy==1.12 178 | - tokenizers==0.19.1 179 | - tomlkit==0.12.0 180 | - toolz==0.12.1 181 | - torch==2.3.1+cu118 182 | - torchaudio==2.3.1+cu118 183 | - torchvision==0.18.1+cu118 184 | - tqdm==4.66.4 185 | - transformers==4.41.2 186 | - triton==2.3.1 187 | - typer==0.12.3 188 | - typing-extensions==4.9.0 189 | - tzdata==2024.1 190 | - ujson==5.10.0 191 | - urllib3==2.2.1 192 | - uvicorn==0.30.1 193 | - uvloop==0.19.0 194 | - virtualenv==20.26.2 195 | - wandb==0.17.1 196 | - watchfiles==0.22.0 197 | - wcwidth==0.2.13 198 | - websockets==11.0.3 199 | - xformers==0.0.26.post1 200 | - zipp==3.19.2 201 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | import os.path as osp 5 | 6 | import numpy as np 7 | import torch 8 | from cog import BasePredictor, Input, Path 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | 12 | from animatediff.pipelines import I2VPipeline 13 | from animatediff.utils.util import save_videos_grid 14 | 15 | 16 | N_PROMPT = ( 17 | "wrong white balance, dark, sketches,worst quality,low quality, " 18 | "deformed, distorted, disfigured, bad eyes, wrong lips, " 19 | "weird mouth, bad teeth, mutated hands and fingers, bad anatomy," 20 | "wrong anatomy, amputation, extra limb, missing limb, " 21 | "floating,limbs, disconnected limbs, mutation, ugly, disgusting, " 22 | "bad_pictures, negative_hand-neg" 23 | ) 24 | 25 | 26 | BASE_CONFIG = "example/config/base.yaml" 27 | STYLE_CONFIG_LIST = { 28 | "realistic": "example/replicate/1-realistic.yaml", 29 | "3d_cartoon": "example/replicate/3-3d.yaml", 30 | } 31 | 32 | 33 | PIA_PATH = "models/PIA" 34 | VAE_PATH = "models/VAE" 35 | DreamBooth_LoRA_PATH = "models/DreamBooth_LoRA" 36 | STABLE_DIFFUSION_PATH = "models/StableDiffusion" 37 | 38 | 39 | class Predictor(BasePredictor): 40 | def setup(self) -> None: 41 | """Load the model into memory to make running multiple predictions efficient""" 42 | 43 | self.ip_adapter_dir = "models/IP_Adapter/h94/IP-Adapter/models" # cached h94/IP-Adapter 44 | 45 | self.inference_config = OmegaConf.load("example/config/base.yaml") 46 | self.stable_diffusion_dir = self.inference_config.pretrained_model_path 47 | self.pia_path = self.inference_config.generate.model_path 48 | self.style_configs = {k: OmegaConf.load(v) for k, v in STYLE_CONFIG_LIST.items()} 49 | self.pipeline_dict = self.load_model_list() 50 | 51 | def load_model_list(self): 52 | pipeline_dict = {} 53 | for style, cfg in self.style_configs.items(): 54 | print(f"Loading {style}") 55 | dreambooth_path = cfg.get("dreambooth", "none") 56 | if dreambooth_path and dreambooth_path.upper() != "NONE": 57 | dreambooth_path = osp.join(DreamBooth_LoRA_PATH, dreambooth_path) 58 | lora_path = cfg.get("lora", None) 59 | if lora_path is not None: 60 | lora_path = osp.join(DreamBooth_LoRA_PATH, lora_path) 61 | lora_alpha = cfg.get("lora_alpha", 0.0) 62 | vae_path = cfg.get("vae", None) 63 | if vae_path is not None: 64 | vae_path = osp.join(VAE_PATH, vae_path) 65 | 66 | pipeline_dict[style] = I2VPipeline.build_pipeline( 67 | self.inference_config, 68 | STABLE_DIFFUSION_PATH, 69 | unet_path=osp.join(PIA_PATH, "pia.ckpt"), 70 | dreambooth_path=dreambooth_path, 71 | lora_path=lora_path, 72 | lora_alpha=lora_alpha, 73 | vae_path=vae_path, 74 | ip_adapter_path=self.ip_adapter_dir, 75 | ip_adapter_scale=0.1, 76 | ) 77 | return pipeline_dict 78 | 79 | def predict( 80 | self, 81 | prompt: str = Input(description="Input prompt."), 82 | image: Path = Input(description="Input image"), 83 | negative_prompt: str = Input(description="Things do not show in the output.", default=N_PROMPT), 84 | style: str = Input( 85 | description="Choose a style", 86 | choices=["3d_cartoon", "realistic"], 87 | default="3d_cartoon", 88 | ), 89 | max_size: int = Input( 90 | description="Max size (The long edge of the input image will be resized to this value, " 91 | "larger value means slower inference speed)", 92 | default=512, 93 | choices=[512, 576, 640, 704, 768, 832, 896, 960, 1024], 94 | ), 95 | motion_scale: int = Input( 96 | description="Larger value means larger motion but less identity consistency.", 97 | ge=1, 98 | le=3, 99 | default=1, 100 | ), 101 | sampling_steps: int = Input(description="Number of denoising steps", ge=10, le=100, default=25), 102 | animation_length: int = Input(description="Length of the output", ge=8, le=24, default=16), 103 | guidance_scale: float = Input( 104 | description="Scale for classifier-free guidance", 105 | ge=1.0, 106 | le=20.0, 107 | default=7.5, 108 | ), 109 | ip_adapter_scale: float = Input( 110 | description="Scale for classifier-free guidance", 111 | ge=0.0, 112 | le=1.0, 113 | default=0.0, 114 | ), 115 | seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None), 116 | ) -> Path: 117 | """Run a single prediction on the model""" 118 | if seed is None: 119 | torch.seed() 120 | seed = torch.initial_seed() 121 | else: 122 | torch.manual_seed(seed) 123 | print(f"Using seed: {seed}") 124 | 125 | pipeline = self.pipeline_dict[style] 126 | 127 | init_img, h, w = preprocess_img(str(image), max_size) 128 | 129 | sample = pipeline( 130 | image=init_img, 131 | prompt=prompt, 132 | negative_prompt=negative_prompt, 133 | num_inference_steps=sampling_steps, 134 | guidance_scale=guidance_scale, 135 | width=w, 136 | height=h, 137 | video_length=animation_length, 138 | mask_sim_template_idx=motion_scale, 139 | ip_adapter_scale=ip_adapter_scale, 140 | ).videos 141 | 142 | out_path = "/tmp/out.mp4" 143 | save_videos_grid(sample, out_path) 144 | return Path(out_path) 145 | 146 | 147 | def preprocess_img(img_np, max_size: int = 512): 148 | ori_image = Image.open(img_np).convert("RGB") 149 | 150 | width, height = ori_image.size 151 | 152 | long_edge = max(width, height) 153 | if long_edge > max_size: 154 | scale_factor = max_size / long_edge 155 | else: 156 | scale_factor = 1 157 | width = int(width * scale_factor) 158 | height = int(height * scale_factor) 159 | ori_image = ori_image.resize((width, height)) 160 | 161 | if (width % 8 != 0) or (height % 8 != 0): 162 | in_width = (width // 8) * 8 163 | in_height = (height // 8) * 8 164 | else: 165 | in_width = width 166 | in_height = height 167 | 168 | in_image = ori_image.resize((in_width, in_height)) 169 | in_image_np = np.array(in_image) 170 | return in_image_np, in_height, in_width 171 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Never enforce `E501` (line length violations). 3 | ignore = ["C901", "E501", "E741", "F402", "F823"] 4 | select = ["C", "E", "F", "I", "W"] 5 | line-length = 119 6 | 7 | # Ignore import violations in all `__init__.py` files. 8 | [tool.ruff.per-file-ignores] 9 | "__init__.py" = ["E402", "F401", "F403", "F811"] 10 | "src/diffusers/utils/dummy_*.py" = ["F401"] 11 | 12 | [tool.ruff.isort] 13 | lines-after-imports = 2 14 | known-first-party = ["diffusers"] 15 | 16 | [tool.ruff.format] 17 | # Like Black, use double quotes for strings. 18 | quote-style = "double" 19 | 20 | # Like Black, indent with spaces, rather than tabs. 21 | indent-style = "space" 22 | 23 | # Like Black, respect magic trailing commas. 24 | skip-magic-trailing-comma = false 25 | 26 | # Like Black, automatically detect the appropriate line ending. 27 | line-ending = "auto" 28 | --------------------------------------------------------------------------------