├── packages.txt ├── inputs └── applications │ ├── source_image │ ├── 0002.png │ ├── demo4.png │ ├── dalle2.jpeg │ ├── dalle8.jpeg │ ├── monalisa.png │ └── multi1_source.png │ └── driving │ └── densepose │ ├── demo4.mp4 │ ├── dancing2.mp4 │ ├── running.mp4 │ ├── running2.mp4 │ ├── multi_dancing.mp4 │ ├── .nfs006c000000039d6800000023 │ └── .nfs006c00000003a32d00000024 ├── magicanimate ├── utils │ ├── __pycache__ │ │ ├── util.cpython-38.pyc │ │ ├── dist_tools.cpython-38.pyc │ │ └── videoreader.cpython-38.pyc │ ├── dist_tools.py │ ├── util.py │ └── videoreader.py ├── models │ ├── __pycache__ │ │ ├── resnet.cpython-38.pyc │ │ ├── attention.cpython-38.pyc │ │ ├── controlnet.cpython-38.pyc │ │ ├── embeddings.cpython-38.pyc │ │ ├── motion_module.cpython-38.pyc │ │ ├── orig_attention.cpython-38.pyc │ │ ├── unet_3d_blocks.cpython-38.pyc │ │ ├── unet_controlnet.cpython-38.pyc │ │ ├── appearance_encoder.cpython-38.pyc │ │ ├── mutual_self_attention.cpython-38.pyc │ │ └── stable_diffusion_controlnet_reference.cpython-38.pyc │ ├── resnet.py │ ├── attention.py │ ├── motion_module.py │ ├── embeddings.py │ ├── unet.py │ ├── unet_controlnet.py │ ├── controlnet.py │ └── unet_3d_blocks.py └── pipelines │ ├── __pycache__ │ ├── context.cpython-38.pyc │ ├── animation.cpython-37.pyc │ ├── animation.cpython-38.pyc │ ├── dist_animation.cpython-37.pyc │ ├── dist_animation.cpython-38.pyc │ └── pipeline_animation.cpython-38.pyc │ ├── context.py │ └── animation.py ├── samples ├── animation-2023-12-05T00-24-12 │ └── videos │ │ ├── 0002_demo4.mp4 │ │ ├── demo4_demo4.mp4 │ │ ├── 0002_demo4 │ │ └── grid.mp4 │ │ ├── demo4_demo4 │ │ └── grid.mp4 │ │ ├── monalisa_running.mp4 │ │ └── monalisa_running │ │ └── grid.mp4 └── animation-2023-12-05T00-37-05 │ └── videos │ ├── monalisa_running.mp4 │ └── monalisa_running │ └── grid.mp4 ├── README.md ├── configs ├── inference │ └── inference.yaml └── prompts │ └── animation.yaml ├── LICENSE ├── .gitattributes ├── requirements.txt ├── app.py └── demo └── animate.py /packages.txt: -------------------------------------------------------------------------------- 1 | ffmpeg 2 | -------------------------------------------------------------------------------- /inputs/applications/source_image/0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/source_image/0002.png -------------------------------------------------------------------------------- /inputs/applications/source_image/demo4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/source_image/demo4.png -------------------------------------------------------------------------------- /inputs/applications/source_image/dalle2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/source_image/dalle2.jpeg -------------------------------------------------------------------------------- /inputs/applications/source_image/dalle8.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/source_image/dalle8.jpeg -------------------------------------------------------------------------------- /inputs/applications/source_image/monalisa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/source_image/monalisa.png -------------------------------------------------------------------------------- /inputs/applications/driving/densepose/demo4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/driving/densepose/demo4.mp4 -------------------------------------------------------------------------------- /inputs/applications/driving/densepose/dancing2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/driving/densepose/dancing2.mp4 -------------------------------------------------------------------------------- /inputs/applications/driving/densepose/running.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/driving/densepose/running.mp4 -------------------------------------------------------------------------------- /inputs/applications/driving/densepose/running2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/driving/densepose/running2.mp4 -------------------------------------------------------------------------------- /inputs/applications/source_image/multi1_source.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/source_image/multi1_source.png -------------------------------------------------------------------------------- /magicanimate/utils/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/utils/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /inputs/applications/driving/densepose/multi_dancing.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/driving/densepose/multi_dancing.mp4 -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/controlnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/controlnet.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/embeddings.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/embeddings.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/pipelines/__pycache__/context.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/pipelines/__pycache__/context.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/utils/__pycache__/videoreader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/utils/__pycache__/videoreader.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/motion_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/motion_module.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/orig_attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/orig_attention.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/pipelines/__pycache__/animation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/pipelines/__pycache__/animation.cpython-37.pyc -------------------------------------------------------------------------------- /magicanimate/pipelines/__pycache__/animation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/pipelines/__pycache__/animation.cpython-38.pyc -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4 -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4 -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc -------------------------------------------------------------------------------- /inputs/applications/driving/densepose/.nfs006c000000039d6800000023: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/driving/densepose/.nfs006c000000039d6800000023 -------------------------------------------------------------------------------- /inputs/applications/driving/densepose/.nfs006c00000003a32d00000024: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc -------------------------------------------------------------------------------- /magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4 -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4 -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4 -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4 -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc -------------------------------------------------------------------------------- /magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4 -------------------------------------------------------------------------------- /samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4 -------------------------------------------------------------------------------- /magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camenduru/magicanimate-hf/HEAD/magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: MagicAnimate 3 | emoji: 💃 4 | colorFrom: purple 5 | colorTo: purple 6 | sdk: gradio 7 | sdk_version: 4.4.0 8 | python_version: 3.8 9 | app_file: app.py 10 | models: 11 | - zcxu-eric/MagicAnimate 12 | - runwayml/stable-diffusion-v1-5 13 | - stabilityai/sd-vae-ft-mse 14 | pinned: false 15 | --- -------------------------------------------------------------------------------- /configs/inference/inference.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | unet_use_cross_frame_attention: false 3 | unet_use_temporal_attention: false 4 | use_motion_module: true 5 | motion_module_resolutions: 6 | - 1 7 | - 2 8 | - 4 9 | - 8 10 | motion_module_mid_block: false 11 | motion_module_decoder_only: false 12 | motion_module_type: Vanilla 13 | motion_module_kwargs: 14 | num_attention_heads: 8 15 | num_transformer_block: 1 16 | attention_block_types: 17 | - Temporal_Self 18 | - Temporal_Self 19 | temporal_position_encoding: true 20 | temporal_position_encoding_max_len: 24 21 | temporal_attention_dim_div: 1 22 | 23 | noise_scheduler_kwargs: 24 | beta_start: 0.00085 25 | beta_end: 0.012 26 | beta_schedule: "linear" 27 | -------------------------------------------------------------------------------- /configs/prompts/animation.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "stable-diffusion-v1-5" 2 | pretrained_vae_path: "sd-vae-ft-mse" 3 | pretrained_controlnet_path: "MagicAnimate/densepose_controlnet" 4 | pretrained_appearance_encoder_path: "MagicAnimate/appearance_encoder" 5 | pretrained_unet_path: "" 6 | 7 | motion_module: "MagicAnimate/temporal_attention/temporal_attention.ckpt" 8 | 9 | savename: null 10 | 11 | fusion_blocks: "midup" 12 | 13 | seed: [1] 14 | steps: 25 15 | guidance_scale: 7.5 16 | 17 | source_image: 18 | - "inputs/applications/source_image/monalisa.png" 19 | - "inputs/applications/source_image/0002.png" 20 | - "inputs/applications/source_image/demo4.png" 21 | - "inputs/applications/source_image/dalle2.jpeg" 22 | - "inputs/applications/source_image/dalle8.jpeg" 23 | - "inputs/applications/source_image/multi1_source.png" 24 | video_path: 25 | - "inputs/applications/driving/densepose/running.mp4" 26 | - "inputs/applications/driving/densepose/demo4.mp4" 27 | - "inputs/applications/driving/densepose/demo4.mp4" 28 | - "inputs/applications/driving/densepose/running2.mp4" 29 | - "inputs/applications/driving/densepose/dancing2.mp4" 30 | - "inputs/applications/driving/densepose/multi_dancing.mp4" 31 | 32 | inference_config: "configs/inference/inference.yaml" 33 | size: 512 34 | L: 16 35 | S: 1 36 | I: 0 37 | clip: 0 38 | offset: 0 39 | max_length: null 40 | video_type: "condition" 41 | invert_video: false 42 | save_individual_videos: false 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright 2023 MagicAnimate Team All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /magicanimate/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main 8 | import numpy as np 9 | from typing import Callable, Optional, List 10 | 11 | 12 | def ordered_halving(val): 13 | bin_str = f"{val:064b}" 14 | bin_flip = bin_str[::-1] 15 | as_int = int(bin_flip, 2) 16 | 17 | return as_int / (1 << 64) 18 | 19 | 20 | def uniform( 21 | step: int = ..., 22 | num_steps: Optional[int] = None, 23 | num_frames: int = ..., 24 | context_size: Optional[int] = None, 25 | context_stride: int = 3, 26 | context_overlap: int = 4, 27 | closed_loop: bool = True, 28 | ): 29 | if num_frames <= context_size: 30 | yield list(range(num_frames)) 31 | return 32 | 33 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) 34 | 35 | for context_step in 1 << np.arange(context_stride): 36 | pad = int(round(num_frames * ordered_halving(step))) 37 | for j in range( 38 | int(ordered_halving(step) * context_step) + pad, 39 | num_frames + pad + (0 if closed_loop else -context_overlap), 40 | (context_size * context_step - context_overlap), 41 | ): 42 | yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] 43 | 44 | 45 | def get_context_scheduler(name: str) -> Callable: 46 | if name == "uniform": 47 | return uniform 48 | else: 49 | raise ValueError(f"Unknown context_overlap policy {name}") 50 | 51 | 52 | def get_total_steps( 53 | scheduler, 54 | timesteps: List[int], 55 | num_steps: Optional[int] = None, 56 | num_frames: int = ..., 57 | context_size: Optional[int] = None, 58 | context_stride: int = 3, 59 | context_overlap: int = 4, 60 | closed_loop: bool = True, 61 | ): 62 | return sum( 63 | len( 64 | list( 65 | scheduler( 66 | i, 67 | num_steps, 68 | num_frames, 69 | context_size, 70 | context_stride, 71 | context_overlap, 72 | ) 73 | ) 74 | ) 75 | for i in range(len(timesteps)) 76 | ) 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.22.0 3 | aiofiles==23.2.1 4 | aiohttp==3.8.5 5 | aiosignal==1.3.1 6 | altair==5.0.1 7 | annotated-types==0.5.0 8 | antlr4-python3-runtime==4.9.3 9 | anyio==3.7.1 10 | async-timeout==4.0.3 11 | attrs==23.1.0 12 | cachetools==5.3.1 13 | certifi==2023.7.22 14 | charset-normalizer==3.2.0 15 | click==8.1.7 16 | cmake==3.27.2 17 | contourpy==1.1.0 18 | cycler==0.11.0 19 | datasets==2.14.4 20 | dill==0.3.7 21 | einops==0.6.1 22 | exceptiongroup==1.1.3 23 | fastapi==0.103.0 24 | ffmpy==0.3.1 25 | filelock==3.12.2 26 | fonttools==4.42.1 27 | frozenlist==1.4.0 28 | fsspec==2023.6.0 29 | google-auth==2.22.0 30 | google-auth-oauthlib==1.0.0 31 | gradio==3.41.2 32 | gradio-client==0.5.0 33 | grpcio==1.57.0 34 | h11==0.14.0 35 | httpcore==0.17.3 36 | httpx==0.24.1 37 | huggingface-hub==0.16.4 38 | idna==3.4 39 | importlib-metadata==6.8.0 40 | importlib-resources==6.0.1 41 | jinja2==3.1.2 42 | joblib==1.3.2 43 | jsonschema==4.19.0 44 | jsonschema-specifications==2023.7.1 45 | kiwisolver==1.4.5 46 | lightning-utilities==0.9.0 47 | lit==16.0.6 48 | markdown==3.4.4 49 | markupsafe==2.1.3 50 | matplotlib==3.7.2 51 | mpmath==1.3.0 52 | multidict==6.0.4 53 | multiprocess==0.70.15 54 | networkx==3.1 55 | numpy==1.24.4 56 | nvidia-cublas-cu11==11.10.3.66 57 | nvidia-cuda-cupti-cu11==11.7.101 58 | nvidia-cuda-nvrtc-cu11==11.7.99 59 | nvidia-cuda-runtime-cu11==11.7.99 60 | nvidia-cudnn-cu11==8.5.0.96 61 | nvidia-cufft-cu11==10.9.0.58 62 | nvidia-curand-cu11==10.2.10.91 63 | nvidia-cusolver-cu11==11.4.0.1 64 | nvidia-cusparse-cu11==11.7.4.91 65 | nvidia-nccl-cu11==2.14.3 66 | nvidia-nvtx-cu11==11.7.91 67 | oauthlib==3.2.2 68 | omegaconf==2.3.0 69 | opencv-python==4.8.0.76 70 | orjson==3.9.5 71 | pandas==2.0.3 72 | pillow==9.5.0 73 | pkgutil-resolve-name==1.3.10 74 | protobuf==4.24.2 75 | psutil==5.9.5 76 | pyarrow==13.0.0 77 | pyasn1==0.5.0 78 | pyasn1-modules==0.3.0 79 | pydantic==2.3.0 80 | pydantic-core==2.6.3 81 | pydub==0.25.1 82 | pyparsing==3.0.9 83 | python-multipart==0.0.6 84 | pytorch-lightning==2.0.7 85 | pytz==2023.3 86 | pyyaml==6.0.1 87 | referencing==0.30.2 88 | regex==2023.8.8 89 | requests==2.31.0 90 | requests-oauthlib==1.3.1 91 | rpds-py==0.9.2 92 | rsa==4.9 93 | safetensors==0.3.3 94 | semantic-version==2.10.0 95 | sniffio==1.3.0 96 | starlette==0.27.0 97 | sympy==1.12 98 | tensorboard==2.14.0 99 | tensorboard-data-server==0.7.1 100 | tokenizers==0.13.3 101 | toolz==0.12.0 102 | torchmetrics==1.1.0 103 | tqdm==4.66.1 104 | transformers==4.32.0 105 | triton==2.0.0 106 | tzdata==2023.3 107 | urllib3==1.26.16 108 | uvicorn==0.23.2 109 | websockets==11.0.3 110 | werkzeug==2.3.7 111 | xxhash==3.3.0 112 | yarl==1.9.2 113 | zipp==3.16.2 114 | decord 115 | imageio==2.9.0 116 | imageio-ffmpeg==0.4.3 117 | timm 118 | scipy 119 | scikit-image 120 | av 121 | imgaug 122 | lpips 123 | ffmpeg-python 124 | torch==2.0.1 125 | torchvision==0.15.2 126 | xformers==0.0.22 127 | diffusers==0.21.4 128 | -------------------------------------------------------------------------------- /magicanimate/utils/dist_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ByteDance and/or its affiliates. 2 | # 3 | # Copyright (2023) MagicAnimate Authors 4 | # 5 | # ByteDance, its affiliates and licensors retain all intellectual 6 | # property and proprietary rights in and to this material, related 7 | # documentation and any modifications thereto. Any use, reproduction, 8 | # disclosure or distribution of this material and related documentation 9 | # without an express license agreement from ByteDance or 10 | # its affiliates is strictly prohibited. 11 | import os 12 | import socket 13 | import warnings 14 | import torch 15 | from torch import distributed as dist 16 | 17 | 18 | def distributed_init(args): 19 | 20 | if dist.is_initialized(): 21 | warnings.warn("Distributed is already initialized, cannot initialize twice!") 22 | args.rank = dist.get_rank() 23 | else: 24 | print( 25 | f"Distributed Init (Rank {args.rank}): " 26 | f"{args.init_method}" 27 | ) 28 | dist.init_process_group( 29 | backend='nccl', 30 | init_method=args.init_method, 31 | world_size=args.world_size, 32 | rank=args.rank, 33 | ) 34 | print( 35 | f"Initialized Host {socket.gethostname()} as Rank " 36 | f"{args.rank}" 37 | ) 38 | 39 | if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ: 40 | # Set for onboxdataloader support 41 | split = args.init_method.split("//") 42 | assert len(split) == 2, ( 43 | "host url for distributed should be split by '//' " 44 | + "into exactly two elements" 45 | ) 46 | 47 | split = split[1].split(":") 48 | assert ( 49 | len(split) == 2 50 | ), "host url should be of the form :" 51 | os.environ["MASTER_ADDR"] = split[0] 52 | os.environ["MASTER_PORT"] = split[1] 53 | 54 | # perform a dummy all-reduce to initialize the NCCL communicator 55 | dist.all_reduce(torch.zeros(1).cuda()) 56 | 57 | suppress_output(is_master()) 58 | args.rank = dist.get_rank() 59 | return args.rank 60 | 61 | 62 | def get_rank(): 63 | if not dist.is_available(): 64 | return 0 65 | if not dist.is_nccl_available(): 66 | return 0 67 | if not dist.is_initialized(): 68 | return 0 69 | return dist.get_rank() 70 | 71 | 72 | def is_master(): 73 | return get_rank() == 0 74 | 75 | 76 | def synchronize(): 77 | if dist.is_initialized(): 78 | dist.barrier() 79 | 80 | 81 | def suppress_output(is_master): 82 | """Suppress printing on the current device. Force printing with `force=True`.""" 83 | import builtins as __builtin__ 84 | 85 | builtin_print = __builtin__.print 86 | 87 | def print(*args, **kwargs): 88 | force = kwargs.pop("force", False) 89 | if is_master or force: 90 | builtin_print(*args, **kwargs) 91 | 92 | __builtin__.print = print 93 | 94 | import warnings 95 | 96 | builtin_warn = warnings.warn 97 | 98 | def warn(*args, **kwargs): 99 | force = kwargs.pop("force", False) 100 | if is_master or force: 101 | builtin_warn(*args, **kwargs) 102 | 103 | # Log warnings only once 104 | warnings.warn = warn 105 | warnings.simplefilter("once", UserWarning) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ByteDance and/or its affiliates. 2 | # 3 | # Copyright (2023) MagicAnimate Authors 4 | # 5 | # ByteDance, its affiliates and licensors retain all intellectual 6 | # property and proprietary rights in and to this material, related 7 | # documentation and any modifications thereto. Any use, reproduction, 8 | # disclosure or distribution of this material and related documentation 9 | # without an express license agreement from ByteDance or 10 | # its affiliates is strictly prohibited. 11 | import argparse 12 | import imageio 13 | import numpy as np 14 | import gradio as gr 15 | from PIL import Image 16 | from subprocess import PIPE, run 17 | 18 | from demo.animate import MagicAnimate 19 | 20 | from huggingface_hub import snapshot_download 21 | 22 | snapshot_download(repo_id="runwayml/stable-diffusion-v1-5", local_dir="./stable-diffusion-v1-5") 23 | snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./sd-vae-ft-mse") 24 | snapshot_download(repo_id="zcxu-eric/MagicAnimate", local_dir="./MagicAnimate") 25 | 26 | animator = MagicAnimate() 27 | 28 | def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale): 29 | return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale) 30 | 31 | with gr.Blocks() as demo: 32 | 33 | gr.HTML( 34 | """ 35 |
36 |

37 | MagicAnimate: Temporally Consistent Human Image Animation 38 |

39 |
40 |

41 | Project page | 42 | GitHub | 43 | arXiv 44 |

45 |
46 | """) 47 | animation = gr.Video(format="mp4", label="Animation Results", autoplay=True) 48 | 49 | with gr.Row(): 50 | reference_image = gr.Image(label="Reference Image") 51 | motion_sequence = gr.Video(format="mp4", label="Motion Sequence") 52 | 53 | with gr.Column(): 54 | random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1") 55 | sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25") 56 | guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5") 57 | submit = gr.Button("Animate") 58 | 59 | def read_video(video): 60 | size = int(size) 61 | reader = imageio.get_reader(video) 62 | fps = reader.get_meta_data()['fps'] 63 | assert fps == 25.0, f'Expected video fps: 25, but {fps} fps found' 64 | return video 65 | 66 | def read_image(image, size=512): 67 | return np.array(Image.fromarray(image).resize((size, size))) 68 | 69 | # when user uploads a new video 70 | motion_sequence.upload( 71 | read_video, 72 | motion_sequence, 73 | motion_sequence 74 | ) 75 | # when `first_frame` is updated 76 | reference_image.upload( 77 | read_image, 78 | reference_image, 79 | reference_image 80 | ) 81 | # when the `submit` button is clicked 82 | submit.click( 83 | animate, 84 | [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], 85 | animation 86 | ) 87 | 88 | # Examples 89 | gr.Markdown("## Examples") 90 | gr.Examples( 91 | examples=[ 92 | ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"], 93 | ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"], 94 | ["inputs/applications/source_image/0002.png", "inputs/applications/driving/densepose/demo4.mp4"], 95 | ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"], 96 | ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"], 97 | ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"], 98 | ], 99 | inputs=[reference_image, motion_sequence], 100 | outputs=animation 101 | ) 102 | 103 | 104 | demo.launch(share=True) -------------------------------------------------------------------------------- /magicanimate/utils/util.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | import os 9 | import imageio 10 | import numpy as np 11 | 12 | import torch 13 | import torchvision 14 | 15 | from PIL import Image 16 | from typing import Union 17 | from tqdm import tqdm 18 | from einops import rearrange 19 | 20 | 21 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25): 22 | videos = rearrange(videos, "b c t h w -> t b c h w") 23 | outputs = [] 24 | for x in videos: 25 | x = torchvision.utils.make_grid(x, nrow=n_rows) 26 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 27 | if rescale: 28 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 29 | x = (x * 255).numpy().astype(np.uint8) 30 | outputs.append(x) 31 | 32 | os.makedirs(os.path.dirname(path), exist_ok=True) 33 | imageio.mimsave(path, outputs, fps=fps) 34 | 35 | def save_images_grid(images: torch.Tensor, path: str): 36 | assert images.shape[2] == 1 # no time dimension 37 | images = images.squeeze(2) 38 | grid = torchvision.utils.make_grid(images) 39 | grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8) 40 | os.makedirs(os.path.dirname(path), exist_ok=True) 41 | Image.fromarray(grid).save(path) 42 | 43 | # DDIM Inversion 44 | @torch.no_grad() 45 | def init_prompt(prompt, pipeline): 46 | uncond_input = pipeline.tokenizer( 47 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 48 | return_tensors="pt" 49 | ) 50 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 51 | text_input = pipeline.tokenizer( 52 | [prompt], 53 | padding="max_length", 54 | max_length=pipeline.tokenizer.model_max_length, 55 | truncation=True, 56 | return_tensors="pt", 57 | ) 58 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 59 | context = torch.cat([uncond_embeddings, text_embeddings]) 60 | 61 | return context 62 | 63 | 64 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 65 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 66 | timestep, next_timestep = min( 67 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 68 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 69 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 70 | beta_prod_t = 1 - alpha_prod_t 71 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 72 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 73 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 74 | return next_sample 75 | 76 | 77 | def get_noise_pred_single(latents, t, context, unet): 78 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 79 | return noise_pred 80 | 81 | 82 | @torch.no_grad() 83 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 84 | context = init_prompt(prompt, pipeline) 85 | uncond_embeddings, cond_embeddings = context.chunk(2) 86 | all_latent = [latent] 87 | latent = latent.clone().detach() 88 | for i in tqdm(range(num_inv_steps)): 89 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 90 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 91 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 92 | all_latent.append(latent) 93 | return all_latent 94 | 95 | 96 | @torch.no_grad() 97 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 98 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 99 | return ddim_latents 100 | 101 | 102 | def video2images(path, step=4, length=16, start=0): 103 | reader = imageio.get_reader(path) 104 | frames = [] 105 | for frame in reader: 106 | frames.append(np.array(frame)) 107 | frames = frames[start::step][:length] 108 | return frames 109 | 110 | 111 | def images2video(video, path, fps=8): 112 | imageio.mimsave(path, video, fps=fps) 113 | return 114 | 115 | 116 | tensor_interpolation = None 117 | 118 | def get_tensor_interpolation_method(): 119 | return tensor_interpolation 120 | 121 | def set_tensor_interpolation_method(is_slerp): 122 | global tensor_interpolation 123 | tensor_interpolation = slerp if is_slerp else linear 124 | 125 | def linear(v1, v2, t): 126 | return (1.0 - t) * v1 + t * v2 127 | 128 | def slerp( 129 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 130 | ) -> torch.Tensor: 131 | u0 = v0 / v0.norm() 132 | u1 = v1 / v1.norm() 133 | dot = (u0 * u1).sum() 134 | if dot.abs() > DOT_THRESHOLD: 135 | #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 136 | return (1.0 - t) * v0 + t * v1 137 | omega = dot.acos() 138 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() -------------------------------------------------------------------------------- /magicanimate/utils/videoreader.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2022 ByteDance and/or its affiliates. 8 | # 9 | # Copyright (2022) PV3D Authors 10 | # 11 | # ByteDance, its affiliates and licensors retain all intellectual 12 | # property and proprietary rights in and to this material, related 13 | # documentation and any modifications thereto. Any use, reproduction, 14 | # disclosure or distribution of this material and related documentation 15 | # without an express license agreement from ByteDance or 16 | # its affiliates is strictly prohibited. 17 | import av, gc 18 | import torch 19 | import warnings 20 | import numpy as np 21 | 22 | 23 | _CALLED_TIMES = 0 24 | _GC_COLLECTION_INTERVAL = 20 25 | 26 | 27 | # remove warnings 28 | av.logging.set_level(av.logging.ERROR) 29 | 30 | 31 | class VideoReader(): 32 | """ 33 | Simple wrapper around PyAV that exposes a few useful functions for 34 | dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries. 35 | Acknowledgement: Codes are borrowed from Bruno Korbar 36 | """ 37 | def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False): 38 | """ 39 | Arguments: 40 | video_path (str): path or byte of the video to be loaded 41 | """ 42 | self.container = av.open(video) 43 | self.num_frames = num_frames 44 | self.bi_frame = bi_frame 45 | 46 | self.resampler = None 47 | if audio_resample_rate is not None: 48 | self.resampler = av.AudioResampler(rate=audio_resample_rate) 49 | 50 | if self.container.streams.video: 51 | # enable multi-threaded video decoding 52 | if decode_lossy: 53 | warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning) 54 | self.container.streams.video[0].thread_type = 'AUTO' 55 | self.video_stream = self.container.streams.video[0] 56 | else: 57 | self.video_stream = None 58 | 59 | self.fps = self._get_video_frame_rate() 60 | 61 | def seek(self, pts, backward=True, any_frame=False): 62 | stream = self.video_stream 63 | self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream) 64 | 65 | def _occasional_gc(self): 66 | # there are a lot of reference cycles in PyAV, so need to manually call 67 | # the garbage collector from time to time 68 | global _CALLED_TIMES, _GC_COLLECTION_INTERVAL 69 | _CALLED_TIMES += 1 70 | if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: 71 | gc.collect() 72 | 73 | def _read_video(self, offset): 74 | self._occasional_gc() 75 | 76 | pts = self.container.duration * offset 77 | time_ = pts / float(av.time_base) 78 | self.container.seek(int(pts)) 79 | 80 | video_frames = [] 81 | count = 0 82 | for _, frame in enumerate(self._iter_frames()): 83 | if frame.pts * frame.time_base >= time_: 84 | video_frames.append(frame) 85 | if count >= self.num_frames - 1: 86 | break 87 | count += 1 88 | return video_frames 89 | 90 | def _iter_frames(self): 91 | for packet in self.container.demux(self.video_stream): 92 | for frame in packet.decode(): 93 | yield frame 94 | 95 | def _compute_video_stats(self): 96 | if self.video_stream is None or self.container is None: 97 | return 0 98 | num_of_frames = self.container.streams.video[0].frames 99 | if num_of_frames == 0: 100 | num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base) 101 | self.seek(0, backward=False) 102 | count = 0 103 | time_base = 512 104 | for p in self.container.decode(video=0): 105 | count = count + 1 106 | if count == 1: 107 | start_pts = p.pts 108 | elif count == 2: 109 | time_base = p.pts - start_pts 110 | break 111 | return start_pts, time_base, num_of_frames 112 | 113 | def _get_video_frame_rate(self): 114 | return float(self.container.streams.video[0].guessed_rate) 115 | 116 | def sample(self, debug=False): 117 | 118 | if self.container is None: 119 | raise RuntimeError('video stream not found') 120 | sample = dict() 121 | _, _, total_num_frames = self._compute_video_stats() 122 | offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item() 123 | video_frames = self._read_video(offset/total_num_frames) 124 | video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) 125 | sample["frames"] = video_frames 126 | sample["frame_idx"] = [offset] 127 | 128 | if self.bi_frame: 129 | frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)] 130 | frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)] 131 | frames.sort() 132 | video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]]) 133 | Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)] 134 | sample["frames"] = video_frames 135 | sample["real_t"] = torch.tensor(Ts, dtype=torch.float32) 136 | sample["frame_idx"] = [offset+min(frames), offset+max(frames)] 137 | return sample 138 | 139 | return sample 140 | 141 | def read_frames(self, frame_indices): 142 | self.num_frames = frame_indices[1] - frame_indices[0] 143 | video_frames = self._read_video(frame_indices[0]/self.get_num_frames()) 144 | video_frames = np.array([ 145 | np.uint8(video_frames[0].to_rgb().to_ndarray()), 146 | np.uint8(video_frames[-1].to_rgb().to_ndarray()) 147 | ]) 148 | return video_frames 149 | 150 | def read(self): 151 | video_frames = self._read_video(0) 152 | video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) 153 | return video_frames 154 | 155 | def get_num_frames(self): 156 | _, _, total_num_frames = self._compute_video_stats() 157 | return total_num_frames -------------------------------------------------------------------------------- /magicanimate/models/resnet.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | 9 | # Copyright 2023 The HuggingFace Team. All rights reserved. 10 | # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | from einops import rearrange 28 | 29 | 30 | class InflatedConv3d(nn.Conv2d): 31 | def forward(self, x): 32 | video_length = x.shape[2] 33 | 34 | x = rearrange(x, "b c f h w -> (b f) c h w") 35 | x = super().forward(x) 36 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 37 | 38 | return x 39 | 40 | 41 | class Upsample3D(nn.Module): 42 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 43 | super().__init__() 44 | self.channels = channels 45 | self.out_channels = out_channels or channels 46 | self.use_conv = use_conv 47 | self.use_conv_transpose = use_conv_transpose 48 | self.name = name 49 | 50 | conv = None 51 | if use_conv_transpose: 52 | raise NotImplementedError 53 | elif use_conv: 54 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 55 | 56 | def forward(self, hidden_states, output_size=None): 57 | assert hidden_states.shape[1] == self.channels 58 | 59 | if self.use_conv_transpose: 60 | raise NotImplementedError 61 | 62 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 63 | dtype = hidden_states.dtype 64 | if dtype == torch.bfloat16: 65 | hidden_states = hidden_states.to(torch.float32) 66 | 67 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 68 | if hidden_states.shape[0] >= 64: 69 | hidden_states = hidden_states.contiguous() 70 | 71 | # if `output_size` is passed we force the interpolation output 72 | # size and do not make use of `scale_factor=2` 73 | if output_size is None: 74 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 75 | else: 76 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 77 | 78 | # If the input is bfloat16, we cast back to bfloat16 79 | if dtype == torch.bfloat16: 80 | hidden_states = hidden_states.to(dtype) 81 | 82 | hidden_states = self.conv(hidden_states) 83 | 84 | return hidden_states 85 | 86 | 87 | class Downsample3D(nn.Module): 88 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 89 | super().__init__() 90 | self.channels = channels 91 | self.out_channels = out_channels or channels 92 | self.use_conv = use_conv 93 | self.padding = padding 94 | stride = 2 95 | self.name = name 96 | 97 | if use_conv: 98 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 99 | else: 100 | raise NotImplementedError 101 | 102 | def forward(self, hidden_states): 103 | assert hidden_states.shape[1] == self.channels 104 | if self.use_conv and self.padding == 0: 105 | raise NotImplementedError 106 | 107 | assert hidden_states.shape[1] == self.channels 108 | hidden_states = self.conv(hidden_states) 109 | 110 | return hidden_states 111 | 112 | 113 | class ResnetBlock3D(nn.Module): 114 | def __init__( 115 | self, 116 | *, 117 | in_channels, 118 | out_channels=None, 119 | conv_shortcut=False, 120 | dropout=0.0, 121 | temb_channels=512, 122 | groups=32, 123 | groups_out=None, 124 | pre_norm=True, 125 | eps=1e-6, 126 | non_linearity="swish", 127 | time_embedding_norm="default", 128 | output_scale_factor=1.0, 129 | use_in_shortcut=None, 130 | ): 131 | super().__init__() 132 | self.pre_norm = pre_norm 133 | self.pre_norm = True 134 | self.in_channels = in_channels 135 | out_channels = in_channels if out_channels is None else out_channels 136 | self.out_channels = out_channels 137 | self.use_conv_shortcut = conv_shortcut 138 | self.time_embedding_norm = time_embedding_norm 139 | self.output_scale_factor = output_scale_factor 140 | 141 | if groups_out is None: 142 | groups_out = groups 143 | 144 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 145 | 146 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 147 | 148 | if temb_channels is not None: 149 | if self.time_embedding_norm == "default": 150 | time_emb_proj_out_channels = out_channels 151 | elif self.time_embedding_norm == "scale_shift": 152 | time_emb_proj_out_channels = out_channels * 2 153 | else: 154 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 155 | 156 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 157 | else: 158 | self.time_emb_proj = None 159 | 160 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 161 | self.dropout = torch.nn.Dropout(dropout) 162 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 163 | 164 | if non_linearity == "swish": 165 | self.nonlinearity = lambda x: F.silu(x) 166 | elif non_linearity == "mish": 167 | self.nonlinearity = Mish() 168 | elif non_linearity == "silu": 169 | self.nonlinearity = nn.SiLU() 170 | 171 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 172 | 173 | self.conv_shortcut = None 174 | if self.use_in_shortcut: 175 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 176 | 177 | def forward(self, input_tensor, temb): 178 | hidden_states = input_tensor 179 | 180 | hidden_states = self.norm1(hidden_states) 181 | hidden_states = self.nonlinearity(hidden_states) 182 | 183 | hidden_states = self.conv1(hidden_states) 184 | 185 | if temb is not None: 186 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 187 | 188 | if temb is not None and self.time_embedding_norm == "default": 189 | hidden_states = hidden_states + temb 190 | 191 | hidden_states = self.norm2(hidden_states) 192 | 193 | if temb is not None and self.time_embedding_norm == "scale_shift": 194 | scale, shift = torch.chunk(temb, 2, dim=1) 195 | hidden_states = hidden_states * (1 + scale) + shift 196 | 197 | hidden_states = self.nonlinearity(hidden_states) 198 | 199 | hidden_states = self.dropout(hidden_states) 200 | hidden_states = self.conv2(hidden_states) 201 | 202 | if self.conv_shortcut is not None: 203 | input_tensor = self.conv_shortcut(input_tensor) 204 | 205 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 206 | 207 | return output_tensor 208 | 209 | 210 | class Mish(torch.nn.Module): 211 | def forward(self, hidden_states): 212 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /demo/animate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ByteDance and/or its affiliates. 2 | # 3 | # Copyright (2023) MagicAnimate Authors 4 | # 5 | # ByteDance, its affiliates and licensors retain all intellectual 6 | # property and proprietary rights in and to this material, related 7 | # documentation and any modifications thereto. Any use, reproduction, 8 | # disclosure or distribution of this material and related documentation 9 | # without an express license agreement from ByteDance or 10 | # its affiliates is strictly prohibited. 11 | import argparse 12 | import argparse 13 | import datetime 14 | import inspect 15 | import os 16 | import numpy as np 17 | from PIL import Image 18 | from omegaconf import OmegaConf 19 | from collections import OrderedDict 20 | 21 | import torch 22 | 23 | from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler 24 | 25 | from tqdm import tqdm 26 | from transformers import CLIPTextModel, CLIPTokenizer 27 | 28 | from magicanimate.models.unet_controlnet import UNet3DConditionModel 29 | from magicanimate.models.controlnet import ControlNetModel 30 | from magicanimate.models.appearance_encoder import AppearanceEncoderModel 31 | from magicanimate.models.mutual_self_attention import ReferenceAttentionControl 32 | from magicanimate.pipelines.pipeline_animation import AnimationPipeline 33 | from magicanimate.utils.util import save_videos_grid 34 | from accelerate.utils import set_seed 35 | 36 | from magicanimate.utils.videoreader import VideoReader 37 | 38 | from einops import rearrange, repeat 39 | 40 | import csv, pdb, glob 41 | from safetensors import safe_open 42 | import math 43 | from pathlib import Path 44 | 45 | class MagicAnimate(): 46 | def __init__(self, config="configs/prompts/animation.yaml") -> None: 47 | print("Initializing MagicAnimate Pipeline...") 48 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 49 | func_args = dict(func_args) 50 | 51 | config = OmegaConf.load(config) 52 | 53 | inference_config = OmegaConf.load(config.inference_config) 54 | 55 | motion_module = config.motion_module 56 | 57 | ### >>> create animation pipeline >>> ### 58 | tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") 59 | text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") 60 | if config.pretrained_unet_path: 61 | unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 62 | else: 63 | unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 64 | self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda() 65 | self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) 66 | self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) 67 | if config.pretrained_vae_path is not None: 68 | vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) 69 | else: 70 | vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") 71 | 72 | ### Load controlnet 73 | controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) 74 | 75 | vae.to(torch.float16) 76 | unet.to(torch.float16) 77 | text_encoder.to(torch.float16) 78 | controlnet.to(torch.float16) 79 | self.appearance_encoder.to(torch.float16) 80 | 81 | unet.enable_xformers_memory_efficient_attention() 82 | self.appearance_encoder.enable_xformers_memory_efficient_attention() 83 | controlnet.enable_xformers_memory_efficient_attention() 84 | 85 | self.pipeline = AnimationPipeline( 86 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, 87 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), 88 | # NOTE: UniPCMultistepScheduler 89 | ).to("cuda") 90 | 91 | # 1. unet ckpt 92 | # 1.1 motion module 93 | motion_module_state_dict = torch.load(motion_module, map_location="cpu") 94 | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) 95 | motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict 96 | try: 97 | # extra steps for self-trained models 98 | state_dict = OrderedDict() 99 | for key in motion_module_state_dict.keys(): 100 | if key.startswith("module."): 101 | _key = key.split("module.")[-1] 102 | state_dict[_key] = motion_module_state_dict[key] 103 | else: 104 | state_dict[key] = motion_module_state_dict[key] 105 | motion_module_state_dict = state_dict 106 | del state_dict 107 | missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) 108 | assert len(unexpected) == 0 109 | except: 110 | _tmp_ = OrderedDict() 111 | for key in motion_module_state_dict.keys(): 112 | if "motion_modules" in key: 113 | if key.startswith("unet."): 114 | _key = key.split('unet.')[-1] 115 | _tmp_[_key] = motion_module_state_dict[key] 116 | else: 117 | _tmp_[key] = motion_module_state_dict[key] 118 | missing, unexpected = unet.load_state_dict(_tmp_, strict=False) 119 | assert len(unexpected) == 0 120 | del _tmp_ 121 | del motion_module_state_dict 122 | 123 | self.pipeline.to("cuda") 124 | self.L = config.L 125 | 126 | print("Initialization Done!") 127 | 128 | def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512): 129 | prompt = n_prompt = "" 130 | random_seed = int(random_seed) 131 | step = int(step) 132 | guidance_scale = float(guidance_scale) 133 | samples_per_video = [] 134 | # manually set random seed for reproduction 135 | if random_seed != -1: 136 | torch.manual_seed(random_seed) 137 | set_seed(random_seed) 138 | else: 139 | torch.seed() 140 | 141 | if motion_sequence.endswith('.mp4'): 142 | control = VideoReader(motion_sequence).read() 143 | if control[0].shape[0] != size: 144 | control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] 145 | control = np.array(control) 146 | 147 | if source_image.shape[0] != size: 148 | source_image = np.array(Image.fromarray(source_image).resize((size, size))) 149 | H, W, C = source_image.shape 150 | 151 | init_latents = None 152 | original_length = control.shape[0] 153 | if control.shape[0] % self.L > 0: 154 | control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge') 155 | generator = torch.Generator(device=torch.device("cuda:0")) 156 | generator.manual_seed(torch.initial_seed()) 157 | sample = self.pipeline( 158 | prompt, 159 | negative_prompt = n_prompt, 160 | num_inference_steps = step, 161 | guidance_scale = guidance_scale, 162 | width = W, 163 | height = H, 164 | video_length = len(control), 165 | controlnet_condition = control, 166 | init_latents = init_latents, 167 | generator = generator, 168 | appearance_encoder = self.appearance_encoder, 169 | reference_control_writer = self.reference_control_writer, 170 | reference_control_reader = self.reference_control_reader, 171 | source_image = source_image, 172 | ).videos 173 | 174 | source_images = np.array([source_image] * original_length) 175 | source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 176 | samples_per_video.append(source_images) 177 | 178 | control = control / 255.0 179 | control = rearrange(control, "t h w c -> 1 c t h w") 180 | control = torch.from_numpy(control) 181 | samples_per_video.append(control[:, :, :original_length]) 182 | 183 | samples_per_video.append(sample[:, :, :original_length]) 184 | 185 | samples_per_video = torch.cat(samples_per_video) 186 | 187 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 188 | savedir = f"demo/outputs" 189 | animation_path = f"{savedir}/{time_str}.mp4" 190 | 191 | os.makedirs(savedir, exist_ok=True) 192 | save_videos_grid(samples_per_video, animation_path) 193 | 194 | return animation_path 195 | -------------------------------------------------------------------------------- /magicanimate/pipelines/animation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ByteDance and/or its affiliates. 2 | # 3 | # Copyright (2023) MagicAnimate Authors 4 | # 5 | # ByteDance, its affiliates and licensors retain all intellectual 6 | # property and proprietary rights in and to this material, related 7 | # documentation and any modifications thereto. Any use, reproduction, 8 | # disclosure or distribution of this material and related documentation 9 | # without an express license agreement from ByteDance or 10 | # its affiliates is strictly prohibited. 11 | import argparse 12 | import datetime 13 | import inspect 14 | import os 15 | import random 16 | import numpy as np 17 | 18 | from PIL import Image 19 | from omegaconf import OmegaConf 20 | from collections import OrderedDict 21 | 22 | import torch 23 | import torch.distributed as dist 24 | 25 | from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler 26 | 27 | from tqdm import tqdm 28 | from transformers import CLIPTextModel, CLIPTokenizer 29 | 30 | from magicanimate.models.unet_controlnet import UNet3DConditionModel 31 | from magicanimate.models.controlnet import ControlNetModel 32 | from magicanimate.models.appearance_encoder import AppearanceEncoderModel 33 | from magicanimate.models.mutual_self_attention import ReferenceAttentionControl 34 | from magicanimate.pipelines.pipeline_animation import AnimationPipeline 35 | from magicanimate.utils.util import save_videos_grid 36 | from magicanimate.utils.dist_tools import distributed_init 37 | from accelerate.utils import set_seed 38 | 39 | from magicanimate.utils.videoreader import VideoReader 40 | 41 | from einops import rearrange 42 | 43 | from pathlib import Path 44 | 45 | 46 | def main(args): 47 | 48 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 49 | func_args = dict(func_args) 50 | 51 | config = OmegaConf.load(args.config) 52 | 53 | # Initialize distributed training 54 | device = torch.device(f"cuda:{args.rank}") 55 | dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} 56 | 57 | if config.savename is None: 58 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 59 | savedir = f"samples/{Path(args.config).stem}-{time_str}" 60 | else: 61 | savedir = f"samples/{config.savename}" 62 | 63 | if args.dist: 64 | dist.broadcast_object_list([savedir], 0) 65 | dist.barrier() 66 | 67 | if args.rank == 0: 68 | os.makedirs(savedir, exist_ok=True) 69 | 70 | inference_config = OmegaConf.load(config.inference_config) 71 | 72 | motion_module = config.motion_module 73 | 74 | ### >>> create animation pipeline >>> ### 75 | tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") 76 | text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") 77 | if config.pretrained_unet_path: 78 | unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 79 | else: 80 | unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 81 | appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) 82 | reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) 83 | reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) 84 | if config.pretrained_vae_path is not None: 85 | vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) 86 | else: 87 | vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") 88 | 89 | ### Load controlnet 90 | controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) 91 | 92 | unet.enable_xformers_memory_efficient_attention() 93 | appearance_encoder.enable_xformers_memory_efficient_attention() 94 | controlnet.enable_xformers_memory_efficient_attention() 95 | 96 | vae.to(torch.float16) 97 | unet.to(torch.float16) 98 | text_encoder.to(torch.float16) 99 | appearance_encoder.to(torch.float16) 100 | controlnet.to(torch.float16) 101 | 102 | pipeline = AnimationPipeline( 103 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, 104 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), 105 | # NOTE: UniPCMultistepScheduler 106 | ) 107 | 108 | # 1. unet ckpt 109 | # 1.1 motion module 110 | motion_module_state_dict = torch.load(motion_module, map_location="cpu") 111 | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) 112 | motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict 113 | try: 114 | # extra steps for self-trained models 115 | state_dict = OrderedDict() 116 | for key in motion_module_state_dict.keys(): 117 | if key.startswith("module."): 118 | _key = key.split("module.")[-1] 119 | state_dict[_key] = motion_module_state_dict[key] 120 | else: 121 | state_dict[key] = motion_module_state_dict[key] 122 | motion_module_state_dict = state_dict 123 | del state_dict 124 | missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) 125 | assert len(unexpected) == 0 126 | except: 127 | _tmp_ = OrderedDict() 128 | for key in motion_module_state_dict.keys(): 129 | if "motion_modules" in key: 130 | if key.startswith("unet."): 131 | _key = key.split('unet.')[-1] 132 | _tmp_[_key] = motion_module_state_dict[key] 133 | else: 134 | _tmp_[key] = motion_module_state_dict[key] 135 | missing, unexpected = unet.load_state_dict(_tmp_, strict=False) 136 | assert len(unexpected) == 0 137 | del _tmp_ 138 | del motion_module_state_dict 139 | 140 | pipeline.to(device) 141 | ### <<< create validation pipeline <<< ### 142 | 143 | random_seeds = config.get("seed", [-1]) 144 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 145 | random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds 146 | 147 | # input test videos (either source video/ conditions) 148 | 149 | test_videos = config.video_path 150 | source_images = config.source_image 151 | num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) 152 | 153 | # read size, step from yaml file 154 | sizes = [config.size] * len(test_videos) 155 | steps = [config.S] * len(test_videos) 156 | 157 | config.random_seed = [] 158 | prompt = n_prompt = "" 159 | for idx, (source_image, test_video, random_seed, size, step) in tqdm( 160 | enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), 161 | total=len(test_videos), 162 | disable=(args.rank!=0) 163 | ): 164 | samples_per_video = [] 165 | samples_per_clip = [] 166 | # manually set random seed for reproduction 167 | if random_seed != -1: 168 | torch.manual_seed(random_seed) 169 | set_seed(random_seed) 170 | else: 171 | torch.seed() 172 | config.random_seed.append(torch.initial_seed()) 173 | 174 | if test_video.endswith('.mp4'): 175 | control = VideoReader(test_video).read() 176 | if control[0].shape[0] != size: 177 | control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] 178 | if config.max_length is not None: 179 | control = control[config.offset: (config.offset+config.max_length)] 180 | control = np.array(control) 181 | 182 | if source_image.endswith(".mp4"): 183 | source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) 184 | else: 185 | source_image = np.array(Image.open(source_image).resize((size, size))) 186 | H, W, C = source_image.shape 187 | 188 | print(f"current seed: {torch.initial_seed()}") 189 | init_latents = None 190 | 191 | # print(f"sampling {prompt} ...") 192 | original_length = control.shape[0] 193 | if control.shape[0] % config.L > 0: 194 | control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') 195 | generator = torch.Generator(device=torch.device("cuda:0")) 196 | generator.manual_seed(torch.initial_seed()) 197 | sample = pipeline( 198 | prompt, 199 | negative_prompt = n_prompt, 200 | num_inference_steps = config.steps, 201 | guidance_scale = config.guidance_scale, 202 | width = W, 203 | height = H, 204 | video_length = len(control), 205 | controlnet_condition = control, 206 | init_latents = init_latents, 207 | generator = generator, 208 | num_actual_inference_steps = num_actual_inference_steps, 209 | appearance_encoder = appearance_encoder, 210 | reference_control_writer = reference_control_writer, 211 | reference_control_reader = reference_control_reader, 212 | source_image = source_image, 213 | **dist_kwargs, 214 | ).videos 215 | 216 | if args.rank == 0: 217 | source_images = np.array([source_image] * original_length) 218 | source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 219 | samples_per_video.append(source_images) 220 | 221 | control = control / 255.0 222 | control = rearrange(control, "t h w c -> 1 c t h w") 223 | control = torch.from_numpy(control) 224 | samples_per_video.append(control[:, :, :original_length]) 225 | 226 | samples_per_video.append(sample[:, :, :original_length]) 227 | 228 | samples_per_video = torch.cat(samples_per_video) 229 | 230 | video_name = os.path.basename(test_video)[:-4] 231 | source_name = os.path.basename(config.source_image[idx]).split(".")[0] 232 | save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") 233 | save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") 234 | 235 | if config.save_individual_videos: 236 | save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") 237 | save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") 238 | 239 | if args.dist: 240 | dist.barrier() 241 | 242 | if args.rank == 0: 243 | OmegaConf.save(config, f"{savedir}/config.yaml") 244 | 245 | 246 | def distributed_main(device_id, args): 247 | args.rank = device_id 248 | args.device_id = device_id 249 | if torch.cuda.is_available(): 250 | torch.cuda.set_device(args.device_id) 251 | torch.cuda.init() 252 | distributed_init(args) 253 | main(args) 254 | 255 | 256 | def run(args): 257 | 258 | if args.dist: 259 | args.world_size = max(1, torch.cuda.device_count()) 260 | assert args.world_size <= torch.cuda.device_count() 261 | 262 | if args.world_size > 0 and torch.cuda.device_count() > 1: 263 | port = random.randint(10000, 20000) 264 | args.init_method = f"tcp://localhost:{port}" 265 | torch.multiprocessing.spawn( 266 | fn=distributed_main, 267 | args=(args,), 268 | nprocs=args.world_size, 269 | ) 270 | else: 271 | main(args) 272 | 273 | 274 | if __name__ == "__main__": 275 | parser = argparse.ArgumentParser() 276 | parser.add_argument("--config", type=str, required=True) 277 | parser.add_argument("--dist", action="store_true", required=False) 278 | parser.add_argument("--rank", type=int, default=0, required=False) 279 | parser.add_argument("--world_size", type=int, default=1, required=False) 280 | 281 | args = parser.parse_args() 282 | run(args) 283 | -------------------------------------------------------------------------------- /magicanimate/models/attention.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2023 The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | from dataclasses import dataclass 21 | from typing import Optional 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | from torch import nn 26 | 27 | from diffusers.configuration_utils import ConfigMixin, register_to_config 28 | from diffusers.models.modeling_utils import ModelMixin 29 | from diffusers.utils import BaseOutput 30 | from diffusers.utils.import_utils import is_xformers_available 31 | from diffusers.models.attention import FeedForward, AdaLayerNorm 32 | from diffusers.models.attention import Attention as CrossAttention 33 | 34 | from einops import rearrange, repeat 35 | 36 | @dataclass 37 | class Transformer3DModelOutput(BaseOutput): 38 | sample: torch.FloatTensor 39 | 40 | 41 | if is_xformers_available(): 42 | import xformers 43 | import xformers.ops 44 | else: 45 | xformers = None 46 | 47 | 48 | class Transformer3DModel(ModelMixin, ConfigMixin): 49 | @register_to_config 50 | def __init__( 51 | self, 52 | num_attention_heads: int = 16, 53 | attention_head_dim: int = 88, 54 | in_channels: Optional[int] = None, 55 | num_layers: int = 1, 56 | dropout: float = 0.0, 57 | norm_num_groups: int = 32, 58 | cross_attention_dim: Optional[int] = None, 59 | attention_bias: bool = False, 60 | activation_fn: str = "geglu", 61 | num_embeds_ada_norm: Optional[int] = None, 62 | use_linear_projection: bool = False, 63 | only_cross_attention: bool = False, 64 | upcast_attention: bool = False, 65 | 66 | unet_use_cross_frame_attention=None, 67 | unet_use_temporal_attention=None, 68 | ): 69 | super().__init__() 70 | self.use_linear_projection = use_linear_projection 71 | self.num_attention_heads = num_attention_heads 72 | self.attention_head_dim = attention_head_dim 73 | inner_dim = num_attention_heads * attention_head_dim 74 | 75 | # Define input layers 76 | self.in_channels = in_channels 77 | 78 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 79 | if use_linear_projection: 80 | self.proj_in = nn.Linear(in_channels, inner_dim) 81 | else: 82 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 83 | 84 | # Define transformers blocks 85 | self.transformer_blocks = nn.ModuleList( 86 | [ 87 | BasicTransformerBlock( 88 | inner_dim, 89 | num_attention_heads, 90 | attention_head_dim, 91 | dropout=dropout, 92 | cross_attention_dim=cross_attention_dim, 93 | activation_fn=activation_fn, 94 | num_embeds_ada_norm=num_embeds_ada_norm, 95 | attention_bias=attention_bias, 96 | only_cross_attention=only_cross_attention, 97 | upcast_attention=upcast_attention, 98 | 99 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 100 | unet_use_temporal_attention=unet_use_temporal_attention, 101 | ) 102 | for d in range(num_layers) 103 | ] 104 | ) 105 | 106 | # 4. Define output layers 107 | if use_linear_projection: 108 | self.proj_out = nn.Linear(in_channels, inner_dim) 109 | else: 110 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 111 | 112 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 113 | # Input 114 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 115 | video_length = hidden_states.shape[2] 116 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 117 | # JH: need not repeat when a list of prompts are given 118 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 119 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 120 | 121 | batch, channel, height, weight = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | if not self.use_linear_projection: 126 | hidden_states = self.proj_in(hidden_states) 127 | inner_dim = hidden_states.shape[1] 128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 129 | else: 130 | inner_dim = hidden_states.shape[1] 131 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 132 | hidden_states = self.proj_in(hidden_states) 133 | 134 | # Blocks 135 | for block in self.transformer_blocks: 136 | hidden_states = block( 137 | hidden_states, 138 | encoder_hidden_states=encoder_hidden_states, 139 | timestep=timestep, 140 | video_length=video_length 141 | ) 142 | 143 | # Output 144 | if not self.use_linear_projection: 145 | hidden_states = ( 146 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 147 | ) 148 | hidden_states = self.proj_out(hidden_states) 149 | else: 150 | hidden_states = self.proj_out(hidden_states) 151 | hidden_states = ( 152 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 153 | ) 154 | 155 | output = hidden_states + residual 156 | 157 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 158 | if not return_dict: 159 | return (output,) 160 | 161 | return Transformer3DModelOutput(sample=output) 162 | 163 | 164 | class BasicTransformerBlock(nn.Module): 165 | def __init__( 166 | self, 167 | dim: int, 168 | num_attention_heads: int, 169 | attention_head_dim: int, 170 | dropout=0.0, 171 | cross_attention_dim: Optional[int] = None, 172 | activation_fn: str = "geglu", 173 | num_embeds_ada_norm: Optional[int] = None, 174 | attention_bias: bool = False, 175 | only_cross_attention: bool = False, 176 | upcast_attention: bool = False, 177 | 178 | unet_use_cross_frame_attention = None, 179 | unet_use_temporal_attention = None, 180 | ): 181 | super().__init__() 182 | self.only_cross_attention = only_cross_attention 183 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 184 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 185 | self.unet_use_temporal_attention = unet_use_temporal_attention 186 | 187 | # SC-Attn 188 | assert unet_use_cross_frame_attention is not None 189 | if unet_use_cross_frame_attention: 190 | self.attn1 = SparseCausalAttention2D( 191 | query_dim=dim, 192 | heads=num_attention_heads, 193 | dim_head=attention_head_dim, 194 | dropout=dropout, 195 | bias=attention_bias, 196 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 197 | upcast_attention=upcast_attention, 198 | ) 199 | else: 200 | self.attn1 = CrossAttention( 201 | query_dim=dim, 202 | heads=num_attention_heads, 203 | dim_head=attention_head_dim, 204 | dropout=dropout, 205 | bias=attention_bias, 206 | upcast_attention=upcast_attention, 207 | ) 208 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 209 | 210 | # Cross-Attn 211 | if cross_attention_dim is not None: 212 | self.attn2 = CrossAttention( 213 | query_dim=dim, 214 | cross_attention_dim=cross_attention_dim, 215 | heads=num_attention_heads, 216 | dim_head=attention_head_dim, 217 | dropout=dropout, 218 | bias=attention_bias, 219 | upcast_attention=upcast_attention, 220 | ) 221 | else: 222 | self.attn2 = None 223 | 224 | if cross_attention_dim is not None: 225 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 226 | else: 227 | self.norm2 = None 228 | 229 | # Feed-forward 230 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 231 | self.norm3 = nn.LayerNorm(dim) 232 | self.use_ada_layer_norm_zero = False 233 | 234 | # Temp-Attn 235 | assert unet_use_temporal_attention is not None 236 | if unet_use_temporal_attention: 237 | self.attn_temp = CrossAttention( 238 | query_dim=dim, 239 | heads=num_attention_heads, 240 | dim_head=attention_head_dim, 241 | dropout=dropout, 242 | bias=attention_bias, 243 | upcast_attention=upcast_attention, 244 | ) 245 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 246 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 247 | 248 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): 249 | if not is_xformers_available(): 250 | print("Here is how to install it") 251 | raise ModuleNotFoundError( 252 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 253 | " xformers", 254 | name="xformers", 255 | ) 256 | elif not torch.cuda.is_available(): 257 | raise ValueError( 258 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 259 | " available for GPU " 260 | ) 261 | else: 262 | try: 263 | # Make sure we can run the memory efficient attention 264 | _ = xformers.ops.memory_efficient_attention( 265 | torch.randn((1, 2, 40), device="cuda"), 266 | torch.randn((1, 2, 40), device="cuda"), 267 | torch.randn((1, 2, 40), device="cuda"), 268 | ) 269 | except Exception as e: 270 | raise e 271 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 272 | if self.attn2 is not None: 273 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 274 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 275 | 276 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 277 | # SparseCausal-Attention 278 | norm_hidden_states = ( 279 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 280 | ) 281 | 282 | # if self.only_cross_attention: 283 | # hidden_states = ( 284 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 285 | # ) 286 | # else: 287 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 288 | 289 | # pdb.set_trace() 290 | if self.unet_use_cross_frame_attention: 291 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 292 | else: 293 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 294 | 295 | if self.attn2 is not None: 296 | # Cross-Attention 297 | norm_hidden_states = ( 298 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 299 | ) 300 | hidden_states = ( 301 | self.attn2( 302 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 303 | ) 304 | + hidden_states 305 | ) 306 | 307 | # Feed-forward 308 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 309 | 310 | # Temporal-Attention 311 | if self.unet_use_temporal_attention: 312 | d = hidden_states.shape[1] 313 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 314 | norm_hidden_states = ( 315 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 316 | ) 317 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 318 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 319 | 320 | return hidden_states 321 | -------------------------------------------------------------------------------- /magicanimate/models/motion_module.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | from dataclasses import dataclass 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | from diffusers.utils import BaseOutput 15 | from diffusers.utils.import_utils import is_xformers_available 16 | from diffusers.models.attention import FeedForward 17 | from magicanimate.models.orig_attention import CrossAttention 18 | 19 | from einops import rearrange, repeat 20 | import math 21 | 22 | 23 | def zero_module(module): 24 | # Zero out the parameters of a module and return it. 25 | for p in module.parameters(): 26 | p.detach().zero_() 27 | return module 28 | 29 | 30 | @dataclass 31 | class TemporalTransformer3DModelOutput(BaseOutput): 32 | sample: torch.FloatTensor 33 | 34 | 35 | if is_xformers_available(): 36 | import xformers 37 | import xformers.ops 38 | else: 39 | xformers = None 40 | 41 | 42 | def get_motion_module( 43 | in_channels, 44 | motion_module_type: str, 45 | motion_module_kwargs: dict 46 | ): 47 | if motion_module_type == "Vanilla": 48 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 49 | else: 50 | raise ValueError 51 | 52 | 53 | class VanillaTemporalModule(nn.Module): 54 | def __init__( 55 | self, 56 | in_channels, 57 | num_attention_heads = 8, 58 | num_transformer_block = 2, 59 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 60 | cross_frame_attention_mode = None, 61 | temporal_position_encoding = False, 62 | temporal_position_encoding_max_len = 24, 63 | temporal_attention_dim_div = 1, 64 | zero_initialize = True, 65 | ): 66 | super().__init__() 67 | 68 | self.temporal_transformer = TemporalTransformer3DModel( 69 | in_channels=in_channels, 70 | num_attention_heads=num_attention_heads, 71 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 72 | num_layers=num_transformer_block, 73 | attention_block_types=attention_block_types, 74 | cross_frame_attention_mode=cross_frame_attention_mode, 75 | temporal_position_encoding=temporal_position_encoding, 76 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 77 | ) 78 | 79 | if zero_initialize: 80 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 81 | 82 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 83 | hidden_states = input_tensor 84 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 85 | 86 | output = hidden_states 87 | return output 88 | 89 | 90 | class TemporalTransformer3DModel(nn.Module): 91 | def __init__( 92 | self, 93 | in_channels, 94 | num_attention_heads, 95 | attention_head_dim, 96 | 97 | num_layers, 98 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 99 | dropout = 0.0, 100 | norm_num_groups = 32, 101 | cross_attention_dim = 768, 102 | activation_fn = "geglu", 103 | attention_bias = False, 104 | upcast_attention = False, 105 | 106 | cross_frame_attention_mode = None, 107 | temporal_position_encoding = False, 108 | temporal_position_encoding_max_len = 24, 109 | ): 110 | super().__init__() 111 | 112 | inner_dim = num_attention_heads * attention_head_dim 113 | 114 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 115 | self.proj_in = nn.Linear(in_channels, inner_dim) 116 | 117 | self.transformer_blocks = nn.ModuleList( 118 | [ 119 | TemporalTransformerBlock( 120 | dim=inner_dim, 121 | num_attention_heads=num_attention_heads, 122 | attention_head_dim=attention_head_dim, 123 | attention_block_types=attention_block_types, 124 | dropout=dropout, 125 | norm_num_groups=norm_num_groups, 126 | cross_attention_dim=cross_attention_dim, 127 | activation_fn=activation_fn, 128 | attention_bias=attention_bias, 129 | upcast_attention=upcast_attention, 130 | cross_frame_attention_mode=cross_frame_attention_mode, 131 | temporal_position_encoding=temporal_position_encoding, 132 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 133 | ) 134 | for d in range(num_layers) 135 | ] 136 | ) 137 | self.proj_out = nn.Linear(inner_dim, in_channels) 138 | 139 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 140 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 141 | video_length = hidden_states.shape[2] 142 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 143 | 144 | batch, channel, height, weight = hidden_states.shape 145 | residual = hidden_states 146 | 147 | hidden_states = self.norm(hidden_states) 148 | inner_dim = hidden_states.shape[1] 149 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 150 | hidden_states = self.proj_in(hidden_states) 151 | 152 | # Transformer Blocks 153 | for block in self.transformer_blocks: 154 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 155 | 156 | # output 157 | hidden_states = self.proj_out(hidden_states) 158 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 159 | 160 | output = hidden_states + residual 161 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 162 | 163 | return output 164 | 165 | 166 | class TemporalTransformerBlock(nn.Module): 167 | def __init__( 168 | self, 169 | dim, 170 | num_attention_heads, 171 | attention_head_dim, 172 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 173 | dropout = 0.0, 174 | norm_num_groups = 32, 175 | cross_attention_dim = 768, 176 | activation_fn = "geglu", 177 | attention_bias = False, 178 | upcast_attention = False, 179 | cross_frame_attention_mode = None, 180 | temporal_position_encoding = False, 181 | temporal_position_encoding_max_len = 24, 182 | ): 183 | super().__init__() 184 | 185 | attention_blocks = [] 186 | norms = [] 187 | 188 | for block_name in attention_block_types: 189 | attention_blocks.append( 190 | VersatileAttention( 191 | attention_mode=block_name.split("_")[0], 192 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 193 | 194 | query_dim=dim, 195 | heads=num_attention_heads, 196 | dim_head=attention_head_dim, 197 | dropout=dropout, 198 | bias=attention_bias, 199 | upcast_attention=upcast_attention, 200 | 201 | cross_frame_attention_mode=cross_frame_attention_mode, 202 | temporal_position_encoding=temporal_position_encoding, 203 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 204 | ) 205 | ) 206 | norms.append(nn.LayerNorm(dim)) 207 | 208 | self.attention_blocks = nn.ModuleList(attention_blocks) 209 | self.norms = nn.ModuleList(norms) 210 | 211 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 212 | self.ff_norm = nn.LayerNorm(dim) 213 | 214 | 215 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 216 | for attention_block, norm in zip(self.attention_blocks, self.norms): 217 | norm_hidden_states = norm(hidden_states) 218 | hidden_states = attention_block( 219 | norm_hidden_states, 220 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 221 | video_length=video_length, 222 | ) + hidden_states 223 | 224 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 225 | 226 | output = hidden_states 227 | return output 228 | 229 | 230 | class PositionalEncoding(nn.Module): 231 | def __init__( 232 | self, 233 | d_model, 234 | dropout = 0., 235 | max_len = 24 236 | ): 237 | super().__init__() 238 | self.dropout = nn.Dropout(p=dropout) 239 | position = torch.arange(max_len).unsqueeze(1) 240 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 241 | pe = torch.zeros(1, max_len, d_model) 242 | pe[0, :, 0::2] = torch.sin(position * div_term) 243 | pe[0, :, 1::2] = torch.cos(position * div_term) 244 | self.register_buffer('pe', pe) 245 | 246 | def forward(self, x): 247 | x = x + self.pe[:, :x.size(1)] 248 | return self.dropout(x) 249 | 250 | 251 | class VersatileAttention(CrossAttention): 252 | def __init__( 253 | self, 254 | attention_mode = None, 255 | cross_frame_attention_mode = None, 256 | temporal_position_encoding = False, 257 | temporal_position_encoding_max_len = 24, 258 | *args, **kwargs 259 | ): 260 | super().__init__(*args, **kwargs) 261 | assert attention_mode == "Temporal" 262 | 263 | self.attention_mode = attention_mode 264 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 265 | 266 | self.pos_encoder = PositionalEncoding( 267 | kwargs["query_dim"], 268 | dropout=0., 269 | max_len=temporal_position_encoding_max_len 270 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 271 | 272 | def extra_repr(self): 273 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 274 | 275 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 276 | batch_size, sequence_length, _ = hidden_states.shape 277 | 278 | if self.attention_mode == "Temporal": 279 | d = hidden_states.shape[1] 280 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 281 | 282 | if self.pos_encoder is not None: 283 | hidden_states = self.pos_encoder(hidden_states) 284 | 285 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 286 | else: 287 | raise NotImplementedError 288 | 289 | encoder_hidden_states = encoder_hidden_states 290 | 291 | if self.group_norm is not None: 292 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 293 | 294 | query = self.to_q(hidden_states) 295 | dim = query.shape[-1] 296 | query = self.reshape_heads_to_batch_dim(query) 297 | 298 | if self.added_kv_proj_dim is not None: 299 | raise NotImplementedError 300 | 301 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 302 | key = self.to_k(encoder_hidden_states) 303 | value = self.to_v(encoder_hidden_states) 304 | 305 | key = self.reshape_heads_to_batch_dim(key) 306 | value = self.reshape_heads_to_batch_dim(value) 307 | 308 | if attention_mask is not None: 309 | if attention_mask.shape[-1] != query.shape[1]: 310 | target_length = query.shape[1] 311 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 312 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 313 | 314 | # attention, what we cannot get enough of 315 | if self._use_memory_efficient_attention_xformers: 316 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 317 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 318 | hidden_states = hidden_states.to(query.dtype) 319 | else: 320 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 321 | hidden_states = self._attention(query, key, value, attention_mask) 322 | else: 323 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 324 | 325 | # linear proj 326 | hidden_states = self.to_out[0](hidden_states) 327 | 328 | # dropout 329 | hidden_states = self.to_out[1](hidden_states) 330 | 331 | if self.attention_mode == "Temporal": 332 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 333 | 334 | return hidden_states 335 | -------------------------------------------------------------------------------- /magicanimate/models/embeddings.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2023 The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import math 21 | from typing import Optional 22 | 23 | import numpy as np 24 | import torch 25 | from torch import nn 26 | 27 | 28 | def get_timestep_embedding( 29 | timesteps: torch.Tensor, 30 | embedding_dim: int, 31 | flip_sin_to_cos: bool = False, 32 | downscale_freq_shift: float = 1, 33 | scale: float = 1, 34 | max_period: int = 10000, 35 | ): 36 | """ 37 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 38 | 39 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 40 | These may be fractional. 41 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 42 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 43 | """ 44 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 45 | 46 | half_dim = embedding_dim // 2 47 | exponent = -math.log(max_period) * torch.arange( 48 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 49 | ) 50 | exponent = exponent / (half_dim - downscale_freq_shift) 51 | 52 | emb = torch.exp(exponent) 53 | emb = timesteps[:, None].float() * emb[None, :] 54 | 55 | # scale embeddings 56 | emb = scale * emb 57 | 58 | # concat sine and cosine embeddings 59 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 60 | 61 | # flip sine and cosine embeddings 62 | if flip_sin_to_cos: 63 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 64 | 65 | # zero pad 66 | if embedding_dim % 2 == 1: 67 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 68 | return emb 69 | 70 | 71 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 72 | """ 73 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 74 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 75 | """ 76 | grid_h = np.arange(grid_size, dtype=np.float32) 77 | grid_w = np.arange(grid_size, dtype=np.float32) 78 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 79 | grid = np.stack(grid, axis=0) 80 | 81 | grid = grid.reshape([2, 1, grid_size, grid_size]) 82 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 83 | if cls_token and extra_tokens > 0: 84 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 85 | return pos_embed 86 | 87 | 88 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 89 | if embed_dim % 2 != 0: 90 | raise ValueError("embed_dim must be divisible by 2") 91 | 92 | # use half of dimensions to encode grid_h 93 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 94 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 95 | 96 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 97 | return emb 98 | 99 | 100 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 101 | """ 102 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 103 | """ 104 | if embed_dim % 2 != 0: 105 | raise ValueError("embed_dim must be divisible by 2") 106 | 107 | omega = np.arange(embed_dim // 2, dtype=np.float64) 108 | omega /= embed_dim / 2.0 109 | omega = 1.0 / 10000**omega # (D/2,) 110 | 111 | pos = pos.reshape(-1) # (M,) 112 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 113 | 114 | emb_sin = np.sin(out) # (M, D/2) 115 | emb_cos = np.cos(out) # (M, D/2) 116 | 117 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 118 | return emb 119 | 120 | 121 | class PatchEmbed(nn.Module): 122 | """2D Image to Patch Embedding""" 123 | 124 | def __init__( 125 | self, 126 | height=224, 127 | width=224, 128 | patch_size=16, 129 | in_channels=3, 130 | embed_dim=768, 131 | layer_norm=False, 132 | flatten=True, 133 | bias=True, 134 | ): 135 | super().__init__() 136 | 137 | num_patches = (height // patch_size) * (width // patch_size) 138 | self.flatten = flatten 139 | self.layer_norm = layer_norm 140 | 141 | self.proj = nn.Conv2d( 142 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 143 | ) 144 | if layer_norm: 145 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 146 | else: 147 | self.norm = None 148 | 149 | pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) 150 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) 151 | 152 | def forward(self, latent): 153 | latent = self.proj(latent) 154 | if self.flatten: 155 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 156 | if self.layer_norm: 157 | latent = self.norm(latent) 158 | return latent + self.pos_embed 159 | 160 | 161 | class TimestepEmbedding(nn.Module): 162 | def __init__( 163 | self, 164 | in_channels: int, 165 | time_embed_dim: int, 166 | act_fn: str = "silu", 167 | out_dim: int = None, 168 | post_act_fn: Optional[str] = None, 169 | cond_proj_dim=None, 170 | ): 171 | super().__init__() 172 | 173 | self.linear_1 = nn.Linear(in_channels, time_embed_dim) 174 | 175 | if cond_proj_dim is not None: 176 | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) 177 | else: 178 | self.cond_proj = None 179 | 180 | if act_fn == "silu": 181 | self.act = nn.SiLU() 182 | elif act_fn == "mish": 183 | self.act = nn.Mish() 184 | elif act_fn == "gelu": 185 | self.act = nn.GELU() 186 | else: 187 | raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") 188 | 189 | if out_dim is not None: 190 | time_embed_dim_out = out_dim 191 | else: 192 | time_embed_dim_out = time_embed_dim 193 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) 194 | 195 | if post_act_fn is None: 196 | self.post_act = None 197 | elif post_act_fn == "silu": 198 | self.post_act = nn.SiLU() 199 | elif post_act_fn == "mish": 200 | self.post_act = nn.Mish() 201 | elif post_act_fn == "gelu": 202 | self.post_act = nn.GELU() 203 | else: 204 | raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") 205 | 206 | def forward(self, sample, condition=None): 207 | if condition is not None: 208 | sample = sample + self.cond_proj(condition) 209 | sample = self.linear_1(sample) 210 | 211 | if self.act is not None: 212 | sample = self.act(sample) 213 | 214 | sample = self.linear_2(sample) 215 | 216 | if self.post_act is not None: 217 | sample = self.post_act(sample) 218 | return sample 219 | 220 | 221 | class Timesteps(nn.Module): 222 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): 223 | super().__init__() 224 | self.num_channels = num_channels 225 | self.flip_sin_to_cos = flip_sin_to_cos 226 | self.downscale_freq_shift = downscale_freq_shift 227 | 228 | def forward(self, timesteps): 229 | t_emb = get_timestep_embedding( 230 | timesteps, 231 | self.num_channels, 232 | flip_sin_to_cos=self.flip_sin_to_cos, 233 | downscale_freq_shift=self.downscale_freq_shift, 234 | ) 235 | return t_emb 236 | 237 | 238 | class GaussianFourierProjection(nn.Module): 239 | """Gaussian Fourier embeddings for noise levels.""" 240 | 241 | def __init__( 242 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False 243 | ): 244 | super().__init__() 245 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 246 | self.log = log 247 | self.flip_sin_to_cos = flip_sin_to_cos 248 | 249 | if set_W_to_weight: 250 | # to delete later 251 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 252 | 253 | self.weight = self.W 254 | 255 | def forward(self, x): 256 | if self.log: 257 | x = torch.log(x) 258 | 259 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi 260 | 261 | if self.flip_sin_to_cos: 262 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) 263 | else: 264 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 265 | return out 266 | 267 | 268 | class ImagePositionalEmbeddings(nn.Module): 269 | """ 270 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the 271 | height and width of the latent space. 272 | 273 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 274 | 275 | For VQ-diffusion: 276 | 277 | Output vector embeddings are used as input for the transformer. 278 | 279 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. 280 | 281 | Args: 282 | num_embed (`int`): 283 | Number of embeddings for the latent pixels embeddings. 284 | height (`int`): 285 | Height of the latent image i.e. the number of height embeddings. 286 | width (`int`): 287 | Width of the latent image i.e. the number of width embeddings. 288 | embed_dim (`int`): 289 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. 290 | """ 291 | 292 | def __init__( 293 | self, 294 | num_embed: int, 295 | height: int, 296 | width: int, 297 | embed_dim: int, 298 | ): 299 | super().__init__() 300 | 301 | self.height = height 302 | self.width = width 303 | self.num_embed = num_embed 304 | self.embed_dim = embed_dim 305 | 306 | self.emb = nn.Embedding(self.num_embed, embed_dim) 307 | self.height_emb = nn.Embedding(self.height, embed_dim) 308 | self.width_emb = nn.Embedding(self.width, embed_dim) 309 | 310 | def forward(self, index): 311 | emb = self.emb(index) 312 | 313 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) 314 | 315 | # 1 x H x D -> 1 x H x 1 x D 316 | height_emb = height_emb.unsqueeze(2) 317 | 318 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) 319 | 320 | # 1 x W x D -> 1 x 1 x W x D 321 | width_emb = width_emb.unsqueeze(1) 322 | 323 | pos_emb = height_emb + width_emb 324 | 325 | # 1 x H x W x D -> 1 x L xD 326 | pos_emb = pos_emb.view(1, self.height * self.width, -1) 327 | 328 | emb = emb + pos_emb[:, : emb.shape[1], :] 329 | 330 | return emb 331 | 332 | 333 | class LabelEmbedding(nn.Module): 334 | """ 335 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 336 | 337 | Args: 338 | num_classes (`int`): The number of classes. 339 | hidden_size (`int`): The size of the vector embeddings. 340 | dropout_prob (`float`): The probability of dropping a label. 341 | """ 342 | 343 | def __init__(self, num_classes, hidden_size, dropout_prob): 344 | super().__init__() 345 | use_cfg_embedding = dropout_prob > 0 346 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 347 | self.num_classes = num_classes 348 | self.dropout_prob = dropout_prob 349 | 350 | def token_drop(self, labels, force_drop_ids=None): 351 | """ 352 | Drops labels to enable classifier-free guidance. 353 | """ 354 | if force_drop_ids is None: 355 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 356 | else: 357 | drop_ids = torch.tensor(force_drop_ids == 1) 358 | labels = torch.where(drop_ids, self.num_classes, labels) 359 | return labels 360 | 361 | def forward(self, labels, force_drop_ids=None): 362 | use_dropout = self.dropout_prob > 0 363 | if (self.training and use_dropout) or (force_drop_ids is not None): 364 | labels = self.token_drop(labels, force_drop_ids) 365 | embeddings = self.embedding_table(labels) 366 | return embeddings 367 | 368 | 369 | class CombinedTimestepLabelEmbeddings(nn.Module): 370 | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): 371 | super().__init__() 372 | 373 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) 374 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 375 | self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) 376 | 377 | def forward(self, timestep, class_labels, hidden_dtype=None): 378 | timesteps_proj = self.time_proj(timestep) 379 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) 380 | 381 | class_labels = self.class_embedder(class_labels) # (N, D) 382 | 383 | conditioning = timesteps_emb + class_labels # (N, D) 384 | 385 | return conditioning -------------------------------------------------------------------------------- /magicanimate/models/unet.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | 9 | # Copyright 2023 The HuggingFace Team. All rights reserved. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | from dataclasses import dataclass 23 | from typing import List, Optional, Tuple, Union 24 | 25 | import os 26 | import json 27 | import pdb 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.utils.checkpoint 32 | 33 | from diffusers.configuration_utils import ConfigMixin, register_to_config 34 | from diffusers.models.modeling_utils import ModelMixin 35 | from diffusers.utils import BaseOutput, logging 36 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 37 | from .unet_3d_blocks import ( 38 | CrossAttnDownBlock3D, 39 | CrossAttnUpBlock3D, 40 | DownBlock3D, 41 | UNetMidBlock3DCrossAttn, 42 | UpBlock3D, 43 | get_down_block, 44 | get_up_block, 45 | ) 46 | from .resnet import InflatedConv3d 47 | 48 | 49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 50 | 51 | 52 | @dataclass 53 | class UNet3DConditionOutput(BaseOutput): 54 | sample: torch.FloatTensor 55 | 56 | 57 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 58 | _supports_gradient_checkpointing = True 59 | 60 | @register_to_config 61 | def __init__( 62 | self, 63 | sample_size: Optional[int] = None, 64 | in_channels: int = 4, 65 | out_channels: int = 4, 66 | center_input_sample: bool = False, 67 | flip_sin_to_cos: bool = True, 68 | freq_shift: int = 0, 69 | down_block_types: Tuple[str] = ( 70 | "CrossAttnDownBlock3D", 71 | "CrossAttnDownBlock3D", 72 | "CrossAttnDownBlock3D", 73 | "DownBlock3D", 74 | ), 75 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 76 | up_block_types: Tuple[str] = ( 77 | "UpBlock3D", 78 | "CrossAttnUpBlock3D", 79 | "CrossAttnUpBlock3D", 80 | "CrossAttnUpBlock3D" 81 | ), 82 | only_cross_attention: Union[bool, Tuple[bool]] = False, 83 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 84 | layers_per_block: int = 2, 85 | downsample_padding: int = 1, 86 | mid_block_scale_factor: float = 1, 87 | act_fn: str = "silu", 88 | norm_num_groups: int = 32, 89 | norm_eps: float = 1e-5, 90 | cross_attention_dim: int = 1280, 91 | attention_head_dim: Union[int, Tuple[int]] = 8, 92 | dual_cross_attention: bool = False, 93 | use_linear_projection: bool = False, 94 | class_embed_type: Optional[str] = None, 95 | num_class_embeds: Optional[int] = None, 96 | upcast_attention: bool = False, 97 | resnet_time_scale_shift: str = "default", 98 | 99 | # Additional 100 | use_motion_module = False, 101 | motion_module_resolutions = ( 1,2,4,8 ), 102 | motion_module_mid_block = False, 103 | motion_module_decoder_only = False, 104 | motion_module_type = None, 105 | motion_module_kwargs = {}, 106 | unet_use_cross_frame_attention = None, 107 | unet_use_temporal_attention = None, 108 | ): 109 | super().__init__() 110 | 111 | self.sample_size = sample_size 112 | time_embed_dim = block_out_channels[0] * 4 113 | 114 | # input 115 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 116 | 117 | # time 118 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 119 | timestep_input_dim = block_out_channels[0] 120 | 121 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 122 | 123 | # class embedding 124 | if class_embed_type is None and num_class_embeds is not None: 125 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 126 | elif class_embed_type == "timestep": 127 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 128 | elif class_embed_type == "identity": 129 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 130 | else: 131 | self.class_embedding = None 132 | 133 | self.down_blocks = nn.ModuleList([]) 134 | self.mid_block = None 135 | self.up_blocks = nn.ModuleList([]) 136 | 137 | if isinstance(only_cross_attention, bool): 138 | only_cross_attention = [only_cross_attention] * len(down_block_types) 139 | 140 | if isinstance(attention_head_dim, int): 141 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 142 | 143 | # down 144 | output_channel = block_out_channels[0] 145 | for i, down_block_type in enumerate(down_block_types): 146 | res = 2 ** i 147 | input_channel = output_channel 148 | output_channel = block_out_channels[i] 149 | is_final_block = i == len(block_out_channels) - 1 150 | 151 | down_block = get_down_block( 152 | down_block_type, 153 | num_layers=layers_per_block, 154 | in_channels=input_channel, 155 | out_channels=output_channel, 156 | temb_channels=time_embed_dim, 157 | add_downsample=not is_final_block, 158 | resnet_eps=norm_eps, 159 | resnet_act_fn=act_fn, 160 | resnet_groups=norm_num_groups, 161 | cross_attention_dim=cross_attention_dim, 162 | attn_num_head_channels=attention_head_dim[i], 163 | downsample_padding=downsample_padding, 164 | dual_cross_attention=dual_cross_attention, 165 | use_linear_projection=use_linear_projection, 166 | only_cross_attention=only_cross_attention[i], 167 | upcast_attention=upcast_attention, 168 | resnet_time_scale_shift=resnet_time_scale_shift, 169 | 170 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 171 | unet_use_temporal_attention=unet_use_temporal_attention, 172 | 173 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), 174 | motion_module_type=motion_module_type, 175 | motion_module_kwargs=motion_module_kwargs, 176 | ) 177 | self.down_blocks.append(down_block) 178 | 179 | # mid 180 | if mid_block_type == "UNetMidBlock3DCrossAttn": 181 | self.mid_block = UNetMidBlock3DCrossAttn( 182 | in_channels=block_out_channels[-1], 183 | temb_channels=time_embed_dim, 184 | resnet_eps=norm_eps, 185 | resnet_act_fn=act_fn, 186 | output_scale_factor=mid_block_scale_factor, 187 | resnet_time_scale_shift=resnet_time_scale_shift, 188 | cross_attention_dim=cross_attention_dim, 189 | attn_num_head_channels=attention_head_dim[-1], 190 | resnet_groups=norm_num_groups, 191 | dual_cross_attention=dual_cross_attention, 192 | use_linear_projection=use_linear_projection, 193 | upcast_attention=upcast_attention, 194 | 195 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 196 | unet_use_temporal_attention=unet_use_temporal_attention, 197 | 198 | use_motion_module=use_motion_module and motion_module_mid_block, 199 | motion_module_type=motion_module_type, 200 | motion_module_kwargs=motion_module_kwargs, 201 | ) 202 | else: 203 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 204 | 205 | # count how many layers upsample the videos 206 | self.num_upsamplers = 0 207 | 208 | # up 209 | reversed_block_out_channels = list(reversed(block_out_channels)) 210 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 211 | only_cross_attention = list(reversed(only_cross_attention)) 212 | output_channel = reversed_block_out_channels[0] 213 | for i, up_block_type in enumerate(up_block_types): 214 | res = 2 ** (3 - i) 215 | is_final_block = i == len(block_out_channels) - 1 216 | 217 | prev_output_channel = output_channel 218 | output_channel = reversed_block_out_channels[i] 219 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 220 | 221 | # add upsample block for all BUT final layer 222 | if not is_final_block: 223 | add_upsample = True 224 | self.num_upsamplers += 1 225 | else: 226 | add_upsample = False 227 | 228 | up_block = get_up_block( 229 | up_block_type, 230 | num_layers=layers_per_block + 1, 231 | in_channels=input_channel, 232 | out_channels=output_channel, 233 | prev_output_channel=prev_output_channel, 234 | temb_channels=time_embed_dim, 235 | add_upsample=add_upsample, 236 | resnet_eps=norm_eps, 237 | resnet_act_fn=act_fn, 238 | resnet_groups=norm_num_groups, 239 | cross_attention_dim=cross_attention_dim, 240 | attn_num_head_channels=reversed_attention_head_dim[i], 241 | dual_cross_attention=dual_cross_attention, 242 | use_linear_projection=use_linear_projection, 243 | only_cross_attention=only_cross_attention[i], 244 | upcast_attention=upcast_attention, 245 | resnet_time_scale_shift=resnet_time_scale_shift, 246 | 247 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 248 | unet_use_temporal_attention=unet_use_temporal_attention, 249 | 250 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 251 | motion_module_type=motion_module_type, 252 | motion_module_kwargs=motion_module_kwargs, 253 | ) 254 | self.up_blocks.append(up_block) 255 | prev_output_channel = output_channel 256 | 257 | # out 258 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 259 | self.conv_act = nn.SiLU() 260 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 261 | 262 | def set_attention_slice(self, slice_size): 263 | r""" 264 | Enable sliced attention computation. 265 | 266 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 267 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 268 | 269 | Args: 270 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 271 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 272 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 273 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 274 | must be a multiple of `slice_size`. 275 | """ 276 | sliceable_head_dims = [] 277 | 278 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 279 | if hasattr(module, "set_attention_slice"): 280 | sliceable_head_dims.append(module.sliceable_head_dim) 281 | 282 | for child in module.children(): 283 | fn_recursive_retrieve_slicable_dims(child) 284 | 285 | # retrieve number of attention layers 286 | for module in self.children(): 287 | fn_recursive_retrieve_slicable_dims(module) 288 | 289 | num_slicable_layers = len(sliceable_head_dims) 290 | 291 | if slice_size == "auto": 292 | # half the attention head size is usually a good trade-off between 293 | # speed and memory 294 | slice_size = [dim // 2 for dim in sliceable_head_dims] 295 | elif slice_size == "max": 296 | # make smallest slice possible 297 | slice_size = num_slicable_layers * [1] 298 | 299 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 300 | 301 | if len(slice_size) != len(sliceable_head_dims): 302 | raise ValueError( 303 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 304 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 305 | ) 306 | 307 | for i in range(len(slice_size)): 308 | size = slice_size[i] 309 | dim = sliceable_head_dims[i] 310 | if size is not None and size > dim: 311 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 312 | 313 | # Recursively walk through all the children. 314 | # Any children which exposes the set_attention_slice method 315 | # gets the message 316 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 317 | if hasattr(module, "set_attention_slice"): 318 | module.set_attention_slice(slice_size.pop()) 319 | 320 | for child in module.children(): 321 | fn_recursive_set_attention_slice(child, slice_size) 322 | 323 | reversed_slice_size = list(reversed(slice_size)) 324 | for module in self.children(): 325 | fn_recursive_set_attention_slice(module, reversed_slice_size) 326 | 327 | def _set_gradient_checkpointing(self, module, value=False): 328 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 329 | module.gradient_checkpointing = value 330 | 331 | def forward( 332 | self, 333 | sample: torch.FloatTensor, 334 | timestep: Union[torch.Tensor, float, int], 335 | encoder_hidden_states: torch.Tensor, 336 | class_labels: Optional[torch.Tensor] = None, 337 | attention_mask: Optional[torch.Tensor] = None, 338 | return_dict: bool = True, 339 | ) -> Union[UNet3DConditionOutput, Tuple]: 340 | r""" 341 | Args: 342 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 343 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 344 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 345 | return_dict (`bool`, *optional*, defaults to `True`): 346 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 347 | 348 | Returns: 349 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 350 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 351 | returning a tuple, the first element is the sample tensor. 352 | """ 353 | # By default samples have to be AT least a multiple of the overall upsampling factor. 354 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 355 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 356 | # on the fly if necessary. 357 | default_overall_up_factor = 2**self.num_upsamplers 358 | 359 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 360 | forward_upsample_size = False 361 | upsample_size = None 362 | 363 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 364 | logger.info("Forward upsample size to force interpolation output size.") 365 | forward_upsample_size = True 366 | 367 | # prepare attention_mask 368 | if attention_mask is not None: 369 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 370 | attention_mask = attention_mask.unsqueeze(1) 371 | 372 | # center input if necessary 373 | if self.config.center_input_sample: 374 | sample = 2 * sample - 1.0 375 | 376 | # time 377 | timesteps = timestep 378 | if not torch.is_tensor(timesteps): 379 | # This would be a good case for the `match` statement (Python 3.10+) 380 | is_mps = sample.device.type == "mps" 381 | if isinstance(timestep, float): 382 | dtype = torch.float32 if is_mps else torch.float64 383 | else: 384 | dtype = torch.int32 if is_mps else torch.int64 385 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 386 | elif len(timesteps.shape) == 0: 387 | timesteps = timesteps[None].to(sample.device) 388 | 389 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 390 | timesteps = timesteps.expand(sample.shape[0]) 391 | 392 | t_emb = self.time_proj(timesteps) 393 | 394 | # timesteps does not contain any weights and will always return f32 tensors 395 | # but time_embedding might actually be running in fp16. so we need to cast here. 396 | # there might be better ways to encapsulate this. 397 | t_emb = t_emb.to(dtype=self.dtype) 398 | emb = self.time_embedding(t_emb) 399 | 400 | if self.class_embedding is not None: 401 | if class_labels is None: 402 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 403 | 404 | if self.config.class_embed_type == "timestep": 405 | class_labels = self.time_proj(class_labels) 406 | 407 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 408 | emb = emb + class_emb 409 | 410 | # pre-process 411 | sample = self.conv_in(sample) 412 | 413 | # down 414 | down_block_res_samples = (sample,) 415 | for downsample_block in self.down_blocks: 416 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 417 | sample, res_samples = downsample_block( 418 | hidden_states=sample, 419 | temb=emb, 420 | encoder_hidden_states=encoder_hidden_states, 421 | attention_mask=attention_mask, 422 | ) 423 | else: 424 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) 425 | 426 | down_block_res_samples += res_samples 427 | 428 | # mid 429 | sample = self.mid_block( 430 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 431 | ) 432 | 433 | # up 434 | for i, upsample_block in enumerate(self.up_blocks): 435 | is_final_block = i == len(self.up_blocks) - 1 436 | 437 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 438 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 439 | 440 | # if we have not reached the final block and need to forward the 441 | # upsample size, we do it here 442 | if not is_final_block and forward_upsample_size: 443 | upsample_size = down_block_res_samples[-1].shape[2:] 444 | 445 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 446 | sample = upsample_block( 447 | hidden_states=sample, 448 | temb=emb, 449 | res_hidden_states_tuple=res_samples, 450 | encoder_hidden_states=encoder_hidden_states, 451 | upsample_size=upsample_size, 452 | attention_mask=attention_mask, 453 | ) 454 | else: 455 | sample = upsample_block( 456 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, 457 | ) 458 | 459 | # post-process 460 | sample = self.conv_norm_out(sample) 461 | sample = self.conv_act(sample) 462 | sample = self.conv_out(sample) 463 | 464 | if not return_dict: 465 | return (sample,) 466 | 467 | return UNet3DConditionOutput(sample=sample) 468 | 469 | @classmethod 470 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): 471 | if subfolder is not None: 472 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 473 | print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") 474 | 475 | config_file = os.path.join(pretrained_model_path, 'config.json') 476 | if not os.path.isfile(config_file): 477 | raise RuntimeError(f"{config_file} does not exist") 478 | with open(config_file, "r") as f: 479 | config = json.load(f) 480 | config["_class_name"] = cls.__name__ 481 | config["down_block_types"] = [ 482 | "CrossAttnDownBlock3D", 483 | "CrossAttnDownBlock3D", 484 | "CrossAttnDownBlock3D", 485 | "DownBlock3D" 486 | ] 487 | config["up_block_types"] = [ 488 | "UpBlock3D", 489 | "CrossAttnUpBlock3D", 490 | "CrossAttnUpBlock3D", 491 | "CrossAttnUpBlock3D" 492 | ] 493 | 494 | from diffusers.utils import WEIGHTS_NAME 495 | model = cls.from_config(config, **unet_additional_kwargs) 496 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 497 | if not os.path.isfile(model_file): 498 | raise RuntimeError(f"{model_file} does not exist") 499 | state_dict = torch.load(model_file, map_location="cpu") 500 | 501 | m, u = model.load_state_dict(state_dict, strict=False) 502 | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 503 | # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") 504 | 505 | params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] 506 | print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") 507 | 508 | return model 509 | -------------------------------------------------------------------------------- /magicanimate/models/unet_controlnet.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2023 The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | from dataclasses import dataclass 21 | from typing import List, Optional, Tuple, Union 22 | 23 | import os 24 | import json 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torch.utils.checkpoint 29 | 30 | from diffusers.configuration_utils import ConfigMixin, register_to_config 31 | from diffusers.models.modeling_utils import ModelMixin 32 | from diffusers.utils import BaseOutput, logging 33 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 34 | from magicanimate.models.unet_3d_blocks import ( 35 | CrossAttnDownBlock3D, 36 | CrossAttnUpBlock3D, 37 | DownBlock3D, 38 | UNetMidBlock3DCrossAttn, 39 | UpBlock3D, 40 | get_down_block, 41 | get_up_block, 42 | ) 43 | from .resnet import InflatedConv3d 44 | 45 | 46 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 47 | 48 | 49 | @dataclass 50 | class UNet3DConditionOutput(BaseOutput): 51 | sample: torch.FloatTensor 52 | 53 | 54 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 55 | _supports_gradient_checkpointing = True 56 | 57 | @register_to_config 58 | def __init__( 59 | self, 60 | sample_size: Optional[int] = None, 61 | in_channels: int = 4, 62 | out_channels: int = 4, 63 | center_input_sample: bool = False, 64 | flip_sin_to_cos: bool = True, 65 | freq_shift: int = 0, 66 | down_block_types: Tuple[str] = ( 67 | "CrossAttnDownBlock3D", 68 | "CrossAttnDownBlock3D", 69 | "CrossAttnDownBlock3D", 70 | "DownBlock3D", 71 | ), 72 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 73 | up_block_types: Tuple[str] = ( 74 | "UpBlock3D", 75 | "CrossAttnUpBlock3D", 76 | "CrossAttnUpBlock3D", 77 | "CrossAttnUpBlock3D" 78 | ), 79 | only_cross_attention: Union[bool, Tuple[bool]] = False, 80 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 81 | layers_per_block: int = 2, 82 | downsample_padding: int = 1, 83 | mid_block_scale_factor: float = 1, 84 | act_fn: str = "silu", 85 | norm_num_groups: int = 32, 86 | norm_eps: float = 1e-5, 87 | cross_attention_dim: int = 1280, 88 | attention_head_dim: Union[int, Tuple[int]] = 8, 89 | dual_cross_attention: bool = False, 90 | use_linear_projection: bool = False, 91 | class_embed_type: Optional[str] = None, 92 | num_class_embeds: Optional[int] = None, 93 | upcast_attention: bool = False, 94 | resnet_time_scale_shift: str = "default", 95 | 96 | # Additional 97 | use_motion_module = False, 98 | motion_module_resolutions = ( 1,2,4,8 ), 99 | motion_module_mid_block = False, 100 | motion_module_decoder_only = False, 101 | motion_module_type = None, 102 | motion_module_kwargs = {}, 103 | unet_use_cross_frame_attention = None, 104 | unet_use_temporal_attention = None, 105 | ): 106 | super().__init__() 107 | 108 | self.sample_size = sample_size 109 | time_embed_dim = block_out_channels[0] * 4 110 | 111 | # input 112 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 113 | 114 | # time 115 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 116 | timestep_input_dim = block_out_channels[0] 117 | 118 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 119 | 120 | # class embedding 121 | if class_embed_type is None and num_class_embeds is not None: 122 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 123 | elif class_embed_type == "timestep": 124 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 125 | elif class_embed_type == "identity": 126 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 127 | else: 128 | self.class_embedding = None 129 | 130 | self.down_blocks = nn.ModuleList([]) 131 | self.mid_block = None 132 | self.up_blocks = nn.ModuleList([]) 133 | 134 | if isinstance(only_cross_attention, bool): 135 | only_cross_attention = [only_cross_attention] * len(down_block_types) 136 | 137 | if isinstance(attention_head_dim, int): 138 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 139 | 140 | # down 141 | output_channel = block_out_channels[0] 142 | for i, down_block_type in enumerate(down_block_types): 143 | res = 2 ** i 144 | input_channel = output_channel 145 | output_channel = block_out_channels[i] 146 | is_final_block = i == len(block_out_channels) - 1 147 | 148 | down_block = get_down_block( 149 | down_block_type, 150 | num_layers=layers_per_block, 151 | in_channels=input_channel, 152 | out_channels=output_channel, 153 | temb_channels=time_embed_dim, 154 | add_downsample=not is_final_block, 155 | resnet_eps=norm_eps, 156 | resnet_act_fn=act_fn, 157 | resnet_groups=norm_num_groups, 158 | cross_attention_dim=cross_attention_dim, 159 | attn_num_head_channels=attention_head_dim[i], 160 | downsample_padding=downsample_padding, 161 | dual_cross_attention=dual_cross_attention, 162 | use_linear_projection=use_linear_projection, 163 | only_cross_attention=only_cross_attention[i], 164 | upcast_attention=upcast_attention, 165 | resnet_time_scale_shift=resnet_time_scale_shift, 166 | 167 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 168 | unet_use_temporal_attention=unet_use_temporal_attention, 169 | 170 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), 171 | motion_module_type=motion_module_type, 172 | motion_module_kwargs=motion_module_kwargs, 173 | ) 174 | self.down_blocks.append(down_block) 175 | 176 | # mid 177 | if mid_block_type == "UNetMidBlock3DCrossAttn": 178 | self.mid_block = UNetMidBlock3DCrossAttn( 179 | in_channels=block_out_channels[-1], 180 | temb_channels=time_embed_dim, 181 | resnet_eps=norm_eps, 182 | resnet_act_fn=act_fn, 183 | output_scale_factor=mid_block_scale_factor, 184 | resnet_time_scale_shift=resnet_time_scale_shift, 185 | cross_attention_dim=cross_attention_dim, 186 | attn_num_head_channels=attention_head_dim[-1], 187 | resnet_groups=norm_num_groups, 188 | dual_cross_attention=dual_cross_attention, 189 | use_linear_projection=use_linear_projection, 190 | upcast_attention=upcast_attention, 191 | 192 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 193 | unet_use_temporal_attention=unet_use_temporal_attention, 194 | 195 | use_motion_module=use_motion_module and motion_module_mid_block, 196 | motion_module_type=motion_module_type, 197 | motion_module_kwargs=motion_module_kwargs, 198 | ) 199 | else: 200 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 201 | 202 | # count how many layers upsample the videos 203 | self.num_upsamplers = 0 204 | 205 | # up 206 | reversed_block_out_channels = list(reversed(block_out_channels)) 207 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 208 | only_cross_attention = list(reversed(only_cross_attention)) 209 | output_channel = reversed_block_out_channels[0] 210 | for i, up_block_type in enumerate(up_block_types): 211 | res = 2 ** (3 - i) 212 | is_final_block = i == len(block_out_channels) - 1 213 | 214 | prev_output_channel = output_channel 215 | output_channel = reversed_block_out_channels[i] 216 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 217 | 218 | # add upsample block for all BUT final layer 219 | if not is_final_block: 220 | add_upsample = True 221 | self.num_upsamplers += 1 222 | else: 223 | add_upsample = False 224 | 225 | up_block = get_up_block( 226 | up_block_type, 227 | num_layers=layers_per_block + 1, 228 | in_channels=input_channel, 229 | out_channels=output_channel, 230 | prev_output_channel=prev_output_channel, 231 | temb_channels=time_embed_dim, 232 | add_upsample=add_upsample, 233 | resnet_eps=norm_eps, 234 | resnet_act_fn=act_fn, 235 | resnet_groups=norm_num_groups, 236 | cross_attention_dim=cross_attention_dim, 237 | attn_num_head_channels=reversed_attention_head_dim[i], 238 | dual_cross_attention=dual_cross_attention, 239 | use_linear_projection=use_linear_projection, 240 | only_cross_attention=only_cross_attention[i], 241 | upcast_attention=upcast_attention, 242 | resnet_time_scale_shift=resnet_time_scale_shift, 243 | 244 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 245 | unet_use_temporal_attention=unet_use_temporal_attention, 246 | 247 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 248 | motion_module_type=motion_module_type, 249 | motion_module_kwargs=motion_module_kwargs, 250 | ) 251 | self.up_blocks.append(up_block) 252 | prev_output_channel = output_channel 253 | 254 | # out 255 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 256 | self.conv_act = nn.SiLU() 257 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 258 | 259 | def set_attention_slice(self, slice_size): 260 | r""" 261 | Enable sliced attention computation. 262 | 263 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 264 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 265 | 266 | Args: 267 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 268 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 269 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 270 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 271 | must be a multiple of `slice_size`. 272 | """ 273 | sliceable_head_dims = [] 274 | 275 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 276 | if hasattr(module, "set_attention_slice"): 277 | sliceable_head_dims.append(module.sliceable_head_dim) 278 | 279 | for child in module.children(): 280 | fn_recursive_retrieve_slicable_dims(child) 281 | 282 | # retrieve number of attention layers 283 | for module in self.children(): 284 | fn_recursive_retrieve_slicable_dims(module) 285 | 286 | num_slicable_layers = len(sliceable_head_dims) 287 | 288 | if slice_size == "auto": 289 | # half the attention head size is usually a good trade-off between 290 | # speed and memory 291 | slice_size = [dim // 2 for dim in sliceable_head_dims] 292 | elif slice_size == "max": 293 | # make smallest slice possible 294 | slice_size = num_slicable_layers * [1] 295 | 296 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 297 | 298 | if len(slice_size) != len(sliceable_head_dims): 299 | raise ValueError( 300 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 301 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 302 | ) 303 | 304 | for i in range(len(slice_size)): 305 | size = slice_size[i] 306 | dim = sliceable_head_dims[i] 307 | if size is not None and size > dim: 308 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 309 | 310 | # Recursively walk through all the children. 311 | # Any children which exposes the set_attention_slice method 312 | # gets the message 313 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 314 | if hasattr(module, "set_attention_slice"): 315 | module.set_attention_slice(slice_size.pop()) 316 | 317 | for child in module.children(): 318 | fn_recursive_set_attention_slice(child, slice_size) 319 | 320 | reversed_slice_size = list(reversed(slice_size)) 321 | for module in self.children(): 322 | fn_recursive_set_attention_slice(module, reversed_slice_size) 323 | 324 | def _set_gradient_checkpointing(self, module, value=False): 325 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 326 | module.gradient_checkpointing = value 327 | 328 | def forward( 329 | self, 330 | sample: torch.FloatTensor, 331 | timestep: Union[torch.Tensor, float, int], 332 | encoder_hidden_states: torch.Tensor, 333 | class_labels: Optional[torch.Tensor] = None, 334 | attention_mask: Optional[torch.Tensor] = None, 335 | # for controlnet 336 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 337 | mid_block_additional_residual: Optional[torch.Tensor] = None, 338 | return_dict: bool = True, 339 | ) -> Union[UNet3DConditionOutput, Tuple]: 340 | r""" 341 | Args: 342 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 343 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 344 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 345 | return_dict (`bool`, *optional*, defaults to `True`): 346 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 347 | 348 | Returns: 349 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 350 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 351 | returning a tuple, the first element is the sample tensor. 352 | """ 353 | # By default samples have to be AT least a multiple of the overall upsampling factor. 354 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 355 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 356 | # on the fly if necessary. 357 | default_overall_up_factor = 2**self.num_upsamplers 358 | 359 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 360 | forward_upsample_size = False 361 | upsample_size = None 362 | 363 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 364 | logger.info("Forward upsample size to force interpolation output size.") 365 | forward_upsample_size = True 366 | 367 | # prepare attention_mask 368 | if attention_mask is not None: 369 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 370 | attention_mask = attention_mask.unsqueeze(1) 371 | 372 | # center input if necessary 373 | if self.config.center_input_sample: 374 | sample = 2 * sample - 1.0 375 | 376 | # time 377 | timesteps = timestep 378 | if not torch.is_tensor(timesteps): 379 | # This would be a good case for the `match` statement (Python 3.10+) 380 | is_mps = sample.device.type == "mps" 381 | if isinstance(timestep, float): 382 | dtype = torch.float32 if is_mps else torch.float64 383 | else: 384 | dtype = torch.int32 if is_mps else torch.int64 385 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 386 | elif len(timesteps.shape) == 0: 387 | timesteps = timesteps[None].to(sample.device) 388 | 389 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 390 | timesteps = timesteps.expand(sample.shape[0]) 391 | 392 | t_emb = self.time_proj(timesteps) 393 | 394 | # timesteps does not contain any weights and will always return f32 tensors 395 | # but time_embedding might actually be running in fp16. so we need to cast here. 396 | # there might be better ways to encapsulate this. 397 | t_emb = t_emb.to(dtype=self.dtype) 398 | emb = self.time_embedding(t_emb) 399 | 400 | if self.class_embedding is not None: 401 | if class_labels is None: 402 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 403 | 404 | if self.config.class_embed_type == "timestep": 405 | class_labels = self.time_proj(class_labels) 406 | 407 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 408 | emb = emb + class_emb 409 | 410 | # pre-process 411 | sample = self.conv_in(sample) 412 | 413 | # down 414 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 415 | 416 | down_block_res_samples = (sample,) 417 | for downsample_block in self.down_blocks: 418 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 419 | sample, res_samples = downsample_block( 420 | hidden_states=sample, 421 | temb=emb, 422 | encoder_hidden_states=encoder_hidden_states, 423 | attention_mask=attention_mask, 424 | ) 425 | else: 426 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) 427 | 428 | down_block_res_samples += res_samples 429 | 430 | if is_controlnet: 431 | new_down_block_res_samples = () 432 | 433 | for down_block_res_sample, down_block_additional_residual in zip( 434 | down_block_res_samples, down_block_additional_residuals 435 | ): 436 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 437 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 438 | 439 | down_block_res_samples = new_down_block_res_samples 440 | 441 | # mid 442 | sample = self.mid_block( 443 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 444 | ) 445 | 446 | if is_controlnet: 447 | sample = sample + mid_block_additional_residual 448 | 449 | # up 450 | for i, upsample_block in enumerate(self.up_blocks): 451 | is_final_block = i == len(self.up_blocks) - 1 452 | 453 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 454 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 455 | 456 | # if we have not reached the final block and need to forward the 457 | # upsample size, we do it here 458 | if not is_final_block and forward_upsample_size: 459 | upsample_size = down_block_res_samples[-1].shape[2:] 460 | 461 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 462 | sample = upsample_block( 463 | hidden_states=sample, 464 | temb=emb, 465 | res_hidden_states_tuple=res_samples, 466 | encoder_hidden_states=encoder_hidden_states, 467 | upsample_size=upsample_size, 468 | attention_mask=attention_mask, 469 | ) 470 | else: 471 | sample = upsample_block( 472 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, 473 | ) 474 | 475 | # post-process 476 | sample = self.conv_norm_out(sample) 477 | sample = self.conv_act(sample) 478 | sample = self.conv_out(sample) 479 | 480 | if not return_dict: 481 | return (sample,) 482 | 483 | return UNet3DConditionOutput(sample=sample) 484 | 485 | @classmethod 486 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): 487 | if subfolder is not None: 488 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 489 | print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") 490 | 491 | config_file = os.path.join(pretrained_model_path, 'config.json') 492 | if not os.path.isfile(config_file): 493 | raise RuntimeError(f"{config_file} does not exist") 494 | with open(config_file, "r") as f: 495 | config = json.load(f) 496 | config["_class_name"] = cls.__name__ 497 | config["down_block_types"] = [ 498 | "CrossAttnDownBlock3D", 499 | "CrossAttnDownBlock3D", 500 | "CrossAttnDownBlock3D", 501 | "DownBlock3D" 502 | ] 503 | config["up_block_types"] = [ 504 | "UpBlock3D", 505 | "CrossAttnUpBlock3D", 506 | "CrossAttnUpBlock3D", 507 | "CrossAttnUpBlock3D" 508 | ] 509 | # config["mid_block_type"] = "UNetMidBlock3DCrossAttn" 510 | 511 | from diffusers.utils import WEIGHTS_NAME 512 | model = cls.from_config(config, **unet_additional_kwargs) 513 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 514 | if not os.path.isfile(model_file): 515 | raise RuntimeError(f"{model_file} does not exist") 516 | state_dict = torch.load(model_file, map_location="cpu") 517 | 518 | m, u = model.load_state_dict(state_dict, strict=False) 519 | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 520 | # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") 521 | 522 | params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] 523 | print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") 524 | 525 | return model 526 | -------------------------------------------------------------------------------- /magicanimate/models/controlnet.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2023 The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | from dataclasses import dataclass 21 | from typing import Any, Dict, List, Optional, Tuple, Union 22 | 23 | import torch 24 | from torch import nn 25 | from torch.nn import functional as F 26 | 27 | from diffusers.configuration_utils import ConfigMixin, register_to_config 28 | from diffusers.utils import BaseOutput, logging 29 | from .embeddings import TimestepEmbedding, Timesteps 30 | from diffusers.models.modeling_utils import ModelMixin 31 | from diffusers.models.unet_2d_blocks import ( 32 | CrossAttnDownBlock2D, 33 | DownBlock2D, 34 | UNetMidBlock2DCrossAttn, 35 | get_down_block, 36 | ) 37 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 38 | 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | @dataclass 44 | class ControlNetOutput(BaseOutput): 45 | down_block_res_samples: Tuple[torch.Tensor] 46 | mid_block_res_sample: torch.Tensor 47 | 48 | 49 | class ControlNetConditioningEmbedding(nn.Module): 50 | """ 51 | Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN 52 | [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized 53 | training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the 54 | convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides 55 | (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full 56 | model) to encode image-space conditions ... into feature maps ..." 57 | """ 58 | 59 | def __init__( 60 | self, 61 | conditioning_embedding_channels: int, 62 | conditioning_channels: int = 3, 63 | block_out_channels: Tuple[int] = (16, 32, 96, 256), 64 | ): 65 | super().__init__() 66 | 67 | self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 68 | 69 | self.blocks = nn.ModuleList([]) 70 | 71 | for i in range(len(block_out_channels) - 1): 72 | channel_in = block_out_channels[i] 73 | channel_out = block_out_channels[i + 1] 74 | self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) 75 | self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 76 | 77 | self.conv_out = zero_module( 78 | nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 79 | ) 80 | 81 | def forward(self, conditioning): 82 | embedding = self.conv_in(conditioning) 83 | embedding = F.silu(embedding) 84 | 85 | for block in self.blocks: 86 | embedding = block(embedding) 87 | embedding = F.silu(embedding) 88 | 89 | embedding = self.conv_out(embedding) 90 | 91 | return embedding 92 | 93 | 94 | class ControlNetModel(ModelMixin, ConfigMixin): 95 | _supports_gradient_checkpointing = True 96 | 97 | @register_to_config 98 | def __init__( 99 | self, 100 | in_channels: int = 4, 101 | flip_sin_to_cos: bool = True, 102 | freq_shift: int = 0, 103 | down_block_types: Tuple[str] = ( 104 | "CrossAttnDownBlock2D", 105 | "CrossAttnDownBlock2D", 106 | "CrossAttnDownBlock2D", 107 | "DownBlock2D", 108 | ), 109 | only_cross_attention: Union[bool, Tuple[bool]] = False, 110 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 111 | layers_per_block: int = 2, 112 | downsample_padding: int = 1, 113 | mid_block_scale_factor: float = 1, 114 | act_fn: str = "silu", 115 | norm_num_groups: Optional[int] = 32, 116 | norm_eps: float = 1e-5, 117 | cross_attention_dim: int = 1280, 118 | attention_head_dim: Union[int, Tuple[int]] = 8, 119 | use_linear_projection: bool = False, 120 | class_embed_type: Optional[str] = None, 121 | num_class_embeds: Optional[int] = None, 122 | upcast_attention: bool = False, 123 | resnet_time_scale_shift: str = "default", 124 | projection_class_embeddings_input_dim: Optional[int] = None, 125 | controlnet_conditioning_channel_order: str = "rgb", 126 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 127 | ): 128 | super().__init__() 129 | 130 | # Check inputs 131 | if len(block_out_channels) != len(down_block_types): 132 | raise ValueError( 133 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 134 | ) 135 | 136 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 137 | raise ValueError( 138 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 139 | ) 140 | 141 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): 142 | raise ValueError( 143 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." 144 | ) 145 | 146 | # input 147 | conv_in_kernel = 3 148 | conv_in_padding = (conv_in_kernel - 1) // 2 149 | self.conv_in = nn.Conv2d( 150 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 151 | ) 152 | 153 | # time 154 | time_embed_dim = block_out_channels[0] * 4 155 | 156 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 157 | timestep_input_dim = block_out_channels[0] 158 | 159 | self.time_embedding = TimestepEmbedding( 160 | timestep_input_dim, 161 | time_embed_dim, 162 | act_fn=act_fn, 163 | ) 164 | 165 | # class embedding 166 | if class_embed_type is None and num_class_embeds is not None: 167 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 168 | elif class_embed_type == "timestep": 169 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 170 | elif class_embed_type == "identity": 171 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 172 | elif class_embed_type == "projection": 173 | if projection_class_embeddings_input_dim is None: 174 | raise ValueError( 175 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 176 | ) 177 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 178 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 179 | # 2. it projects from an arbitrary input dimension. 180 | # 181 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 182 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 183 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 184 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 185 | else: 186 | self.class_embedding = None 187 | 188 | # control net conditioning embedding 189 | self.controlnet_cond_embedding = ControlNetConditioningEmbedding( 190 | conditioning_embedding_channels=block_out_channels[0], 191 | block_out_channels=conditioning_embedding_out_channels, 192 | ) 193 | 194 | self.down_blocks = nn.ModuleList([]) 195 | self.controlnet_down_blocks = nn.ModuleList([]) 196 | 197 | if isinstance(only_cross_attention, bool): 198 | only_cross_attention = [only_cross_attention] * len(down_block_types) 199 | 200 | if isinstance(attention_head_dim, int): 201 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 202 | 203 | # down 204 | output_channel = block_out_channels[0] 205 | 206 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 207 | controlnet_block = zero_module(controlnet_block) 208 | self.controlnet_down_blocks.append(controlnet_block) 209 | 210 | for i, down_block_type in enumerate(down_block_types): 211 | input_channel = output_channel 212 | output_channel = block_out_channels[i] 213 | is_final_block = i == len(block_out_channels) - 1 214 | 215 | down_block = get_down_block( 216 | down_block_type, 217 | num_layers=layers_per_block, 218 | in_channels=input_channel, 219 | out_channels=output_channel, 220 | temb_channels=time_embed_dim, 221 | add_downsample=not is_final_block, 222 | resnet_eps=norm_eps, 223 | resnet_act_fn=act_fn, 224 | resnet_groups=norm_num_groups, 225 | cross_attention_dim=cross_attention_dim, 226 | num_attention_heads=attention_head_dim[i], 227 | downsample_padding=downsample_padding, 228 | use_linear_projection=use_linear_projection, 229 | only_cross_attention=only_cross_attention[i], 230 | upcast_attention=upcast_attention, 231 | resnet_time_scale_shift=resnet_time_scale_shift, 232 | ) 233 | self.down_blocks.append(down_block) 234 | 235 | for _ in range(layers_per_block): 236 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 237 | controlnet_block = zero_module(controlnet_block) 238 | self.controlnet_down_blocks.append(controlnet_block) 239 | 240 | if not is_final_block: 241 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 242 | controlnet_block = zero_module(controlnet_block) 243 | self.controlnet_down_blocks.append(controlnet_block) 244 | 245 | # mid 246 | mid_block_channel = block_out_channels[-1] 247 | 248 | controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) 249 | controlnet_block = zero_module(controlnet_block) 250 | self.controlnet_mid_block = controlnet_block 251 | 252 | self.mid_block = UNetMidBlock2DCrossAttn( 253 | in_channels=mid_block_channel, 254 | temb_channels=time_embed_dim, 255 | resnet_eps=norm_eps, 256 | resnet_act_fn=act_fn, 257 | output_scale_factor=mid_block_scale_factor, 258 | resnet_time_scale_shift=resnet_time_scale_shift, 259 | cross_attention_dim=cross_attention_dim, 260 | num_attention_heads=attention_head_dim[-1], 261 | resnet_groups=norm_num_groups, 262 | use_linear_projection=use_linear_projection, 263 | upcast_attention=upcast_attention, 264 | ) 265 | 266 | @classmethod 267 | def from_unet( 268 | cls, 269 | unet: UNet2DConditionModel, 270 | controlnet_conditioning_channel_order: str = "rgb", 271 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 272 | load_weights_from_unet: bool = True, 273 | ): 274 | r""" 275 | Instantiate Controlnet class from UNet2DConditionModel. 276 | 277 | Parameters: 278 | unet (`UNet2DConditionModel`): 279 | UNet model which weights are copied to the ControlNet. Note that all configuration options are also 280 | copied where applicable. 281 | """ 282 | controlnet = cls( 283 | in_channels=unet.config.in_channels, 284 | flip_sin_to_cos=unet.config.flip_sin_to_cos, 285 | freq_shift=unet.config.freq_shift, 286 | down_block_types=unet.config.down_block_types, 287 | only_cross_attention=unet.config.only_cross_attention, 288 | block_out_channels=unet.config.block_out_channels, 289 | layers_per_block=unet.config.layers_per_block, 290 | downsample_padding=unet.config.downsample_padding, 291 | mid_block_scale_factor=unet.config.mid_block_scale_factor, 292 | act_fn=unet.config.act_fn, 293 | norm_num_groups=unet.config.norm_num_groups, 294 | norm_eps=unet.config.norm_eps, 295 | cross_attention_dim=unet.config.cross_attention_dim, 296 | attention_head_dim=unet.config.attention_head_dim, 297 | use_linear_projection=unet.config.use_linear_projection, 298 | class_embed_type=unet.config.class_embed_type, 299 | num_class_embeds=unet.config.num_class_embeds, 300 | upcast_attention=unet.config.upcast_attention, 301 | resnet_time_scale_shift=unet.config.resnet_time_scale_shift, 302 | projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, 303 | controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, 304 | conditioning_embedding_out_channels=conditioning_embedding_out_channels, 305 | ) 306 | 307 | if load_weights_from_unet: 308 | controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) 309 | controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) 310 | controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) 311 | 312 | if controlnet.class_embedding: 313 | controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) 314 | 315 | controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) 316 | controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) 317 | 318 | return controlnet 319 | 320 | # @property 321 | # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 322 | # def attn_processors(self) -> Dict[str, AttentionProcessor]: 323 | # r""" 324 | # Returns: 325 | # `dict` of attention processors: A dictionary containing all attention processors used in the model with 326 | # indexed by its weight name. 327 | # """ 328 | # # set recursively 329 | # processors = {} 330 | 331 | # def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 332 | # if hasattr(module, "set_processor"): 333 | # processors[f"{name}.processor"] = module.processor 334 | 335 | # for sub_name, child in module.named_children(): 336 | # fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 337 | 338 | # return processors 339 | 340 | # for name, module in self.named_children(): 341 | # fn_recursive_add_processors(name, module, processors) 342 | 343 | # return processors 344 | 345 | # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 346 | # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 347 | # r""" 348 | # Parameters: 349 | # `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): 350 | # The instantiated processor class or a dictionary of processor classes that will be set as the processor 351 | # of **all** `Attention` layers. 352 | # In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: 353 | 354 | # """ 355 | # count = len(self.attn_processors.keys()) 356 | 357 | # if isinstance(processor, dict) and len(processor) != count: 358 | # raise ValueError( 359 | # f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 360 | # f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 361 | # ) 362 | 363 | # def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 364 | # if hasattr(module, "set_processor"): 365 | # if not isinstance(processor, dict): 366 | # module.set_processor(processor) 367 | # else: 368 | # module.set_processor(processor.pop(f"{name}.processor")) 369 | 370 | # for sub_name, child in module.named_children(): 371 | # fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 372 | 373 | # for name, module in self.named_children(): 374 | # fn_recursive_attn_processor(name, module, processor) 375 | 376 | # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 377 | # def set_default_attn_processor(self): 378 | # """ 379 | # Disables custom attention processors and sets the default attention implementation. 380 | # """ 381 | # self.set_attn_processor(AttnProcessor()) 382 | 383 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice 384 | def set_attention_slice(self, slice_size): 385 | r""" 386 | Enable sliced attention computation. 387 | 388 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 389 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 390 | 391 | Args: 392 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 393 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 394 | `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is 395 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 396 | must be a multiple of `slice_size`. 397 | """ 398 | sliceable_head_dims = [] 399 | 400 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 401 | if hasattr(module, "set_attention_slice"): 402 | sliceable_head_dims.append(module.sliceable_head_dim) 403 | 404 | for child in module.children(): 405 | fn_recursive_retrieve_sliceable_dims(child) 406 | 407 | # retrieve number of attention layers 408 | for module in self.children(): 409 | fn_recursive_retrieve_sliceable_dims(module) 410 | 411 | num_sliceable_layers = len(sliceable_head_dims) 412 | 413 | if slice_size == "auto": 414 | # half the attention head size is usually a good trade-off between 415 | # speed and memory 416 | slice_size = [dim // 2 for dim in sliceable_head_dims] 417 | elif slice_size == "max": 418 | # make smallest slice possible 419 | slice_size = num_sliceable_layers * [1] 420 | 421 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 422 | 423 | if len(slice_size) != len(sliceable_head_dims): 424 | raise ValueError( 425 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 426 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 427 | ) 428 | 429 | for i in range(len(slice_size)): 430 | size = slice_size[i] 431 | dim = sliceable_head_dims[i] 432 | if size is not None and size > dim: 433 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 434 | 435 | # Recursively walk through all the children. 436 | # Any children which exposes the set_attention_slice method 437 | # gets the message 438 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 439 | if hasattr(module, "set_attention_slice"): 440 | module.set_attention_slice(slice_size.pop()) 441 | 442 | for child in module.children(): 443 | fn_recursive_set_attention_slice(child, slice_size) 444 | 445 | reversed_slice_size = list(reversed(slice_size)) 446 | for module in self.children(): 447 | fn_recursive_set_attention_slice(module, reversed_slice_size) 448 | 449 | def _set_gradient_checkpointing(self, module, value=False): 450 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): 451 | module.gradient_checkpointing = value 452 | 453 | def forward( 454 | self, 455 | sample: torch.FloatTensor, 456 | timestep: Union[torch.Tensor, float, int], 457 | encoder_hidden_states: torch.Tensor, 458 | controlnet_cond: torch.FloatTensor, 459 | conditioning_scale: float = 1.0, 460 | class_labels: Optional[torch.Tensor] = None, 461 | timestep_cond: Optional[torch.Tensor] = None, 462 | attention_mask: Optional[torch.Tensor] = None, 463 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 464 | return_dict: bool = True, 465 | ) -> Union[ControlNetOutput, Tuple]: 466 | # check channel order 467 | channel_order = self.config.controlnet_conditioning_channel_order 468 | 469 | if channel_order == "rgb": 470 | # in rgb order by default 471 | ... 472 | elif channel_order == "bgr": 473 | controlnet_cond = torch.flip(controlnet_cond, dims=[1]) 474 | else: 475 | raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") 476 | 477 | # prepare attention_mask 478 | if attention_mask is not None: 479 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 480 | attention_mask = attention_mask.unsqueeze(1) 481 | 482 | # 1. time 483 | timesteps = timestep 484 | if not torch.is_tensor(timesteps): 485 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 486 | # This would be a good case for the `match` statement (Python 3.10+) 487 | is_mps = sample.device.type == "mps" 488 | if isinstance(timestep, float): 489 | dtype = torch.float32 if is_mps else torch.float64 490 | else: 491 | dtype = torch.int32 if is_mps else torch.int64 492 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 493 | elif len(timesteps.shape) == 0: 494 | timesteps = timesteps[None].to(sample.device) 495 | 496 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 497 | timesteps = timesteps.expand(sample.shape[0]) 498 | 499 | t_emb = self.time_proj(timesteps) 500 | 501 | # timesteps does not contain any weights and will always return f32 tensors 502 | # but time_embedding might actually be running in fp16. so we need to cast here. 503 | # there might be better ways to encapsulate this. 504 | t_emb = t_emb.to(dtype=self.dtype) 505 | 506 | emb = self.time_embedding(t_emb, timestep_cond) 507 | 508 | if self.class_embedding is not None: 509 | if class_labels is None: 510 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 511 | 512 | if self.config.class_embed_type == "timestep": 513 | class_labels = self.time_proj(class_labels) 514 | 515 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 516 | emb = emb + class_emb 517 | 518 | # 2. pre-process 519 | sample = self.conv_in(sample) 520 | 521 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 522 | 523 | sample += controlnet_cond 524 | 525 | # 3. down 526 | down_block_res_samples = (sample,) 527 | for downsample_block in self.down_blocks: 528 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 529 | sample, res_samples = downsample_block( 530 | hidden_states=sample, 531 | temb=emb, 532 | encoder_hidden_states=encoder_hidden_states, 533 | attention_mask=attention_mask, 534 | # cross_attention_kwargs=cross_attention_kwargs, 535 | ) 536 | else: 537 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 538 | 539 | down_block_res_samples += res_samples 540 | 541 | # 4. mid 542 | if self.mid_block is not None: 543 | sample = self.mid_block( 544 | sample, 545 | emb, 546 | encoder_hidden_states=encoder_hidden_states, 547 | attention_mask=attention_mask, 548 | # cross_attention_kwargs=cross_attention_kwargs, 549 | ) 550 | 551 | # 5. Control net blocks 552 | 553 | controlnet_down_block_res_samples = () 554 | 555 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 556 | down_block_res_sample = controlnet_block(down_block_res_sample) 557 | controlnet_down_block_res_samples += (down_block_res_sample,) 558 | 559 | down_block_res_samples = controlnet_down_block_res_samples 560 | 561 | mid_block_res_sample = self.controlnet_mid_block(sample) 562 | 563 | # 6. scaling 564 | down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] 565 | mid_block_res_sample *= conditioning_scale 566 | 567 | if not return_dict: 568 | return (down_block_res_samples, mid_block_res_sample) 569 | 570 | return ControlNetOutput( 571 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample 572 | ) 573 | 574 | 575 | def zero_module(module): 576 | for p in module.parameters(): 577 | nn.init.zeros_(p) 578 | return module -------------------------------------------------------------------------------- /magicanimate/models/unet_3d_blocks.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | 9 | # Copyright 2023 The HuggingFace Team. All rights reserved. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | import torch 23 | from torch import nn 24 | 25 | from .attention import Transformer3DModel 26 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 27 | from .motion_module import get_motion_module 28 | 29 | 30 | def get_down_block( 31 | down_block_type, 32 | num_layers, 33 | in_channels, 34 | out_channels, 35 | temb_channels, 36 | add_downsample, 37 | resnet_eps, 38 | resnet_act_fn, 39 | attn_num_head_channels, 40 | resnet_groups=None, 41 | cross_attention_dim=None, 42 | downsample_padding=None, 43 | dual_cross_attention=False, 44 | use_linear_projection=False, 45 | only_cross_attention=False, 46 | upcast_attention=False, 47 | resnet_time_scale_shift="default", 48 | 49 | unet_use_cross_frame_attention=None, 50 | unet_use_temporal_attention=None, 51 | 52 | use_motion_module=None, 53 | 54 | motion_module_type=None, 55 | motion_module_kwargs=None, 56 | ): 57 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 58 | if down_block_type == "DownBlock3D": 59 | return DownBlock3D( 60 | num_layers=num_layers, 61 | in_channels=in_channels, 62 | out_channels=out_channels, 63 | temb_channels=temb_channels, 64 | add_downsample=add_downsample, 65 | resnet_eps=resnet_eps, 66 | resnet_act_fn=resnet_act_fn, 67 | resnet_groups=resnet_groups, 68 | downsample_padding=downsample_padding, 69 | resnet_time_scale_shift=resnet_time_scale_shift, 70 | 71 | use_motion_module=use_motion_module, 72 | motion_module_type=motion_module_type, 73 | motion_module_kwargs=motion_module_kwargs, 74 | ) 75 | elif down_block_type == "CrossAttnDownBlock3D": 76 | if cross_attention_dim is None: 77 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 78 | return CrossAttnDownBlock3D( 79 | num_layers=num_layers, 80 | in_channels=in_channels, 81 | out_channels=out_channels, 82 | temb_channels=temb_channels, 83 | add_downsample=add_downsample, 84 | resnet_eps=resnet_eps, 85 | resnet_act_fn=resnet_act_fn, 86 | resnet_groups=resnet_groups, 87 | downsample_padding=downsample_padding, 88 | cross_attention_dim=cross_attention_dim, 89 | attn_num_head_channels=attn_num_head_channels, 90 | dual_cross_attention=dual_cross_attention, 91 | use_linear_projection=use_linear_projection, 92 | only_cross_attention=only_cross_attention, 93 | upcast_attention=upcast_attention, 94 | resnet_time_scale_shift=resnet_time_scale_shift, 95 | 96 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 97 | unet_use_temporal_attention=unet_use_temporal_attention, 98 | 99 | use_motion_module=use_motion_module, 100 | motion_module_type=motion_module_type, 101 | motion_module_kwargs=motion_module_kwargs, 102 | ) 103 | raise ValueError(f"{down_block_type} does not exist.") 104 | 105 | 106 | def get_up_block( 107 | up_block_type, 108 | num_layers, 109 | in_channels, 110 | out_channels, 111 | prev_output_channel, 112 | temb_channels, 113 | add_upsample, 114 | resnet_eps, 115 | resnet_act_fn, 116 | attn_num_head_channels, 117 | resnet_groups=None, 118 | cross_attention_dim=None, 119 | dual_cross_attention=False, 120 | use_linear_projection=False, 121 | only_cross_attention=False, 122 | upcast_attention=False, 123 | resnet_time_scale_shift="default", 124 | 125 | unet_use_cross_frame_attention=None, 126 | unet_use_temporal_attention=None, 127 | 128 | use_motion_module=None, 129 | motion_module_type=None, 130 | motion_module_kwargs=None, 131 | ): 132 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 133 | if up_block_type == "UpBlock3D": 134 | return UpBlock3D( 135 | num_layers=num_layers, 136 | in_channels=in_channels, 137 | out_channels=out_channels, 138 | prev_output_channel=prev_output_channel, 139 | temb_channels=temb_channels, 140 | add_upsample=add_upsample, 141 | resnet_eps=resnet_eps, 142 | resnet_act_fn=resnet_act_fn, 143 | resnet_groups=resnet_groups, 144 | resnet_time_scale_shift=resnet_time_scale_shift, 145 | 146 | use_motion_module=use_motion_module, 147 | motion_module_type=motion_module_type, 148 | motion_module_kwargs=motion_module_kwargs, 149 | ) 150 | elif up_block_type == "CrossAttnUpBlock3D": 151 | if cross_attention_dim is None: 152 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 153 | return CrossAttnUpBlock3D( 154 | num_layers=num_layers, 155 | in_channels=in_channels, 156 | out_channels=out_channels, 157 | prev_output_channel=prev_output_channel, 158 | temb_channels=temb_channels, 159 | add_upsample=add_upsample, 160 | resnet_eps=resnet_eps, 161 | resnet_act_fn=resnet_act_fn, 162 | resnet_groups=resnet_groups, 163 | cross_attention_dim=cross_attention_dim, 164 | attn_num_head_channels=attn_num_head_channels, 165 | dual_cross_attention=dual_cross_attention, 166 | use_linear_projection=use_linear_projection, 167 | only_cross_attention=only_cross_attention, 168 | upcast_attention=upcast_attention, 169 | resnet_time_scale_shift=resnet_time_scale_shift, 170 | 171 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 172 | unet_use_temporal_attention=unet_use_temporal_attention, 173 | 174 | use_motion_module=use_motion_module, 175 | motion_module_type=motion_module_type, 176 | motion_module_kwargs=motion_module_kwargs, 177 | ) 178 | raise ValueError(f"{up_block_type} does not exist.") 179 | 180 | 181 | class UNetMidBlock3DCrossAttn(nn.Module): 182 | def __init__( 183 | self, 184 | in_channels: int, 185 | temb_channels: int, 186 | dropout: float = 0.0, 187 | num_layers: int = 1, 188 | resnet_eps: float = 1e-6, 189 | resnet_time_scale_shift: str = "default", 190 | resnet_act_fn: str = "swish", 191 | resnet_groups: int = 32, 192 | resnet_pre_norm: bool = True, 193 | attn_num_head_channels=1, 194 | output_scale_factor=1.0, 195 | cross_attention_dim=1280, 196 | dual_cross_attention=False, 197 | use_linear_projection=False, 198 | upcast_attention=False, 199 | 200 | unet_use_cross_frame_attention=None, 201 | unet_use_temporal_attention=None, 202 | 203 | use_motion_module=None, 204 | 205 | motion_module_type=None, 206 | motion_module_kwargs=None, 207 | ): 208 | super().__init__() 209 | 210 | self.has_cross_attention = True 211 | self.attn_num_head_channels = attn_num_head_channels 212 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 213 | 214 | # there is always at least one resnet 215 | resnets = [ 216 | ResnetBlock3D( 217 | in_channels=in_channels, 218 | out_channels=in_channels, 219 | temb_channels=temb_channels, 220 | eps=resnet_eps, 221 | groups=resnet_groups, 222 | dropout=dropout, 223 | time_embedding_norm=resnet_time_scale_shift, 224 | non_linearity=resnet_act_fn, 225 | output_scale_factor=output_scale_factor, 226 | pre_norm=resnet_pre_norm, 227 | ) 228 | ] 229 | attentions = [] 230 | motion_modules = [] 231 | 232 | for _ in range(num_layers): 233 | if dual_cross_attention: 234 | raise NotImplementedError 235 | attentions.append( 236 | Transformer3DModel( 237 | attn_num_head_channels, 238 | in_channels // attn_num_head_channels, 239 | in_channels=in_channels, 240 | num_layers=1, 241 | cross_attention_dim=cross_attention_dim, 242 | norm_num_groups=resnet_groups, 243 | use_linear_projection=use_linear_projection, 244 | upcast_attention=upcast_attention, 245 | 246 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 247 | unet_use_temporal_attention=unet_use_temporal_attention, 248 | ) 249 | ) 250 | motion_modules.append( 251 | get_motion_module( 252 | in_channels=in_channels, 253 | motion_module_type=motion_module_type, 254 | motion_module_kwargs=motion_module_kwargs, 255 | ) if use_motion_module else None 256 | ) 257 | resnets.append( 258 | ResnetBlock3D( 259 | in_channels=in_channels, 260 | out_channels=in_channels, 261 | temb_channels=temb_channels, 262 | eps=resnet_eps, 263 | groups=resnet_groups, 264 | dropout=dropout, 265 | time_embedding_norm=resnet_time_scale_shift, 266 | non_linearity=resnet_act_fn, 267 | output_scale_factor=output_scale_factor, 268 | pre_norm=resnet_pre_norm, 269 | ) 270 | ) 271 | 272 | self.attentions = nn.ModuleList(attentions) 273 | self.resnets = nn.ModuleList(resnets) 274 | self.motion_modules = nn.ModuleList(motion_modules) 275 | 276 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 277 | hidden_states = self.resnets[0](hidden_states, temb) 278 | for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): 279 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 280 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 281 | hidden_states = resnet(hidden_states, temb) 282 | 283 | return hidden_states 284 | 285 | 286 | class CrossAttnDownBlock3D(nn.Module): 287 | def __init__( 288 | self, 289 | in_channels: int, 290 | out_channels: int, 291 | temb_channels: int, 292 | dropout: float = 0.0, 293 | num_layers: int = 1, 294 | resnet_eps: float = 1e-6, 295 | resnet_time_scale_shift: str = "default", 296 | resnet_act_fn: str = "swish", 297 | resnet_groups: int = 32, 298 | resnet_pre_norm: bool = True, 299 | attn_num_head_channels=1, 300 | cross_attention_dim=1280, 301 | output_scale_factor=1.0, 302 | downsample_padding=1, 303 | add_downsample=True, 304 | dual_cross_attention=False, 305 | use_linear_projection=False, 306 | only_cross_attention=False, 307 | upcast_attention=False, 308 | 309 | unet_use_cross_frame_attention=None, 310 | unet_use_temporal_attention=None, 311 | 312 | use_motion_module=None, 313 | 314 | motion_module_type=None, 315 | motion_module_kwargs=None, 316 | ): 317 | super().__init__() 318 | resnets = [] 319 | attentions = [] 320 | motion_modules = [] 321 | 322 | self.has_cross_attention = True 323 | self.attn_num_head_channels = attn_num_head_channels 324 | 325 | for i in range(num_layers): 326 | in_channels = in_channels if i == 0 else out_channels 327 | resnets.append( 328 | ResnetBlock3D( 329 | in_channels=in_channels, 330 | out_channels=out_channels, 331 | temb_channels=temb_channels, 332 | eps=resnet_eps, 333 | groups=resnet_groups, 334 | dropout=dropout, 335 | time_embedding_norm=resnet_time_scale_shift, 336 | non_linearity=resnet_act_fn, 337 | output_scale_factor=output_scale_factor, 338 | pre_norm=resnet_pre_norm, 339 | ) 340 | ) 341 | if dual_cross_attention: 342 | raise NotImplementedError 343 | attentions.append( 344 | Transformer3DModel( 345 | attn_num_head_channels, 346 | out_channels // attn_num_head_channels, 347 | in_channels=out_channels, 348 | num_layers=1, 349 | cross_attention_dim=cross_attention_dim, 350 | norm_num_groups=resnet_groups, 351 | use_linear_projection=use_linear_projection, 352 | only_cross_attention=only_cross_attention, 353 | upcast_attention=upcast_attention, 354 | 355 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 356 | unet_use_temporal_attention=unet_use_temporal_attention, 357 | ) 358 | ) 359 | motion_modules.append( 360 | get_motion_module( 361 | in_channels=out_channels, 362 | motion_module_type=motion_module_type, 363 | motion_module_kwargs=motion_module_kwargs, 364 | ) if use_motion_module else None 365 | ) 366 | 367 | self.attentions = nn.ModuleList(attentions) 368 | self.resnets = nn.ModuleList(resnets) 369 | self.motion_modules = nn.ModuleList(motion_modules) 370 | 371 | if add_downsample: 372 | self.downsamplers = nn.ModuleList( 373 | [ 374 | Downsample3D( 375 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 376 | ) 377 | ] 378 | ) 379 | else: 380 | self.downsamplers = None 381 | 382 | self.gradient_checkpointing = False 383 | 384 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 385 | output_states = () 386 | 387 | for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): 388 | if self.training and self.gradient_checkpointing: 389 | 390 | def create_custom_forward(module, return_dict=None): 391 | def custom_forward(*inputs): 392 | if return_dict is not None: 393 | return module(*inputs, return_dict=return_dict) 394 | else: 395 | return module(*inputs) 396 | 397 | return custom_forward 398 | 399 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 400 | hidden_states = torch.utils.checkpoint.checkpoint( 401 | create_custom_forward(attn, return_dict=False), 402 | hidden_states, 403 | encoder_hidden_states, 404 | )[0] 405 | if motion_module is not None: 406 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 407 | 408 | else: 409 | hidden_states = resnet(hidden_states, temb) 410 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 411 | 412 | # add motion module 413 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 414 | 415 | output_states += (hidden_states,) 416 | 417 | if self.downsamplers is not None: 418 | for downsampler in self.downsamplers: 419 | hidden_states = downsampler(hidden_states) 420 | 421 | output_states += (hidden_states,) 422 | 423 | return hidden_states, output_states 424 | 425 | 426 | class DownBlock3D(nn.Module): 427 | def __init__( 428 | self, 429 | in_channels: int, 430 | out_channels: int, 431 | temb_channels: int, 432 | dropout: float = 0.0, 433 | num_layers: int = 1, 434 | resnet_eps: float = 1e-6, 435 | resnet_time_scale_shift: str = "default", 436 | resnet_act_fn: str = "swish", 437 | resnet_groups: int = 32, 438 | resnet_pre_norm: bool = True, 439 | output_scale_factor=1.0, 440 | add_downsample=True, 441 | downsample_padding=1, 442 | 443 | use_motion_module=None, 444 | motion_module_type=None, 445 | motion_module_kwargs=None, 446 | ): 447 | super().__init__() 448 | resnets = [] 449 | motion_modules = [] 450 | 451 | for i in range(num_layers): 452 | in_channels = in_channels if i == 0 else out_channels 453 | resnets.append( 454 | ResnetBlock3D( 455 | in_channels=in_channels, 456 | out_channels=out_channels, 457 | temb_channels=temb_channels, 458 | eps=resnet_eps, 459 | groups=resnet_groups, 460 | dropout=dropout, 461 | time_embedding_norm=resnet_time_scale_shift, 462 | non_linearity=resnet_act_fn, 463 | output_scale_factor=output_scale_factor, 464 | pre_norm=resnet_pre_norm, 465 | ) 466 | ) 467 | motion_modules.append( 468 | get_motion_module( 469 | in_channels=out_channels, 470 | motion_module_type=motion_module_type, 471 | motion_module_kwargs=motion_module_kwargs, 472 | ) if use_motion_module else None 473 | ) 474 | 475 | self.resnets = nn.ModuleList(resnets) 476 | self.motion_modules = nn.ModuleList(motion_modules) 477 | 478 | if add_downsample: 479 | self.downsamplers = nn.ModuleList( 480 | [ 481 | Downsample3D( 482 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 483 | ) 484 | ] 485 | ) 486 | else: 487 | self.downsamplers = None 488 | 489 | self.gradient_checkpointing = False 490 | 491 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 492 | output_states = () 493 | 494 | for resnet, motion_module in zip(self.resnets, self.motion_modules): 495 | if self.training and self.gradient_checkpointing: 496 | def create_custom_forward(module): 497 | def custom_forward(*inputs): 498 | return module(*inputs) 499 | 500 | return custom_forward 501 | 502 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 503 | if motion_module is not None: 504 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 505 | else: 506 | hidden_states = resnet(hidden_states, temb) 507 | 508 | # add motion module 509 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 510 | 511 | output_states += (hidden_states,) 512 | 513 | if self.downsamplers is not None: 514 | for downsampler in self.downsamplers: 515 | hidden_states = downsampler(hidden_states) 516 | 517 | output_states += (hidden_states,) 518 | 519 | return hidden_states, output_states 520 | 521 | 522 | class CrossAttnUpBlock3D(nn.Module): 523 | def __init__( 524 | self, 525 | in_channels: int, 526 | out_channels: int, 527 | prev_output_channel: int, 528 | temb_channels: int, 529 | dropout: float = 0.0, 530 | num_layers: int = 1, 531 | resnet_eps: float = 1e-6, 532 | resnet_time_scale_shift: str = "default", 533 | resnet_act_fn: str = "swish", 534 | resnet_groups: int = 32, 535 | resnet_pre_norm: bool = True, 536 | attn_num_head_channels=1, 537 | cross_attention_dim=1280, 538 | output_scale_factor=1.0, 539 | add_upsample=True, 540 | dual_cross_attention=False, 541 | use_linear_projection=False, 542 | only_cross_attention=False, 543 | upcast_attention=False, 544 | 545 | unet_use_cross_frame_attention=None, 546 | unet_use_temporal_attention=None, 547 | 548 | use_motion_module=None, 549 | 550 | motion_module_type=None, 551 | motion_module_kwargs=None, 552 | ): 553 | super().__init__() 554 | resnets = [] 555 | attentions = [] 556 | motion_modules = [] 557 | 558 | self.has_cross_attention = True 559 | self.attn_num_head_channels = attn_num_head_channels 560 | 561 | for i in range(num_layers): 562 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 563 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 564 | 565 | resnets.append( 566 | ResnetBlock3D( 567 | in_channels=resnet_in_channels + res_skip_channels, 568 | out_channels=out_channels, 569 | temb_channels=temb_channels, 570 | eps=resnet_eps, 571 | groups=resnet_groups, 572 | dropout=dropout, 573 | time_embedding_norm=resnet_time_scale_shift, 574 | non_linearity=resnet_act_fn, 575 | output_scale_factor=output_scale_factor, 576 | pre_norm=resnet_pre_norm, 577 | ) 578 | ) 579 | if dual_cross_attention: 580 | raise NotImplementedError 581 | attentions.append( 582 | Transformer3DModel( 583 | attn_num_head_channels, 584 | out_channels // attn_num_head_channels, 585 | in_channels=out_channels, 586 | num_layers=1, 587 | cross_attention_dim=cross_attention_dim, 588 | norm_num_groups=resnet_groups, 589 | use_linear_projection=use_linear_projection, 590 | only_cross_attention=only_cross_attention, 591 | upcast_attention=upcast_attention, 592 | 593 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 594 | unet_use_temporal_attention=unet_use_temporal_attention, 595 | ) 596 | ) 597 | motion_modules.append( 598 | get_motion_module( 599 | in_channels=out_channels, 600 | motion_module_type=motion_module_type, 601 | motion_module_kwargs=motion_module_kwargs, 602 | ) if use_motion_module else None 603 | ) 604 | 605 | self.attentions = nn.ModuleList(attentions) 606 | self.resnets = nn.ModuleList(resnets) 607 | self.motion_modules = nn.ModuleList(motion_modules) 608 | 609 | if add_upsample: 610 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 611 | else: 612 | self.upsamplers = None 613 | 614 | self.gradient_checkpointing = False 615 | 616 | def forward( 617 | self, 618 | hidden_states, 619 | res_hidden_states_tuple, 620 | temb=None, 621 | encoder_hidden_states=None, 622 | upsample_size=None, 623 | attention_mask=None, 624 | ): 625 | for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): 626 | # pop res hidden states 627 | res_hidden_states = res_hidden_states_tuple[-1] 628 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 629 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 630 | 631 | if self.training and self.gradient_checkpointing: 632 | 633 | def create_custom_forward(module, return_dict=None): 634 | def custom_forward(*inputs): 635 | if return_dict is not None: 636 | return module(*inputs, return_dict=return_dict) 637 | else: 638 | return module(*inputs) 639 | 640 | return custom_forward 641 | 642 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 643 | hidden_states = torch.utils.checkpoint.checkpoint( 644 | create_custom_forward(attn, return_dict=False), 645 | hidden_states, 646 | encoder_hidden_states, 647 | )[0] 648 | if motion_module is not None: 649 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 650 | 651 | else: 652 | hidden_states = resnet(hidden_states, temb) 653 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 654 | 655 | # add motion module 656 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 657 | 658 | if self.upsamplers is not None: 659 | for upsampler in self.upsamplers: 660 | hidden_states = upsampler(hidden_states, upsample_size) 661 | 662 | return hidden_states 663 | 664 | 665 | class UpBlock3D(nn.Module): 666 | def __init__( 667 | self, 668 | in_channels: int, 669 | prev_output_channel: int, 670 | out_channels: int, 671 | temb_channels: int, 672 | dropout: float = 0.0, 673 | num_layers: int = 1, 674 | resnet_eps: float = 1e-6, 675 | resnet_time_scale_shift: str = "default", 676 | resnet_act_fn: str = "swish", 677 | resnet_groups: int = 32, 678 | resnet_pre_norm: bool = True, 679 | output_scale_factor=1.0, 680 | add_upsample=True, 681 | 682 | use_motion_module=None, 683 | motion_module_type=None, 684 | motion_module_kwargs=None, 685 | ): 686 | super().__init__() 687 | resnets = [] 688 | motion_modules = [] 689 | 690 | for i in range(num_layers): 691 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 692 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 693 | 694 | resnets.append( 695 | ResnetBlock3D( 696 | in_channels=resnet_in_channels + res_skip_channels, 697 | out_channels=out_channels, 698 | temb_channels=temb_channels, 699 | eps=resnet_eps, 700 | groups=resnet_groups, 701 | dropout=dropout, 702 | time_embedding_norm=resnet_time_scale_shift, 703 | non_linearity=resnet_act_fn, 704 | output_scale_factor=output_scale_factor, 705 | pre_norm=resnet_pre_norm, 706 | ) 707 | ) 708 | motion_modules.append( 709 | get_motion_module( 710 | in_channels=out_channels, 711 | motion_module_type=motion_module_type, 712 | motion_module_kwargs=motion_module_kwargs, 713 | ) if use_motion_module else None 714 | ) 715 | 716 | self.resnets = nn.ModuleList(resnets) 717 | self.motion_modules = nn.ModuleList(motion_modules) 718 | 719 | if add_upsample: 720 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 721 | else: 722 | self.upsamplers = None 723 | 724 | self.gradient_checkpointing = False 725 | 726 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): 727 | for resnet, motion_module in zip(self.resnets, self.motion_modules): 728 | # pop res hidden states 729 | res_hidden_states = res_hidden_states_tuple[-1] 730 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 731 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 732 | 733 | if self.training and self.gradient_checkpointing: 734 | def create_custom_forward(module): 735 | def custom_forward(*inputs): 736 | return module(*inputs) 737 | 738 | return custom_forward 739 | 740 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 741 | if motion_module is not None: 742 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 743 | else: 744 | hidden_states = resnet(hidden_states, temb) 745 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 746 | 747 | if self.upsamplers is not None: 748 | for upsampler in self.upsamplers: 749 | hidden_states = upsampler(hidden_states, upsample_size) 750 | 751 | return hidden_states --------------------------------------------------------------------------------