├── .DS_Store ├── .gitignore ├── LICENSE.txt ├── README.md ├── animatediff ├── data │ ├── dataset_train_realestate10k.py │ └── dataset_validation.py ├── models │ ├── attention.py │ ├── attention_processor.py │ ├── epi_module.py │ ├── motion_module.py │ ├── pose_adaptor.py │ ├── resnet.py │ ├── sparse_controlnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ ├── pipeline_animation_epi.py │ └── pipeline_animation_epi_advanced.py └── utils │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ └── util.py ├── assets ├── .DS_Store ├── 2c80f9eb0d3b2bb4.txt ├── 2f25826f0d0ef09a.txt ├── cameractrl_prompts.json ├── cameractrl_prompts_for_circle.json ├── cameractrl_prompts_for_interpolate.json └── pose_files │ ├── 0bf152ef84195293.txt │ ├── 0c11dbe781b1c11c.txt │ ├── 0c9b371cc6225682.txt │ ├── 0f47577ab3441480.txt │ ├── 0f68374b76390082.txt │ ├── 2c80f9eb0d3b2bb4.txt │ ├── 2f25826f0d0ef09a.txt │ ├── 3c35b868a8ec3433.txt │ ├── 3f79dc32d575bcdc.txt │ └── 4a2d6753676df096.txt ├── configs ├── inference_config.yaml └── validation_prompts.txt ├── dist_run.sh ├── docs ├── badge-website.svg └── teaser.png ├── environment.yaml ├── inference_epi.py ├── inference_epi_advanced.py ├── requirements.txt ├── run_inference_advanced.sh ├── run_inference_simple.sh ├── tools ├── merge_lora2unet.py └── visualize_trajectory.py └── train_epi_control.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CollaborativeVideoDiffusion/CVD/107f299bd75c7a37158c52427de473cec86c649a/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | debugs/ 3 | output/ 4 | samples/ 5 | __pycache__/ 6 | ossutil_output/ 7 | .ossutil_checkpoint/ 8 | 9 | test_epi/ 10 | test_homography/ 11 | test_pose/ 12 | test_bullet_time/ 13 | results/ 14 | scripts/ 15 | !scripts/animate.py 16 | results*/ 17 | 18 | assets/reference_videos/ 19 | *.ipynb 20 | *.safetensors 21 | *.ckpt 22 | .idea 23 | # *.json 24 | *.csv 25 | # *.txt 26 | temp_attention_mask_for_debug.pt 27 | *.npy 28 | 29 | rendering_visualizing_tools/ 30 | batch_vis_scripts/ 31 | probing_tools/ 32 | models/* 33 | !models/StableDiffusion/ 34 | models/StableDiffusion/* 35 | !models/StableDiffusion/*.txt 36 | !models/Motion_Module/ 37 | !models/Motion_Module/*.txt 38 | !models/DreamBooth_LoRA/ 39 | !models/DreamBooth_LoRA/*.txt 40 | !models/MotionLoRA/ 41 | !models/MotionLoRA/*.txt 42 | fid_workspace/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Collaborative Video Diffusion: Consistent Multi-video Generation with Camera Control 2 | 3 | **NeurIPS 2024** 4 | 5 | This repository represents the official implementation of the paper titled "Collaborative Video Diffusion: Consistent Multi-video Generation with Camera Control". 6 | 7 | *This repository is still under construction, many updates will be applied in the near future.* 8 | 9 | [![Website](docs/badge-website.svg)](https://collaborativevideodiffusion.github.io/) 10 | [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2405.17414) 11 | 12 | [Zhengfei Kuang*](https://zhengfeikuang.com/), 13 | [Shengqu Cai*](https://primecai.github.io/), 14 | [Hao He](https://hehao13.github.io/), 15 | [Yinghao Xu](https://justimyhxu.github.io/), 16 | [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/), 17 | [Leonidas Guibas](https://www.cs.stanford.edu/people/leonidas-guibas), 18 | [Gordon Wetzstein](https://stanford.edu/~gordonwz/ ) 19 | 20 | Research on video generation has recently made tremendous progress, enabling high-quality videos to be generated from text prompts or images. Adding control to the video generation process is an important goal moving forward and recent approaches that condition video generation models on camera trajectories make strides towards it. Yet, it remains challenging to generate a video of the same scene from multiple different camera trajectories. Solutions to this multi-video generation problem could enable large-scale 3D scene generation with editable camera trajectories, among other applications. We introduce collaborative video diffusion (CVD) as an important step towards this vision. The CVD framework includes a novel cross-video synchronization module that promotes consistency between corresponding frames of the same video rendered from different camera poses using an epipolar attention mechanism. Trained on top of a state-of-the-art camera-control module for video generation, CVD generates multiple videos rendered from different camera trajectories with significantly better consistency than baselines, as shown in extensive experiments. 21 | 22 | ![teaser](docs/teaser.png) 23 | 24 | 25 | ## 🛠️ Setup 26 | 27 | ### 📦 Repository 28 | 29 | Clone the repository (requires git): 30 | 31 | ```bash 32 | git clone https://github.com/CVD 33 | cd CVD 34 | ``` 35 | 36 | ### 💻 Dependencies 37 | For the environment, run: 38 | 39 | ``` 40 | conda env create -f environment.yaml 41 | 42 | conda activate CVD 43 | 44 | pip install torch==2.2+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 45 | 46 | pip install -r requirements.txt 47 | ``` 48 | We require AnimateDiff and CameraCtrl to be built: 49 | - DownLoad Stable Diffusion V1.5 (SD1.5) from [HuggingFace](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). 50 | - DownLoad the checkpoints of AnimatediffV3 (ADV3) adaptor and motion module from [AnimateDiff](https://github.com/guoyww/AnimateDiff). 51 | - Run `tools/merge_lora2unet.py` to merge the ADV3 adaptor weights into SD1.5 unet and save results to new subfolder (like, `unet_webvidlora_v3`) under SD1.5 folder. 52 | - DownLoad the CameraCtrl's camera control model from [Google Drive](https://drive.google.com/file/d/1mlNaX8ipJylTHq2MHV2_mOQEegKr1YXc/view?usp=share_link). 53 | - Download our synchronization module from [Google Drive](https://drive.google.com/file/d/1z6cR3PbqnrlVjXJJlk6AYxdl8z18hvtL/view?usp=sharing). 54 | 55 | By default, all of the models should be downloaded to `./models` under the root directory. 56 | 57 | 60 | 61 | ## 🏃 Inference 62 | We provide two scripts to sample random consensus videos, namely the simplest two-video generation, and the advanced multi-video and complex trajectory video generation. 63 | ### 🎞️ Simple two-video generation 64 | 65 | To sample the simplest setup of CVD, that is two videos representing the same underlying scene, but captured from different camera poses, run the following: 66 | ```bash 67 | bash run_inference_simple.sh 68 | ``` 69 | You might need to modify the model paths to your download location. 70 | 71 | #### ⚙️ Inference settings 72 | 73 | We provide two methods to sample camera trajectories of the two videos. You will need to define: 74 | - `--pose_file_0`: Camera trajectory for the first video. 75 | - `--pose_file_1`: Camera trajectory for the second video. 76 | 77 | To specify the prompts, modify `assets/cameractrl_prompts.json`. 78 | 79 | To run CVD on a LoRA model, simply specify either: 80 | - `--civitai_base_model` for LoRA tuned base model, or 81 | - `--civitai_lora_ckpt` for loading the LoRA checkpoints. 82 | 83 | To get the best results, play around with `--guidance_scale`. Depends on the desired contents, we find range of 8-30 typically provide decent results. 84 | 85 | Feel free to tune the parameters (such as the guidance scale, LoRA weights) for variant results. 86 | 87 | ### 🎞️ Adavanced generation 88 | In addition to paired video generation, we also provide scripts for generating more videos. Here we provide three settings of camera trajectories: 89 | - 'circle': Each camera trajectory starts from the same position, and spans out to different locations on a circle perpendicular to the look-at direction. 90 | - 'upper_hemi': similar to the 'circle' mode, but only covers the upper hemicircle. (No camera trajectory moves underwards) 91 | - 'interpolate': Each camera trajectory starts from the same position, and move to a target interpolated from two given positions. 92 | To generate arbitrary views under these patterns, run: 93 | ```bash 94 | bash run_inference_advanced.sh 95 | ``` 96 | 97 | #### ⚙️ Inference settings 98 | Different from the simple mode, here the camera poses are procedurally generated in the scripts. Hence the pose files are not required here. Instead, 99 | - '--cam_pattern' determines how the camera are generated. 100 | 101 | Some other important parameters: 102 | - '--view_num': the number of views that will be generated. 103 | - '--multistep': Number of recurrent steps for each denoising step. Set to 3 for 4 view generation and 6 for 6 view generation by default. 104 | - '--accumulate_step': Number of pairs assigned to each video. Set to 1 for 4 view generation and 2 for 6 view generation by default. 105 | 106 | 109 | 110 | ## 🎓 Citation 111 | 112 | Please cite our paper: 113 | 114 | ```bibtex 115 | @inproceedings{kuang2024cvd, 116 | author={Kuang, Zhengfei and Cai, Shengqu and He, Hao and Xu, Yinghao and Li, Hongsheng and Guibas, Leonidas and Wetzstein, Gordon.}, 117 | title={Collaborative Video Diffusion: Consistent Multi-video Generation with Camera Control}, 118 | booktitle={arXiv}, 119 | year={2024} 120 | } 121 | ``` -------------------------------------------------------------------------------- /animatediff/data/dataset_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import torch 5 | 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import torchvision.transforms.functional as F 9 | import numpy as np 10 | 11 | from decord import VideoReader 12 | from torch.utils.data.dataset import Dataset 13 | from packaging import version as pver 14 | from scipy.spatial.transform import Rotation as R 15 | from scipy.spatial.transform import Slerp 16 | 17 | import glob 18 | import imageio 19 | import cv2 20 | # from animatediff.data.dinov2matcher import Dinov2Matcher 21 | 22 | # a -> [ax] 23 | # [0, -a3, a2] 24 | # [a3, 0, -a1] 25 | # [-a2, a1, 0] 26 | # input: a vector (shape: 3) or vector array (shape: N x 3) 27 | # output: a matrix (shape: 3 x 3) of array of matrix (shape: N x 3 x 3) 28 | def calc_cross_product_matrix(vec): 29 | is_array = False if len(vec.shape) == 1 else True 30 | if not is_array: 31 | vec = vec[np.newaxis, :] # 1 x 3 32 | ret_mat = np.zeros(list(vec.shape)+[3]) 33 | ret_mat[:, 0, 1] = -vec[:, 2] 34 | ret_mat[:, 0, 2] = vec[:, 1] 35 | ret_mat[:, 1, 2] = -vec[:, 0] 36 | ret_mat -= ret_mat.transpose((0, 2, 1)) 37 | if not is_array: 38 | ret_mat = ret_mat[0] 39 | return ret_mat 40 | 41 | # T_mat shape is 4 x 4 42 | # x2 = T_mat * x1 = R_mat*x1 + t = R_mat * ( x1 - (-R_mat^T * t) ) 43 | # let R_ess = R_mat, t_ess = -R_mat^T*t 44 | # then E_mat = R_ess*[t_ess x] 45 | def calc_essential_matrix(T_mat): 46 | R_mat = T_mat[:3, :3] 47 | t = T_mat[:3, 3] # t 48 | t_ess = -np.matmul(R_mat.transpose(), t) 49 | E_mat = np.matmul(R_mat, calc_cross_product_matrix(t_ess)) 50 | return E_mat 51 | 52 | # T_mat: from camera 1 to camera 2 53 | # x2 = T_mat * x1 54 | # because in essential matrix we have x2^t E x1 = 0, 55 | # and x_{1,2} = K_{1,2}^-1 * coord_{1,2}, 56 | # we can get F = K2^-T * E * K1^-1 57 | def calc_fundamental_matrix(T_mat, K_mat1, K_mat2): 58 | E_mat = calc_essential_matrix(T_mat) 59 | 60 | K2_invT = np.linalg.inv(K_mat2).transpose() 61 | K1_inv = np.linalg.inv(K_mat1) 62 | F_mat = np.matmul(np.matmul(K2_invT, E_mat), K1_inv) 63 | 64 | return F_mat 65 | 66 | # Assume cx=H/2, cy=W/2 67 | def K_mat_from_fov(fov_deg, H, W): 68 | fx = (W/2) / math.tan(fov_deg/2) 69 | fy = (H/2) / math.tan(fov_deg/2) 70 | K_mat = np.array( 71 | [ 72 | [fx, 0, W/2], 73 | [0, fy, H/2], 74 | [0, 0, 1] 75 | ] 76 | ) 77 | return K_mat 78 | 79 | class Camera(object): 80 | def __init__(self, entry): 81 | self.cid = entry[0] 82 | fx, fy, cx, cy = entry[1:5] 83 | self.fx = fx 84 | self.fy = fy 85 | self.cx = cx 86 | self.cy = cy 87 | w2c_mat = np.array(entry[7:]).reshape(3, 4) 88 | w2c_mat_4x4 = np.eye(4) 89 | w2c_mat_4x4[:3, :] = w2c_mat 90 | self.w2c_mat = w2c_mat_4x4 91 | self.c2w_mat = np.linalg.inv(w2c_mat_4x4) 92 | 93 | 94 | def custom_meshgrid(*args): 95 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid 96 | if pver.parse(torch.__version__) < pver.parse('1.10'): 97 | return torch.meshgrid(*args) 98 | else: 99 | return torch.meshgrid(*args, indexing='ij') 100 | 101 | 102 | def ray_condition(K, c2w, H, W, device, flip_flag=None): 103 | # c2w: B, V, 4, 4 104 | # K: B, V, 4 105 | 106 | B, V = K.shape[:2] 107 | 108 | j, i = custom_meshgrid( 109 | torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), 110 | torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), 111 | ) 112 | i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] 113 | j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] 114 | 115 | n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 116 | if n_flip > 0: 117 | j_flip, i_flip = custom_meshgrid( 118 | torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), 119 | torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype) 120 | ) 121 | i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 122 | j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 123 | i[:, flip_flag, ...] = i_flip 124 | j[:, flip_flag, ...] = j_flip 125 | 126 | fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 127 | 128 | zs = torch.ones_like(i) # [B, V, HxW] 129 | xs = (i - cx) / fx * zs 130 | ys = (j - cy) / fy * zs 131 | zs = zs.expand_as(ys) 132 | 133 | directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 134 | directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 135 | 136 | rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 137 | rays_o = c2w[..., :3, 3] # B, V, 3 138 | rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 139 | # c2w @ dirctions 140 | rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3 141 | plucker = torch.cat([rays_dxo, rays_d], dim=-1) 142 | plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 143 | # plucker = plucker.permute(0, 1, 4, 2, 3) 144 | return plucker 145 | 146 | class ValRealEstate10KPoseFolded(Dataset): 147 | def __init__( 148 | self, 149 | sample_n_frames=16, 150 | relative_pose=False, 151 | sample_size=256, 152 | validation_prompts=None, 153 | validation_negative_prompts=None, 154 | mode="train", 155 | pose_file_0=None, 156 | pose_file_1=None 157 | ): 158 | self.relative_pose = relative_pose 159 | self.sample_n_frames = sample_n_frames 160 | self.validation_prompts = validation_prompts 161 | self.validation_negative_prompts = validation_negative_prompts 162 | self.mode = mode 163 | self.pose_file_0 = pose_file_0 164 | self.pose_file_1 = pose_file_1 165 | 166 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 167 | self.sample_size = sample_size 168 | 169 | pixel_transforms = transforms.Compose([transforms.Resize(sample_size[0]), 170 | transforms.CenterCrop(self.sample_size), 171 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) 172 | 173 | self.pixel_transforms = pixel_transforms 174 | 175 | def get_relative_pose(self, c2w_list, tar_idx=0): 176 | abs2rel = np.linalg.inv(c2w_list[tar_idx]) 177 | ret_poses = [abs2rel @ c2w for c2w in c2w_list] 178 | return np.array(ret_poses, dtype=np.float32) 179 | 180 | def load_cameras_specific(self): 181 | # To extract two camera poses from the same start, here we first load two trajectory files, 182 | # then normalize the poses to the same start (identical matrix), and finally calculate the relative poses 183 | 184 | pose_file_0 = os.path.join(self.pose_file_0) 185 | with open(pose_file_0, 'r') as f: 186 | poses_0 = f.readlines() 187 | pose_file_1 = os.path.join(self.pose_file_1) 188 | with open(pose_file_1, 'r') as f: 189 | poses_1 = f.readlines() 190 | poses_0 = [pose.strip().split(' ') for pose in poses_0[1:]] 191 | cam_params_0 = [[float(x) for x in pose] for pose in poses_0] 192 | cam_params_0 = [Camera(cam_param) for cam_param in cam_params_0] 193 | poses_1 = [pose.strip().split(' ') for pose in poses_1[1:]] 194 | cam_params_1 = [[float(x) for x in pose] for pose in poses_1] 195 | cam_params_1 = [Camera(cam_param) for cam_param in cam_params_1] 196 | cam_params_1.reverse() 197 | 198 | c2w_pose_list_0 = [] 199 | K_mat_list_0 = [] 200 | intrinsic_list_0 = [] 201 | for frame_idx in range(len(cam_params_0)): 202 | cam = cam_params_0[frame_idx] 203 | H, W = 1280, 720 204 | 205 | crop_size = min(H, W) 206 | rescale = self.sample_size[0] / crop_size 207 | dH, dW = (H-crop_size)/2, (W-crop_size)/2 208 | K_mat = np.array([[W*rescale*cam.fx, 0, (W*cam.cx-dW)*rescale], [0, H*rescale*cam.fy, (H*cam.cy-dH)*rescale], [0, 0, 1]]) 209 | intrinsics = [K_mat[0, 0], K_mat[1, 1], K_mat[0, 2], K_mat[1, 2]] 210 | # While the statement in realestate10K states that the extrinsics are w2c, 211 | # Seems they are c2w instead 212 | c2w_pose_list_0.append(cam.c2w_mat) 213 | K_mat_list_0.append(K_mat) 214 | intrinsic_list_0.append(intrinsics) 215 | 216 | c2w_pose_list_1 = [] 217 | K_mat_list_1 = [] 218 | intrinsic_list_1 = [] 219 | for frame_idx in range(len(cam_params_1)): 220 | cam = cam_params_1[frame_idx] 221 | H, W = 1280, 720 222 | 223 | crop_size = min(H, W) 224 | rescale = self.sample_size[0] / crop_size 225 | dH, dW = (H-crop_size)/2, (W-crop_size)/2 226 | K_mat = np.array([[W*rescale*cam.fx, 0, (W*cam.cx-dW)*rescale], [0, H*rescale*cam.fy, (H*cam.cy-dH)*rescale], [0, 0, 1]]) 227 | intrinsics = [K_mat[0, 0], K_mat[1, 1], K_mat[0, 2], K_mat[1, 2]] 228 | # While the statement in realestate10K states that the extrinsics are w2c, 229 | # Seems they are c2w instead 230 | c2w_pose_list_1.append(cam.c2w_mat) 231 | K_mat_list_1.append(K_mat) 232 | intrinsic_list_1.append(intrinsics) 233 | 234 | c2w_pose_list_0 = self.get_relative_pose(c2w_pose_list_0, tar_idx = 0) 235 | c2w_pose_list_1 = self.get_relative_pose(c2w_pose_list_1, tar_idx = 0) 236 | c2w_pose_list = np.concatenate([c2w_pose_list_0[1:][::-1], c2w_pose_list_1], axis=0) 237 | # force k mat to be the same 238 | K_mat_list = np.concatenate([K_mat_list_0[1:][::-1], K_mat_list_0], axis=0) 239 | intrinsic_list = np.concatenate([intrinsic_list_0[1:][::-1], intrinsic_list_1], axis=0) 240 | return c2w_pose_list, K_mat_list, intrinsic_list 241 | 242 | def get_batch(self, validation_idx): 243 | 244 | validation_prompt = self.validation_prompts[validation_idx] 245 | 246 | if self.validation_negative_prompts is not None: 247 | validation_negative_prompt = self.validation_negative_prompts[validation_idx] 248 | else: 249 | validation_negative_prompt = None 250 | 251 | c2w_pose_list, K_mat_list, intrinsic_list = self.load_cameras_specific() 252 | 253 | intrinsics = torch.as_tensor(np.array(intrinsic_list)).float()[None] # [1, n_frame, 4] 254 | c2w = torch.as_tensor(c2w_pose_list)[None] # [1, n_frame, 4, 4] 255 | 256 | plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu' 257 | )[0].permute(0, 3, 1, 2) # n_frame, channel, H, W 258 | 259 | # Folding camera poses 260 | F_mat_list = [] 261 | for i in range(self.sample_n_frames): 262 | sid = self.sample_n_frames - 1 - i 263 | tid = self.sample_n_frames - 1 + i 264 | s2t = np.linalg.inv(c2w_pose_list[tid]) @ c2w_pose_list[sid] 265 | F_mat = calc_fundamental_matrix(s2t, K_mat_list[sid], K_mat_list[tid]) 266 | F_mat_list.append(torch.from_numpy(F_mat)) 267 | 268 | F_mats = torch.as_tensor(np.array(F_mat_list)).float() # [n_frame, 3, 3] 269 | 270 | # Fold all vectors 271 | F_mats = torch.cat([F_mats, F_mats.permute(0, 2, 1)], dim=0).contiguous() 272 | fold_indices = torch.arange(self.sample_n_frames) 273 | fold_indices = torch.cat([self.sample_n_frames - 1 - fold_indices, 274 | self.sample_n_frames - 1 + fold_indices]) 275 | 276 | plucker_embedding = plucker_embedding[fold_indices].contiguous() 277 | 278 | ret_c2w = c2w[:, fold_indices] 279 | ret_K_mats = np.stack(K_mat_list, axis=0)[fold_indices] 280 | 281 | return plucker_embedding, F_mats, validation_prompt, validation_negative_prompt, ret_c2w, ret_K_mats 282 | 283 | def __len__(self): 284 | return len(self.validation_prompts) 285 | 286 | def __getitem__(self, idx): 287 | plucker_embedding, F_mats, validation_prompt, validation_negative_prompt, ret_c2w, ret_K_mats = self.get_batch(idx) 288 | 289 | ret_sample = { 290 | "validation_prompt": validation_prompt, 291 | "plucker_embedding": plucker_embedding, 292 | "F_mats": F_mats, 293 | "ret_c2w": ret_c2w, 294 | "ret_K_mats": ret_K_mats 295 | } 296 | if validation_negative_prompt is not None: 297 | ret_sample["validation_negative_prompt"] = validation_negative_prompt 298 | 299 | return ret_sample 300 | -------------------------------------------------------------------------------- /animatediff/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from diffusers.configuration_utils import ConfigMixin, register_to_config 10 | from diffusers.models.modeling_utils import ModelMixin 11 | from diffusers.utils import BaseOutput 12 | from diffusers.models.attention import BasicTransformerBlock 13 | from einops import rearrange, repeat 14 | 15 | 16 | @dataclass 17 | class Transformer3DModelOutput(BaseOutput): 18 | sample: torch.FloatTensor 19 | 20 | 21 | class Transformer3DModel(ModelMixin, ConfigMixin): 22 | @register_to_config 23 | def __init__( 24 | self, 25 | num_attention_heads: int = 16, 26 | attention_head_dim: int = 88, 27 | in_channels: Optional[int] = None, 28 | num_layers: int = 1, 29 | dropout: float = 0.0, 30 | norm_num_groups: int = 32, 31 | cross_attention_dim: Optional[int] = None, 32 | attention_bias: bool = False, 33 | activation_fn: str = "geglu", 34 | num_embeds_ada_norm: Optional[int] = None, 35 | use_linear_projection: bool = False, 36 | only_cross_attention: bool = False, 37 | upcast_attention: bool = False, 38 | norm_type: str = "layer_norm", 39 | norm_elementwise_affine: bool = True, 40 | ): 41 | super().__init__() 42 | self.use_linear_projection = use_linear_projection 43 | self.num_attention_heads = num_attention_heads 44 | self.attention_head_dim = attention_head_dim 45 | inner_dim = num_attention_heads * attention_head_dim 46 | 47 | # Define input layers 48 | self.in_channels = in_channels 49 | 50 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 51 | if use_linear_projection: 52 | self.proj_in = nn.Linear(in_channels, inner_dim) 53 | else: 54 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 55 | 56 | # Define transformers blocks 57 | self.transformer_blocks = nn.ModuleList( 58 | [ 59 | BasicTransformerBlock( 60 | inner_dim, 61 | num_attention_heads, 62 | attention_head_dim, 63 | dropout=dropout, 64 | cross_attention_dim=cross_attention_dim, 65 | activation_fn=activation_fn, 66 | num_embeds_ada_norm=num_embeds_ada_norm, 67 | attention_bias=attention_bias, 68 | only_cross_attention=only_cross_attention, 69 | upcast_attention=upcast_attention, 70 | norm_type=norm_type, 71 | norm_elementwise_affine=norm_elementwise_affine, 72 | ) 73 | for d in range(num_layers) 74 | ] 75 | ) 76 | 77 | # 4. Define output layers 78 | if use_linear_projection: 79 | self.proj_out = nn.Linear(in_channels, inner_dim) 80 | else: 81 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 82 | 83 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 84 | # Input 85 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 86 | batch_size, _, video_length = hidden_states.shape[:3] 87 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 88 | 89 | if encoder_hidden_states.shape[0] == batch_size: 90 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 91 | 92 | elif encoder_hidden_states.shape[0] == batch_size * video_length: 93 | pass 94 | else: 95 | raise ValueError 96 | 97 | batch, channel, height, weight = hidden_states.shape 98 | residual = hidden_states 99 | 100 | hidden_states = self.norm(hidden_states) 101 | if not self.use_linear_projection: 102 | hidden_states = self.proj_in(hidden_states) 103 | inner_dim = hidden_states.shape[1] 104 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 105 | else: 106 | inner_dim = hidden_states.shape[1] 107 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 108 | hidden_states = self.proj_in(hidden_states) 109 | 110 | # Blocks 111 | for block in self.transformer_blocks: 112 | hidden_states = block( 113 | hidden_states, 114 | encoder_hidden_states=encoder_hidden_states, 115 | timestep=timestep, 116 | ) 117 | 118 | # Output 119 | if not self.use_linear_projection: 120 | hidden_states = ( 121 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 122 | ) 123 | hidden_states = self.proj_out(hidden_states) 124 | else: 125 | hidden_states = self.proj_out(hidden_states) 126 | hidden_states = ( 127 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 128 | ) 129 | 130 | output = hidden_states + residual 131 | 132 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 133 | if not return_dict: 134 | return (output,) 135 | 136 | return Transformer3DModelOutput(sample=output) 137 | -------------------------------------------------------------------------------- /animatediff/models/epi_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from diffusers.utils import BaseOutput 9 | from diffusers.models.attention_processor import Attention 10 | from diffusers.models.attention import FeedForward 11 | 12 | from animatediff.models.resnet import InflatedGroupNorm 13 | from typing import Dict, Any 14 | from animatediff.models.attention_processor import PoseAdaptorAttnProcessor, EpiAttnProcessor 15 | 16 | from einops import rearrange 17 | import math 18 | 19 | 20 | def zero_module(module): 21 | # Zero out the parameters of a module and return it. 22 | for p in module.parameters(): 23 | p.detach().zero_() 24 | return module 25 | 26 | 27 | @dataclass 28 | class TemporalTransformer3DModelOutput(BaseOutput): 29 | sample: torch.FloatTensor 30 | 31 | 32 | def get_epi_module( 33 | in_channels, 34 | epi_module_kwargs: dict 35 | ): 36 | return EpiModule(in_channels=in_channels, **epi_module_kwargs) 37 | 38 | class EpiModule(nn.Module): 39 | def __init__( 40 | self, 41 | in_channels, 42 | num_attention_heads=8, 43 | num_transformer_block=2, 44 | attention_block_types=("Epi_Self",), 45 | epi_position_encoding = True, 46 | epi_position_encoding_feat_max_size= 64, 47 | epi_position_encoding_F_mat_size = 256, 48 | epi_no_attention_mask = False, 49 | epi_mono_direction = False, 50 | epi_fix_firstframe = False, 51 | epi_rand_slope_ff = False, 52 | cross_attention_dim=320, 53 | zero_initialize=True, 54 | encoder_hidden_states_query=(False, False), 55 | attention_activation_scale=1.0, 56 | attention_processor_kwargs: Dict = {}, 57 | rescale_output_factor=1.0 58 | ): 59 | super().__init__() 60 | 61 | self.epi_transformer = EpiTransformer3DModel( 62 | in_channels=in_channels, 63 | num_attention_heads=num_attention_heads, 64 | attention_head_dim=in_channels // num_attention_heads, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_attention_dim=cross_attention_dim, 68 | epi_position_encoding=epi_position_encoding, 69 | epi_position_encoding_feat_max_size=epi_position_encoding_feat_max_size, 70 | epi_position_encoding_F_mat_size=epi_position_encoding_F_mat_size, 71 | epi_no_attention_mask = epi_no_attention_mask, 72 | epi_mono_direction=epi_mono_direction, 73 | epi_fix_firstframe=epi_fix_firstframe, 74 | epi_rand_slope_ff=epi_rand_slope_ff, 75 | encoder_hidden_states_query=encoder_hidden_states_query, 76 | attention_activation_scale=attention_activation_scale, 77 | attention_processor_kwargs=attention_processor_kwargs, 78 | rescale_output_factor=rescale_output_factor 79 | ) 80 | 81 | if zero_initialize: 82 | self.epi_transformer.proj_out = zero_module(self.epi_transformer.proj_out) 83 | 84 | def forward(self, hidden_states, F_mats=None, H_mats=None, temb=None, encoder_hidden_states=None, attention_mask=None, 85 | cross_attention_kwargs: Dict[str, Any] = {}): 86 | hidden_states, aux = self.epi_transformer(hidden_states, F_mats, H_mats, encoder_hidden_states, attention_mask, cross_attention_kwargs=cross_attention_kwargs) 87 | 88 | output = hidden_states 89 | return output, aux 90 | 91 | 92 | class EpiTransformer3DModel(nn.Module): 93 | def __init__( 94 | self, 95 | in_channels, 96 | num_attention_heads, 97 | attention_head_dim, 98 | num_layers, 99 | attention_block_types=("Epi_Self",), 100 | dropout=0.0, 101 | norm_num_groups=32, 102 | cross_attention_dim=320, 103 | activation_fn="geglu", 104 | attention_bias=False, 105 | upcast_attention=False, 106 | epi_position_encoding=False, 107 | epi_position_encoding_feat_max_size=32, 108 | epi_position_encoding_F_mat_size=256, 109 | epi_no_attention_mask=False, 110 | epi_mono_direction=False, 111 | epi_fix_firstframe=False, 112 | epi_rand_slope_ff=False, 113 | encoder_hidden_states_query=(False, False), 114 | attention_activation_scale=1.0, 115 | attention_processor_kwargs: Dict = {}, 116 | 117 | rescale_output_factor=1.0 118 | ): 119 | super().__init__() 120 | 121 | inner_dim = num_attention_heads * attention_head_dim 122 | 123 | self.norm = InflatedGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 124 | self.proj_in = nn.Linear(in_channels, inner_dim) 125 | 126 | self.transformer_blocks = nn.ModuleList( 127 | [ 128 | EpiTransformerBlock( 129 | dim=inner_dim, 130 | num_attention_heads=num_attention_heads, 131 | attention_head_dim=attention_head_dim, 132 | attention_block_types=attention_block_types, 133 | dropout=dropout, 134 | norm_num_groups=norm_num_groups, 135 | cross_attention_dim=cross_attention_dim, 136 | activation_fn=activation_fn, 137 | attention_bias=attention_bias, 138 | upcast_attention=upcast_attention, 139 | epi_position_encoding=epi_position_encoding, 140 | epi_position_encoding_feat_max_size=epi_position_encoding_feat_max_size, 141 | epi_position_encoding_F_mat_size=epi_position_encoding_F_mat_size, 142 | encoder_hidden_states_query=encoder_hidden_states_query, 143 | epi_no_attention_mask=epi_no_attention_mask, 144 | epi_mono_direction=epi_mono_direction, 145 | epi_fix_firstframe=epi_fix_firstframe, 146 | epi_rand_slope_ff=epi_rand_slope_ff, 147 | attention_activation_scale=attention_activation_scale, 148 | attention_processor_kwargs=attention_processor_kwargs, 149 | rescale_output_factor=rescale_output_factor, 150 | ) 151 | for d in range(num_layers) 152 | ] 153 | ) 154 | self.proj_out = nn.Linear(inner_dim, in_channels) 155 | 156 | 157 | def forward(self, hidden_states, F_mats=None, H_mats=None, encoder_hidden_states=None, attention_mask=None, 158 | cross_attention_kwargs: Dict[str, Any] = {},): 159 | residual = hidden_states 160 | 161 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 162 | video_length, height, width = hidden_states.shape[-3:] 163 | hidden_states = self.norm(hidden_states) 164 | # hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c") 165 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) (h w) c") 166 | if F_mats is not None: 167 | if isinstance(F_mats, torch.Tensor): 168 | F_mats = rearrange(F_mats, "b f h w -> (b f) h w") 169 | else: 170 | F_mats = [rearrange(F_mats[0], "b f h w -> (b f) h w"), F_mats[1]] 171 | if H_mats is not None: 172 | H_mats = rearrange(H_mats, "b f h w -> (b f) h w") 173 | 174 | hidden_states = self.proj_in(hidden_states) 175 | 176 | # Transformer Blocks 177 | additional_outputs = [] 178 | for block in self.transformer_blocks: 179 | hidden_states, aux = block(hidden_states, F_mats, H_mats, encoder_hidden_states=encoder_hidden_states, 180 | attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs) 181 | additional_outputs += aux 182 | hidden_states = self.proj_out(hidden_states) 183 | 184 | # hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width) 185 | hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width) 186 | 187 | output = hidden_states + residual 188 | return output, additional_outputs 189 | 190 | 191 | class EpiTransformerBlock(nn.Module): 192 | def __init__( 193 | self, 194 | dim, 195 | num_attention_heads, 196 | attention_head_dim, 197 | attention_block_types=("Temporal_Self", "Temporal_Self",), 198 | dropout=0.0, 199 | norm_num_groups=32, 200 | cross_attention_dim=768, 201 | activation_fn="geglu", 202 | attention_bias=False, 203 | upcast_attention=False, 204 | epi_position_encoding=False, 205 | epi_position_encoding_feat_max_size=32, 206 | epi_position_encoding_F_mat_size=256, 207 | epi_no_attention_mask=False, 208 | epi_mono_direction=False, 209 | epi_fix_firstframe=False, 210 | epi_rand_slope_ff=False, 211 | encoder_hidden_states_query=(False, False), 212 | attention_activation_scale=1.0, 213 | attention_processor_kwargs: Dict = {}, 214 | rescale_output_factor=1.0 215 | ): 216 | super().__init__() 217 | 218 | attention_blocks = [] 219 | norms = [] 220 | self.attention_block_types = attention_block_types 221 | 222 | for block_idx, block_name in enumerate(attention_block_types): 223 | attention_blocks.append( 224 | EpiSelfAttention( 225 | attention_mode=block_name, 226 | cross_attention_dim=None, 227 | query_dim=dim, 228 | heads=num_attention_heads, 229 | dim_head=attention_head_dim, 230 | dropout=dropout, 231 | bias=attention_bias, 232 | upcast_attention=upcast_attention, 233 | epi_position_encoding=epi_position_encoding, 234 | epi_position_encoding_feat_max_size=epi_position_encoding_feat_max_size, 235 | epi_position_encoding_F_mat_size=epi_position_encoding_F_mat_size, 236 | epi_no_attention_mask=epi_no_attention_mask, 237 | epi_mono_direction=epi_mono_direction, 238 | epi_fix_firstframe=epi_fix_firstframe, 239 | epi_rand_slope_ff=epi_rand_slope_ff, 240 | rescale_output_factor=rescale_output_factor, 241 | ) 242 | ) 243 | norms.append(nn.LayerNorm(dim)) 244 | 245 | self.attention_blocks = nn.ModuleList(attention_blocks) 246 | self.norms = nn.ModuleList(norms) 247 | 248 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 249 | self.ff_norm = nn.LayerNorm(dim) 250 | 251 | def forward(self, hidden_states, F_mats=None, H_mats=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs: Dict[str, Any] = {}): 252 | additional_outputs = [] 253 | for attention_block, norm, attention_block_type in zip(self.attention_blocks, self.norms, self.attention_block_types): 254 | norm_hidden_states = norm(hidden_states) 255 | res, aux = attention_block( 256 | norm_hidden_states, 257 | F_mats=F_mats, 258 | H_mats=H_mats, 259 | encoder_hidden_states=norm_hidden_states if attention_block_type == 'Temporal_Self' else encoder_hidden_states, 260 | attention_mask=attention_mask, 261 | **cross_attention_kwargs 262 | ) 263 | hidden_states = hidden_states + res 264 | additional_outputs.append(aux) 265 | 266 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 267 | 268 | output = hidden_states 269 | return output, additional_outputs 270 | 271 | class EpiEncoding(nn.Module): 272 | def __init__( 273 | self, 274 | d_model, 275 | dropout = 0., 276 | max_feat_size = 128, 277 | F_mat_size = 256, 278 | rand_slope_on_first_frame = False, 279 | ): 280 | super().__init__() 281 | self.F_mat_size = F_mat_size 282 | self.rand_slope_on_first_frame = rand_slope_on_first_frame 283 | self.dropout = nn.Dropout(p=dropout) 284 | coords = torch.arange(max_feat_size) 285 | coords_x, coords_y = torch.meshgrid(coords, coords, indexing='xy') 286 | coords = torch.stack([coords_x, coords_y, coords_x*0+1], dim=-1) # 64 x 64 x 3 287 | 288 | self.register_buffer('coords', coords) 289 | 290 | def forward(self, x, F_mats=None): 291 | pass 292 | 293 | def get_attn_map(self, x, F_mats=None, H_mats=None, pixel_band=3, decay_alpha=3): 294 | feat_size = int(x.shape[1] ** 0.5) 295 | 296 | selected_coords = self.coords[:feat_size, :feat_size].reshape(-1, 3) 297 | # Rescale pixel coordinates to where the F matrix is defined 298 | coords = ((self.F_mat_size / feat_size) * selected_coords + (self.F_mat_size / feat_size-1) / 2)[None] # 1 x feat_size^2 x 3 299 | coords[..., -1] = 1 300 | 301 | if H_mats is not None: 302 | # Get F_coords by homography transformation 303 | # In case where H_mats is given, pseodo epipolar lines are generated 304 | batch_size = H_mats.shape[0] 305 | H_coords = coords.repeat(batch_size, 1, 1) # B x feat_size^2 x 3 306 | H_coords[...,:2] = H_coords[...,:2] - (self.F_mat_size-1) / 2 307 | H_coords = torch.bmm(H_mats.float(), H_coords.permute(0, 2, 1)).permute(0, 2, 1) 308 | H_coords = H_coords / (H_coords[...,2:]+1e-6) 309 | H_coords[...,:2] = H_coords[...,:2] + (self.F_mat_size-1) / 2 310 | F_coords = self.get_pseudo_F_coords(H_coords, random_slope=True) 311 | elif F_mats is not None: 312 | # Get F_coords by epipolar transformation 313 | batch_size = F_mats.shape[0] 314 | F_coords = coords.repeat(batch_size, 1, 1) 315 | F_coords = torch.bmm(F_mats.float(), F_coords.float().permute(0, 2, 1)).permute(0, 2, 1) # B x feat_size^2 x 3 316 | F_coords[::16] = self.get_pseudo_F_coords(coords[::16], random_slope=self.rand_slope_on_first_frame) 317 | else: 318 | # Get F_coords by identity transformation 319 | batch_size = x.shape[0] 320 | F_coords = self.get_pseudo_F_coords(coords.repeat(batch_size, 1, 1), random_slope=True) 321 | 322 | ab_norm = (F_coords[:, :, :2] * F_coords[:, :, :2]).sum(-1).sqrt()[:, :, None] 323 | cFc = torch.bmm(F_coords, coords.repeat(batch_size, 1, 1).permute(0, 2, 1)).abs() 324 | cFc = cFc / (ab_norm+1e-6) 325 | normed_pixel_band = (pixel_band / (self.F_mat_size // 2) * cFc.reshape(cFc.shape[0], -1).max(dim=-1)[0])[:, None, None] 326 | map_weight_decay = decay_alpha / (normed_pixel_band+1e-6) 327 | attn_mask = - (cFc-normed_pixel_band).clip(0) * map_weight_decay # B x feat_size^2 x feat_size^2 328 | # attn_mask = 1. - torch.sigmoid(50. * (cFc/256. - 0.01)) 329 | return attn_mask.detach() 330 | 331 | def get_pseudo_F_coords(self, coords, random_slope=False): 332 | feat_size = int(coords.shape[1] ** 0.5) 333 | batch_size = coords.shape[0] 334 | if random_slope is True: 335 | slope = torch.rand([batch_size], device=coords.device) * math.pi 336 | F_coords_a = torch.cos(slope)[:, None, None].repeat(1, feat_size**2, 1) 337 | F_coords_b = torch.sin(slope)[:, None, None].repeat(1, feat_size**2, 1) 338 | F_coords_c = -(F_coords_a * coords[...,0:1] + F_coords_b * coords[...,1:2]) 339 | F_coords = torch.cat([F_coords_a, F_coords_b, F_coords_c], dim=-1) 340 | else: 341 | F_coords_a = torch.zeros([1, feat_size**2, 1], device=coords.device).repeat(batch_size, 1, 1) 342 | F_coords_b = -torch.ones([1, feat_size**2, 1], device=coords.device).repeat(batch_size, 1, 1) 343 | F_coords_c = coords[...,1:2] 344 | 345 | F_coords = torch.cat([F_coords_a, F_coords_b, F_coords_c], dim=-1) 346 | return F_coords 347 | 348 | 349 | 350 | class EpiSelfAttention(Attention): 351 | def __init__( 352 | self, 353 | attention_mode=None, 354 | epi_position_encoding=False, 355 | epi_position_encoding_feat_max_size=32, 356 | epi_position_encoding_F_mat_size=256, 357 | epi_no_attention_mask=False, 358 | epi_mono_direction=False, 359 | epi_fix_firstframe=False, 360 | epi_rand_slope_ff=False, 361 | rescale_output_factor=1.0, 362 | *args, **kwargs 363 | ): 364 | super().__init__(*args, **kwargs) 365 | assert attention_mode == "Epi_Self" 366 | 367 | self.pos_encoder = EpiEncoding( 368 | kwargs["query_dim"], 369 | dropout=0., 370 | max_feat_size=epi_position_encoding_feat_max_size, 371 | F_mat_size=epi_position_encoding_F_mat_size, 372 | rand_slope_on_first_frame=epi_rand_slope_ff 373 | ) if epi_position_encoding else None 374 | self.rescale_output_factor = rescale_output_factor 375 | self.epi_no_attention_mask = epi_no_attention_mask 376 | self.epi_mono_direction = epi_mono_direction 377 | self.epi_fix_firstframe = epi_fix_firstframe 378 | 379 | def set_use_memory_efficient_attention_xformers( 380 | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None 381 | ): 382 | # disable motion module efficient xformers to avoid bad results, don't know why 383 | # TODO: fix this bug 384 | pass 385 | 386 | def forward(self, hidden_states, F_mats=None, H_mats=None, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): 387 | # The `Attention` class can call different attention processors / attention functions 388 | # here we simply pass along all tensors to the selected processor class 389 | # For standard processors that are defined here, `**cross_attention_kwargs` is empty 390 | 391 | # add position encoding 392 | if self.pos_encoder is not None and not self.epi_no_attention_mask: 393 | with torch.no_grad(): 394 | # hidden_states = self.pos_encoder(hidden_states) 395 | attention_mask = self.pos_encoder.get_attn_map(hidden_states, 396 | F_mats = F_mats[0] if isinstance(F_mats, list) else F_mats, 397 | H_mats = H_mats) 398 | if attention_mask.shape[0] != hidden_states.shape[0]: 399 | assert attention_mask.shape[0] % hidden_states.shape[0] == 0 400 | B, N, C = hidden_states.shape 401 | attention_mask = attention_mask.reshape(-1, B, N, N) 402 | attention_mask = attention_mask.permute(1, 2, 3, 0).reshape(B, N, -1) 403 | 404 | if torch.isnan(attention_mask).any(): 405 | print("attention_mask contains NaN") 406 | 407 | torch.nan_to_num(attention_mask, nan=0.0, posinf=0.0, neginf=0.0, out=attention_mask) 408 | 409 | # if attention_mask.shape[1] == 1024: 410 | # import pdb 411 | # pdb.set_trace() 412 | # torch.save(attention_mask.detach().cpu(), "temp_attention_mask_for_debug.pt") 413 | # attention_mask=None 414 | 415 | # if "pose_feature" in cross_attention_kwargs: 416 | # pose_feature = cross_attention_kwargs["pose_feature"] 417 | # if pose_feature.ndim == 5: 418 | # pose_feature = rearrange(pose_feature, "b c f h w -> (b h w) f c") 419 | # else: 420 | # assert pose_feature.ndim == 3 421 | # cross_attention_kwargs["pose_feature"] = pose_feature 422 | 423 | assert isinstance(self.processor, EpiAttnProcessor) 424 | if attention_mask is not None and (attention_mask.shape[0] > 200 or attention_mask.shape[-1] > 2048): # memorrrrrrrry...... 425 | bs = hidden_states.shape[0] 426 | chunk_num = 128 * 1024 // max(1024, attention_mask.shape[-1]) 427 | # hidden_states_0, hidden_states_1 = hidden_states.chunk(2) 428 | if isinstance(F_mats, list): 429 | kv_index=F_mats[1] 430 | encoder_hidden_states = hidden_states[kv_index] 431 | # encoder_hidden_states_0, encoder_hidden_states_1 = encoder_hidden_states.chunk(2) 432 | else: 433 | encoder_hidden_states = hidden_states 434 | # encoder_hidden_states_0, encoder_hidden_states_1 = hidden_states_1, hidden_states_0 435 | hid_list = [] 436 | for i in range(bs // chunk_num): 437 | st, ed = i*chunk_num, (i+1)*chunk_num 438 | hid_list.append(self.processor( 439 | self, 440 | hidden_states[st:ed], 441 | encoder_hidden_states=encoder_hidden_states[st:ed], 442 | attention_mask=attention_mask[st:ed], 443 | kv_index=None, # no need to input kv_index here since encoder_hidden_states is assigned 444 | mono_direction=self.epi_mono_direction, 445 | fix_firstframe=self.epi_fix_firstframe, 446 | **cross_attention_kwargs, 447 | )[0]) 448 | torch.cuda.empty_cache() 449 | hidden_states = torch.cat(hid_list, dim=0) 450 | aux = None # {k:torch.cat([hid_0[1][k], hid_1[1][k]], dim=0) for k in hid_0[1].keys()} 451 | return hidden_states, aux 452 | else: 453 | return self.processor( 454 | self, 455 | hidden_states, 456 | encoder_hidden_states=None, 457 | attention_mask=attention_mask, 458 | kv_index=F_mats[1] if isinstance(F_mats, list) else None, 459 | mono_direction=self.epi_mono_direction, 460 | fix_firstframe=self.epi_fix_firstframe, 461 | **cross_attention_kwargs, 462 | ) 463 | -------------------------------------------------------------------------------- /animatediff/models/pose_adaptor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | from animatediff.models.motion_module import TemporalTransformerBlock 6 | 7 | 8 | def get_parameter_dtype(parameter: torch.nn.Module): 9 | try: 10 | params = tuple(parameter.parameters()) 11 | if len(params) > 0: 12 | return params[0].dtype 13 | 14 | buffers = tuple(parameter.buffers()) 15 | if len(buffers) > 0: 16 | return buffers[0].dtype 17 | 18 | except StopIteration: 19 | # For torch.nn.DataParallel compatibility in PyTorch 1.5 20 | 21 | def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: 22 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 23 | return tuples 24 | 25 | gen = parameter._named_members(get_members_fn=find_tensor_attributes) 26 | first_tuple = next(gen) 27 | return first_tuple[1].dtype 28 | 29 | 30 | def conv_nd(dims, *args, **kwargs): 31 | """ 32 | Create a 1D, 2D, or 3D convolution module. 33 | """ 34 | if dims == 1: 35 | return nn.Conv1d(*args, **kwargs) 36 | elif dims == 2: 37 | return nn.Conv2d(*args, **kwargs) 38 | elif dims == 3: 39 | return nn.Conv3d(*args, **kwargs) 40 | raise ValueError(f"unsupported dimensions: {dims}") 41 | 42 | 43 | def avg_pool_nd(dims, *args, **kwargs): 44 | """ 45 | Create a 1D, 2D, or 3D average pooling module. 46 | """ 47 | if dims == 1: 48 | return nn.AvgPool1d(*args, **kwargs) 49 | elif dims == 2: 50 | return nn.AvgPool2d(*args, **kwargs) 51 | elif dims == 3: 52 | return nn.AvgPool3d(*args, **kwargs) 53 | raise ValueError(f"unsupported dimensions: {dims}") 54 | 55 | 56 | class PoseAdaptor2D(nn.Module): 57 | def __init__(self, unet, pose_encoder): 58 | super(PoseAdaptor2D, self).__init__() 59 | self.unet = unet 60 | self.pose_encoder = pose_encoder 61 | 62 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, pose_embedding, F_mats=None, H_mats=None): 63 | if pose_embedding is not None: 64 | assert pose_embedding.ndim == 5 65 | bs = pose_embedding.shape[0] 66 | pose_embedding = rearrange(pose_embedding, "b c f h w -> (b f) c h w") 67 | pose_embedding_features = self.pose_encoder(pose_embedding) 68 | pose_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs) 69 | for x in pose_embedding_features] 70 | else: 71 | pose_embedding_features = None 72 | noise_pred = self.unet(noisy_latents, 73 | timesteps, 74 | encoder_hidden_states, 75 | F_mats=F_mats, 76 | H_mats=H_mats, 77 | pose_embedding_features=pose_embedding_features).sample 78 | return noise_pred 79 | 80 | 81 | class PoseAdaptor(nn.Module): 82 | def __init__(self, unet, pose_encoder): 83 | super().__init__() 84 | self.unet = unet 85 | self.pose_encoder = pose_encoder 86 | 87 | def forward(self, noisy_latents, timesteps, encoder_hidden_states, pose_embedding, F_mats=None, H_mats=None): 88 | if pose_embedding is not None: 89 | assert pose_embedding.ndim == 5 90 | bs = pose_embedding.shape[0] # b c f h w 91 | pose_embedding_features = self.pose_encoder(pose_embedding) # bf c h w 92 | pose_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs) 93 | for x in pose_embedding_features] 94 | else: 95 | pose_embedding_features = None 96 | output = self.unet(noisy_latents, 97 | timesteps, 98 | encoder_hidden_states, 99 | F_mats=F_mats, 100 | H_mats=H_mats, 101 | pose_embedding_features=pose_embedding_features) 102 | noise_pred = output.sample 103 | auxiliary = output.auxiliary 104 | return noise_pred, auxiliary 105 | 106 | 107 | class Downsample(nn.Module): 108 | """ 109 | A downsampling layer with an optional convolution. 110 | :param channels: channels in the inputs and outputs. 111 | :param use_conv: a bool determining if a convolution is applied. 112 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 113 | downsampling occurs in the inner-two dimensions. 114 | """ 115 | 116 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 117 | super().__init__() 118 | self.channels = channels 119 | self.out_channels = out_channels or channels 120 | self.use_conv = use_conv 121 | self.dims = dims 122 | stride = 2 if dims != 3 else (1, 2, 2) 123 | if use_conv: 124 | self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) 125 | else: 126 | assert self.channels == self.out_channels 127 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 128 | 129 | def forward(self, x): 130 | assert x.shape[1] == self.channels 131 | return self.op(x) 132 | 133 | 134 | class ResnetBlock(nn.Module): 135 | 136 | def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): 137 | super().__init__() 138 | ps = ksize // 2 139 | if in_c != out_c or sk == False: 140 | self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) 141 | else: 142 | self.in_conv = None 143 | self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) 144 | self.act = nn.ReLU() 145 | self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) 146 | if sk == False: 147 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) 148 | else: 149 | self.skep = None 150 | 151 | self.down = down 152 | if self.down == True: 153 | self.down_opt = Downsample(in_c, use_conv=use_conv) 154 | 155 | def forward(self, x): 156 | if self.down == True: 157 | x = self.down_opt(x) 158 | if self.in_conv is not None: # edit 159 | x = self.in_conv(x) 160 | 161 | h = self.block1(x) 162 | h = self.act(h) 163 | h = self.block2(h) 164 | if self.skep is not None: 165 | return h + self.skep(x) 166 | else: 167 | return h + x 168 | 169 | 170 | class CameraPoseEncoder2D(nn.Module): 171 | 172 | def __init__(self, 173 | downscale_factor, 174 | channels=[320, 640, 1280, 1280], 175 | nums_rb=3, 176 | cin=64, 177 | ksize=3, 178 | sk=False, 179 | use_conv=True, 180 | compression_factor=1): 181 | super(CameraPoseEncoder2D, self).__init__() 182 | self.unshuffle = nn.PixelUnshuffle(downscale_factor) 183 | self.channels = channels 184 | self.nums_rb = nums_rb 185 | self.encoder_down_blocks = nn.ModuleList() 186 | for i in range(len(channels)): 187 | down_layers = [] 188 | for j in range(nums_rb): 189 | if j == 0 and i != 0: 190 | down_layers.append(ResnetBlock(channels[i - 1], int(channels[i] / compression_factor), 191 | down=True, ksize=ksize, sk=sk, use_conv=use_conv)) 192 | elif j == 0: 193 | down_layers.append(ResnetBlock(channels[0], int(channels[i] / compression_factor), 194 | down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 195 | elif j == nums_rb - 1: 196 | down_layers.append(ResnetBlock(int(channels[i] / compression_factor), channels[i], 197 | down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 198 | else: 199 | down_layers.append(ResnetBlock(int(channels[i] / compression_factor), 200 | int(channels[i] / compression_factor), 201 | down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 202 | self.encoder_down_blocks.append(nn.Sequential(*down_layers)) 203 | self.encoder_conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1) 204 | 205 | @property 206 | def dtype(self) -> torch.dtype: 207 | """ 208 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 209 | """ 210 | return get_parameter_dtype(self) 211 | 212 | def forward(self, x): 213 | # unshuffle 214 | x = self.unshuffle(x) 215 | # extract features 216 | features = [] 217 | x = self.encoder_conv_in(x) # bf c h w 218 | 219 | for down_block in self.encoder_down_blocks: 220 | x = down_block(x) 221 | features.append(x) 222 | 223 | return features 224 | 225 | 226 | class PositionalEncoding(nn.Module): 227 | def __init__( 228 | self, 229 | d_model, 230 | dropout=0., 231 | max_len=32, 232 | ): 233 | super().__init__() 234 | self.dropout = nn.Dropout(p=dropout) 235 | position = torch.arange(max_len).unsqueeze(1) 236 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 237 | pe = torch.zeros(1, max_len, d_model) 238 | pe[0, :, 0::2, ...] = torch.sin(position * div_term) 239 | pe[0, :, 1::2, ...] = torch.cos(position * div_term) 240 | pe.unsqueeze_(-1).unsqueeze_(-1) 241 | self.register_buffer('pe', pe) 242 | 243 | def forward(self, x): 244 | x = x + self.pe[:, :x.size(1), ...] 245 | return self.dropout(x) 246 | 247 | 248 | class CameraPoseEncoder(nn.Module): 249 | 250 | def __init__(self, 251 | downscale_factor, 252 | channels=[320, 640, 1280, 1280], 253 | nums_rb=3, 254 | cin=64, 255 | ksize=3, 256 | sk=False, 257 | use_conv=True, 258 | compression_factor=1, 259 | temporal_attention_nhead=8, 260 | attention_block_types=("Temporal_Self", ), 261 | temporal_position_encoding=False, 262 | temporal_position_encoding_max_len=16, 263 | rescale_output_factor=1.0): 264 | super(CameraPoseEncoder, self).__init__() 265 | self.unshuffle = nn.PixelUnshuffle(downscale_factor) 266 | self.channels = channels 267 | self.nums_rb = nums_rb 268 | self.encoder_down_conv_blocks = nn.ModuleList() 269 | self.encoder_down_attention_blocks = nn.ModuleList() 270 | for i in range(len(channels)): 271 | conv_layers = nn.ModuleList() 272 | temporal_attention_layers = nn.ModuleList() 273 | for j in range(nums_rb): 274 | if j == 0 and i != 0: 275 | in_dim = channels[i - 1] 276 | out_dim = int(channels[i] / compression_factor) 277 | conv_layer = ResnetBlock(in_dim, out_dim, down=True, ksize=ksize, sk=sk, use_conv=use_conv) 278 | elif j == 0: 279 | in_dim = channels[0] 280 | out_dim = int(channels[i] / compression_factor) 281 | conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv) 282 | elif j == nums_rb - 1: 283 | in_dim = channels[i] / compression_factor 284 | out_dim = channels[i] 285 | conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv) 286 | else: 287 | in_dim = int(channels[i] / compression_factor) 288 | out_dim = int(channels[i] / compression_factor) 289 | conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv) 290 | temporal_attention_layer = TemporalTransformerBlock(dim=out_dim, 291 | num_attention_heads=temporal_attention_nhead, 292 | attention_head_dim=int(out_dim / temporal_attention_nhead), 293 | attention_block_types=attention_block_types, 294 | dropout=0.0, 295 | cross_attention_dim=None, 296 | temporal_position_encoding=temporal_position_encoding, 297 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 298 | rescale_output_factor=rescale_output_factor) 299 | conv_layers.append(conv_layer) 300 | temporal_attention_layers.append(temporal_attention_layer) 301 | self.encoder_down_conv_blocks.append(conv_layers) 302 | self.encoder_down_attention_blocks.append(temporal_attention_layers) 303 | 304 | self.encoder_conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1) 305 | 306 | @property 307 | def dtype(self) -> torch.dtype: 308 | """ 309 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 310 | """ 311 | return get_parameter_dtype(self) 312 | 313 | def forward(self, x): 314 | # unshuffle 315 | bs = x.shape[0] 316 | x = rearrange(x, "b c f h w -> (b f) c h w") 317 | x = self.unshuffle(x) 318 | # extract features 319 | features = [] 320 | x = self.encoder_conv_in(x) 321 | for res_block, attention_block in zip(self.encoder_down_conv_blocks, self.encoder_down_attention_blocks): 322 | for res_layer, attention_layer in zip(res_block, attention_block): 323 | x = res_layer(x) 324 | h, w = x.shape[-2:] 325 | x = rearrange(x, '(b f) c h w -> (b h w) f c', b=bs) 326 | x = attention_layer(x) 327 | x = rearrange(x, '(b h w) f c -> (b f) c h w', h=h, w=w) 328 | features.append(x) 329 | return features 330 | -------------------------------------------------------------------------------- /animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | from einops import rearrange, repeat 4 | from functools import partial 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from diffusers.models.activations import get_activation 12 | from diffusers.models.normalization import AdaGroupNorm 13 | from diffusers.models.attention_processor import SpatialNorm 14 | 15 | from diffusers.models.resnet import ResnetBlock2D 16 | from diffusers.models.transformer_2d import Transformer2DModel 17 | 18 | 19 | class InflatedConv3d(nn.Conv2d): 20 | def forward(self, x): 21 | video_length = x.shape[2] 22 | 23 | x = rearrange(x, "b c f h w -> (b f) c h w") 24 | x = super().forward(x) 25 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 26 | 27 | return x 28 | 29 | 30 | class InflatedGroupNorm(nn.GroupNorm): 31 | def forward(self, x): 32 | # return super().forward(x) 33 | 34 | video_length = x.shape[2] 35 | 36 | x = rearrange(x, "b c f h w -> (b f) c h w") 37 | x = super().forward(x) 38 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 39 | 40 | return x 41 | 42 | def zero_module(module): 43 | # Zero out the parameters of a module and return it. 44 | for p in module.parameters(): 45 | p.detach().zero_() 46 | return module 47 | 48 | 49 | class FusionBlock2D(nn.Module): 50 | r""" 51 | A Resnet block. 52 | 53 | Parameters: 54 | in_channels (`int`): The number of channels in the input. 55 | out_channels (`int`, *optional*, default to be `None`): 56 | The number of output channels for the first conv2d layer. If None, same as `in_channels`. 57 | dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. 58 | temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. 59 | groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. 60 | groups_out (`int`, *optional*, default to None): 61 | The number of groups to use for the second normalization layer. if set to None, same as `groups`. 62 | eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. 63 | non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. 64 | time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. 65 | By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or 66 | "ada_group" for a stronger conditioning with scale and shift. 67 | kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see 68 | [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. 69 | output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. 70 | use_in_shortcut (`bool`, *optional*, default to `True`): 71 | If `True`, add a 1x1 nn.conv2d layer for skip-connection. 72 | up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. 73 | down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. 74 | conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the 75 | `conv_shortcut` output. 76 | conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. 77 | If None, same as `out_channels`. 78 | """ 79 | 80 | def __init__( 81 | self, 82 | *, 83 | in_channels, 84 | out_channels=None, 85 | conv_shortcut=False, 86 | dropout=0.0, 87 | temb_channels=512, 88 | groups=32, 89 | groups_out=None, 90 | pre_norm=True, 91 | eps=1e-6, 92 | non_linearity="swish", 93 | skip_time_act=False, 94 | time_embedding_norm="default", # default, scale_shift, ada_group, spatial 95 | kernel=None, 96 | output_scale_factor=1.0, 97 | use_in_shortcut=None, 98 | up=False, 99 | down=False, 100 | conv_shortcut_bias: bool = True, 101 | conv_2d_out_channels: Optional[int] = None, 102 | 103 | zero_init=True, 104 | ): 105 | super().__init__() 106 | self.pre_norm = pre_norm 107 | self.pre_norm = True 108 | 109 | in_channels = in_channels * 2 110 | self.in_channels = in_channels 111 | 112 | out_channels = in_channels * 3 if out_channels is None else out_channels * 3 113 | self.out_channels = out_channels 114 | 115 | self.use_conv_shortcut = conv_shortcut 116 | self.up = up 117 | self.down = down 118 | self.output_scale_factor = output_scale_factor 119 | self.time_embedding_norm = time_embedding_norm 120 | self.skip_time_act = skip_time_act 121 | 122 | if groups_out is None: 123 | groups_out = groups 124 | 125 | if self.time_embedding_norm == "ada_group": 126 | self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) 127 | elif self.time_embedding_norm == "spatial": 128 | self.norm1 = SpatialNorm(in_channels, temb_channels) 129 | else: 130 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 131 | 132 | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 133 | 134 | if temb_channels is not None: 135 | if self.time_embedding_norm == "default": 136 | self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) 137 | elif self.time_embedding_norm == "scale_shift": 138 | self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) 139 | elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 140 | self.time_emb_proj = None 141 | else: 142 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 143 | else: 144 | self.time_emb_proj = None 145 | 146 | if self.time_embedding_norm == "ada_group": 147 | self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) 148 | elif self.time_embedding_norm == "spatial": 149 | self.norm2 = SpatialNorm(out_channels, temb_channels) 150 | else: 151 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 152 | 153 | self.dropout = torch.nn.Dropout(dropout) 154 | conv_2d_out_channels = conv_2d_out_channels or out_channels 155 | self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0) 156 | 157 | self.nonlinearity = get_activation(non_linearity) 158 | 159 | self.upsample = self.downsample = None 160 | if self.up: 161 | if kernel == "fir": 162 | fir_kernel = (1, 3, 3, 1) 163 | self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) 164 | elif kernel == "sde_vp": 165 | self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") 166 | else: 167 | self.upsample = Upsample2D(in_channels, use_conv=False) 168 | elif self.down: 169 | if kernel == "fir": 170 | fir_kernel = (1, 3, 3, 1) 171 | self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) 172 | elif kernel == "sde_vp": 173 | self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) 174 | else: 175 | self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") 176 | 177 | self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut 178 | 179 | self.conv_shortcut = None 180 | if self.use_in_shortcut: 181 | self.conv_shortcut = torch.nn.Conv2d( 182 | in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias 183 | ) 184 | 185 | conv_out = torch.nn.Conv2d( 186 | conv_2d_out_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, 187 | ) 188 | self.conv_out = zero_module(conv_out) if zero_init else conv_out 189 | 190 | def forward(self, init_hidden_state, post_hidden_states, temb): 191 | # init_hidden_state: b c 1 h w 192 | # post_hidden_states: b c (f-1) h w 193 | 194 | video_length = post_hidden_states.shape[2] 195 | repeated_init_hidden_state = repeat(init_hidden_state, "b c f h w -> b c (n f) h w", n=video_length) 196 | 197 | hidden_states = torch.cat([repeated_init_hidden_state, post_hidden_states], dim=1) 198 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 199 | input_tensor = hidden_states 200 | 201 | if temb.shape[0] != input_tensor.shape[0]: 202 | temb = repeat(temb, "b c -> (b n) c", n=input_tensor.shape[0] // temb.shape[0]) 203 | assert temb.shape[0] == input_tensor.shape[0], f"{temb.shape}, {input_tensor.shape}" 204 | 205 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 206 | hidden_states = self.norm1(hidden_states, temb) 207 | else: 208 | hidden_states = self.norm1(hidden_states) 209 | 210 | hidden_states = self.nonlinearity(hidden_states) 211 | 212 | if self.upsample is not None: 213 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 214 | if hidden_states.shape[0] >= 64: 215 | input_tensor = input_tensor.contiguous() 216 | hidden_states = hidden_states.contiguous() 217 | input_tensor = self.upsample(input_tensor) 218 | hidden_states = self.upsample(hidden_states) 219 | elif self.downsample is not None: 220 | input_tensor = self.downsample(input_tensor) 221 | hidden_states = self.downsample(hidden_states) 222 | 223 | hidden_states = self.conv1(hidden_states) 224 | 225 | if self.time_emb_proj is not None: 226 | if not self.skip_time_act: 227 | temb = self.nonlinearity(temb) 228 | temb = self.time_emb_proj(temb)[:, :, None, None] 229 | 230 | if temb is not None and self.time_embedding_norm == "default": 231 | hidden_states = hidden_states + temb 232 | 233 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 234 | hidden_states = self.norm2(hidden_states, temb) 235 | else: 236 | hidden_states = self.norm2(hidden_states) 237 | 238 | if temb is not None and self.time_embedding_norm == "scale_shift": 239 | scale, shift = torch.chunk(temb, 2, dim=1) 240 | hidden_states = hidden_states * (1 + scale) + shift 241 | 242 | hidden_states = self.nonlinearity(hidden_states) 243 | 244 | hidden_states = self.dropout(hidden_states) 245 | hidden_states = self.conv2(hidden_states) 246 | 247 | if self.conv_shortcut is not None: 248 | input_tensor = self.conv_shortcut(input_tensor) 249 | 250 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 251 | 252 | output_tensor = self.conv_out(output_tensor) 253 | 254 | output_tensor = rearrange(output_tensor, "(b f) c h w -> b c f h w", f=video_length) 255 | scale_1, scale_2, shift = output_tensor.chunk(3, dim=1) 256 | 257 | # output_tensor = (1 + scale_1) * repeated_init_hidden_state + scale_2 * post_hidden_states + shift 258 | output_tensor = scale_1 * repeated_init_hidden_state + (1 + scale_2) * post_hidden_states + shift 259 | 260 | return output_tensor 261 | 262 | class Upsample3D(nn.Module): 263 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 264 | super().__init__() 265 | self.channels = channels 266 | self.out_channels = out_channels or channels 267 | self.use_conv = use_conv 268 | self.use_conv_transpose = use_conv_transpose 269 | self.name = name 270 | 271 | conv = None 272 | if use_conv_transpose: 273 | raise NotImplementedError 274 | elif use_conv: 275 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 276 | 277 | def forward(self, hidden_states, output_size=None): 278 | assert hidden_states.shape[1] == self.channels 279 | 280 | if self.use_conv_transpose: 281 | raise NotImplementedError 282 | 283 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 284 | dtype = hidden_states.dtype 285 | if dtype == torch.bfloat16: 286 | hidden_states = hidden_states.to(torch.float32) 287 | 288 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 289 | if hidden_states.shape[0] >= 64: 290 | hidden_states = hidden_states.contiguous() 291 | 292 | # if `output_size` is passed we force the interpolation output 293 | # size and do not make use of `scale_factor=2` 294 | if output_size is None: 295 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 296 | else: 297 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 298 | 299 | # If the input is bfloat16, we cast back to bfloat16 300 | if dtype == torch.bfloat16: 301 | hidden_states = hidden_states.to(dtype) 302 | 303 | # if self.use_conv: 304 | # if self.name == "conv": 305 | # hidden_states = self.conv(hidden_states) 306 | # else: 307 | # hidden_states = self.Conv2d_0(hidden_states) 308 | hidden_states = self.conv(hidden_states) 309 | 310 | return hidden_states 311 | 312 | 313 | class Downsample3D(nn.Module): 314 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 315 | super().__init__() 316 | self.channels = channels 317 | self.out_channels = out_channels or channels 318 | self.use_conv = use_conv 319 | self.padding = padding 320 | stride = 2 321 | self.name = name 322 | 323 | if use_conv: 324 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 325 | else: 326 | raise NotImplementedError 327 | 328 | def forward(self, hidden_states): 329 | assert hidden_states.shape[1] == self.channels 330 | if self.use_conv and self.padding == 0: 331 | raise NotImplementedError 332 | 333 | assert hidden_states.shape[1] == self.channels 334 | hidden_states = self.conv(hidden_states) 335 | 336 | return hidden_states 337 | 338 | 339 | class ResnetBlock3D(nn.Module): 340 | def __init__( 341 | self, 342 | *, 343 | in_channels, 344 | out_channels=None, 345 | conv_shortcut=False, 346 | dropout=0.0, 347 | temb_channels=512, 348 | groups=32, 349 | groups_out=None, 350 | pre_norm=True, 351 | eps=1e-6, 352 | non_linearity="swish", 353 | time_embedding_norm="default", 354 | output_scale_factor=1.0, 355 | use_in_shortcut=None, 356 | ): 357 | super().__init__() 358 | self.pre_norm = pre_norm 359 | self.pre_norm = True 360 | self.in_channels = in_channels 361 | out_channels = in_channels if out_channels is None else out_channels 362 | self.out_channels = out_channels 363 | self.use_conv_shortcut = conv_shortcut 364 | self.time_embedding_norm = time_embedding_norm 365 | self.output_scale_factor = output_scale_factor 366 | 367 | if groups_out is None: 368 | groups_out = groups 369 | 370 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 371 | 372 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 373 | 374 | if temb_channels is not None: 375 | if self.time_embedding_norm == "default": 376 | time_emb_proj_out_channels = out_channels 377 | elif self.time_embedding_norm == "scale_shift": 378 | time_emb_proj_out_channels = out_channels * 2 379 | else: 380 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 381 | 382 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 383 | else: 384 | self.time_emb_proj = None 385 | 386 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 387 | self.dropout = torch.nn.Dropout(dropout) 388 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 389 | 390 | if non_linearity == "swish": 391 | self.nonlinearity = lambda x: F.silu(x) 392 | elif non_linearity == "mish": 393 | self.nonlinearity = Mish() 394 | elif non_linearity == "silu": 395 | self.nonlinearity = nn.SiLU() 396 | 397 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 398 | 399 | self.conv_shortcut = None 400 | if self.use_in_shortcut: 401 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 402 | 403 | def forward(self, input_tensor, temb): 404 | # input: b c f h w 405 | 406 | hidden_states = input_tensor 407 | 408 | video_length = hidden_states.shape[2] 409 | emb = repeat(emb, "b c -> (b f) c", f=video_length) 410 | 411 | hidden_states = self.norm1(hidden_states) 412 | hidden_states = self.nonlinearity(hidden_states) 413 | 414 | hidden_states = self.conv1(hidden_states) 415 | 416 | if temb is not None: 417 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 418 | 419 | if temb is not None and self.time_embedding_norm == "default": 420 | hidden_states = hidden_states + temb 421 | 422 | hidden_states = self.norm2(hidden_states) 423 | 424 | if temb is not None and self.time_embedding_norm == "scale_shift": 425 | scale, shift = torch.chunk(temb, 2, dim=1) 426 | hidden_states = hidden_states * (1 + scale) + shift 427 | 428 | hidden_states = self.nonlinearity(hidden_states) 429 | 430 | hidden_states = self.dropout(hidden_states) 431 | hidden_states = self.conv2(hidden_states) 432 | 433 | if self.conv_shortcut is not None: 434 | input_tensor = self.conv_shortcut(input_tensor) 435 | 436 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 437 | 438 | return output_tensor 439 | 440 | 441 | class Mish(torch.nn.Module): 442 | def forward(self, hidden_states): 443 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /animatediff/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Conversion script for the LoRA's safetensors checkpoints. """ 17 | 18 | import argparse 19 | 20 | import torch 21 | from safetensors.torch import load_file 22 | 23 | from diffusers import StableDiffusionPipeline 24 | import pdb 25 | 26 | 27 | 28 | def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0): 29 | # directly update weight in diffusers model 30 | for key in state_dict: 31 | # only process lora down key 32 | if "up." in key: continue 33 | 34 | up_key = key.replace(".down.", ".up.") 35 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 36 | model_key = model_key.replace("to_out.", "to_out.0.") 37 | layer_infos = model_key.split(".")[:-1] 38 | 39 | curr_layer = pipeline.unet 40 | while len(layer_infos) > 0: 41 | temp_name = layer_infos.pop(0) 42 | curr_layer = curr_layer.__getattr__(temp_name) 43 | 44 | weight_down = state_dict[key] 45 | weight_up = state_dict[up_key] 46 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 47 | 48 | return pipeline 49 | 50 | 51 | 52 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 53 | # load base model 54 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 55 | 56 | # load LoRA weight from .safetensors 57 | # state_dict = load_file(checkpoint_path) 58 | 59 | visited = [] 60 | 61 | # directly update weight in diffusers model 62 | for key in state_dict: 63 | # it is suggested to print out the key, it usually will be something like below 64 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 65 | 66 | # as we have set the alpha beforehand, so just skip 67 | if ".alpha" in key or key in visited: 68 | continue 69 | 70 | if "text" in key: 71 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 72 | curr_layer = pipeline.text_encoder 73 | else: 74 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 75 | curr_layer = pipeline.unet 76 | 77 | # find the target layer 78 | temp_name = layer_infos.pop(0) 79 | while len(layer_infos) > -1: 80 | try: 81 | curr_layer = curr_layer.__getattr__(temp_name) 82 | if len(layer_infos) > 0: 83 | temp_name = layer_infos.pop(0) 84 | elif len(layer_infos) == 0: 85 | break 86 | except Exception: 87 | if len(temp_name) > 0: 88 | temp_name += "_" + layer_infos.pop(0) 89 | else: 90 | temp_name = layer_infos.pop(0) 91 | 92 | pair_keys = [] 93 | if "lora_down" in key: 94 | pair_keys.append(key.replace("lora_down", "lora_up")) 95 | pair_keys.append(key) 96 | else: 97 | pair_keys.append(key) 98 | pair_keys.append(key.replace("lora_up", "lora_down")) 99 | 100 | # update weight 101 | if len(state_dict[pair_keys[0]].shape) == 4: 102 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 103 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 104 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 105 | else: 106 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 107 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 108 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 109 | 110 | # update visited list 111 | for item in pair_keys: 112 | visited.append(item) 113 | 114 | return pipeline 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument( 121 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 122 | ) 123 | parser.add_argument( 124 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 125 | ) 126 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 127 | parser.add_argument( 128 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 129 | ) 130 | parser.add_argument( 131 | "--lora_prefix_text_encoder", 132 | default="lora_te", 133 | type=str, 134 | help="The prefix of text encoder weight in safetensors", 135 | ) 136 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 137 | parser.add_argument( 138 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 139 | ) 140 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 141 | 142 | args = parser.parse_args() 143 | 144 | base_model_path = args.base_model_path 145 | checkpoint_path = args.checkpoint_path 146 | dump_path = args.dump_path 147 | lora_prefix_unet = args.lora_prefix_unet 148 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 149 | alpha = args.alpha 150 | 151 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 152 | 153 | pipe = pipe.to(args.device) 154 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 155 | -------------------------------------------------------------------------------- /animatediff/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | import logging 4 | import sys 5 | import imageio 6 | import numpy as np 7 | import atexit 8 | from typing import Union 9 | from termcolor import colored 10 | 11 | import torch 12 | import torchvision 13 | 14 | from safetensors import safe_open 15 | from tqdm import tqdm 16 | from einops import rearrange 17 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 18 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers 19 | 20 | import importlib 21 | 22 | 23 | def instantiate_from_config(config, **additional_kwargs): 24 | if not "target" in config: 25 | if config == '__is_first_stage__': 26 | return None 27 | elif config == "__is_unconditional__": 28 | return None 29 | raise KeyError("Expected key `target` to instantiate.") 30 | 31 | additional_kwargs.update(config.get("kwargs", dict())) 32 | return get_obj_from_str(config["target"])(**additional_kwargs) 33 | 34 | 35 | def get_obj_from_str(string, reload=False): 36 | module, cls = string.rsplit(".", 1) 37 | if reload: 38 | module_imp = importlib.import_module(module) 39 | importlib.reload(module_imp) 40 | return getattr(importlib.import_module(module, package=None), cls) 41 | 42 | 43 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, mp4_also=False): 44 | videos = rearrange(videos, "b c t h w -> t b c h w") 45 | outputs = [] 46 | for x in videos: 47 | x = torchvision.utils.make_grid(x, nrow=n_rows) 48 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 49 | if rescale: 50 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 51 | x = (x * 255).numpy().astype(np.uint8) 52 | outputs.append(x) 53 | 54 | os.makedirs(os.path.dirname(path), exist_ok=True) 55 | imageio.mimsave(path, outputs, fps=fps) 56 | if mp4_also: 57 | imageio.mimsave(path.replace(".gif", ".mp4"), outputs, fps=fps) 58 | 59 | 60 | def save_video_as_images(video: torch.Tensor, path: str, rescale=False): 61 | video = rearrange(video, "c t h w -> t c h w") 62 | os.makedirs(os.path.dirname(path), exist_ok=True) 63 | output_filenames = [] 64 | for idx, frame in enumerate(video): 65 | frame = frame.transpose(0, 1).transpose(1, 2).squeeze(-1) 66 | if rescale: 67 | frame = (frame + 1.0) / 2.0 # -1,1 -> 0,1 68 | frame = (frame * 255).numpy().astype(np.uint8) 69 | output_path = path+"_%d.png"%idx 70 | output_filenames.append(output_path) 71 | imageio.imsave(output_path, frame) 72 | return output_filenames 73 | 74 | # DDIM Inversion 75 | @torch.no_grad() 76 | def init_prompt(prompt, pipeline): 77 | uncond_input = pipeline.tokenizer( 78 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 79 | return_tensors="pt" 80 | ) 81 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 82 | text_input = pipeline.tokenizer( 83 | [prompt], 84 | padding="max_length", 85 | max_length=pipeline.tokenizer.model_max_length, 86 | truncation=True, 87 | return_tensors="pt", 88 | ) 89 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 90 | context = torch.cat([uncond_embeddings, text_embeddings]) 91 | 92 | return context 93 | 94 | 95 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 96 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 97 | timestep, next_timestep = min( 98 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 99 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 100 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 101 | beta_prod_t = 1 - alpha_prod_t 102 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 103 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 104 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 105 | return next_sample 106 | 107 | 108 | def get_noise_pred_single(latents, t, context, unet): 109 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 110 | return noise_pred 111 | 112 | 113 | @torch.no_grad() 114 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 115 | context = init_prompt(prompt, pipeline) 116 | uncond_embeddings, cond_embeddings = context.chunk(2) 117 | all_latent = [latent] 118 | latent = latent.clone().detach() 119 | for i in tqdm(range(num_inv_steps)): 120 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 121 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 122 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 123 | all_latent.append(latent) 124 | return all_latent 125 | 126 | 127 | @torch.no_grad() 128 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 129 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 130 | return ddim_latents 131 | 132 | def load_weights( 133 | animation_pipeline, 134 | # motion module 135 | motion_module_path = "", 136 | motion_module_lora_configs = [], 137 | # image layers 138 | dreambooth_model_path = "", 139 | lora_model_path = "", 140 | lora_alpha = 0.8, 141 | ): 142 | # 1.1 motion module 143 | unet_state_dict = {} 144 | if motion_module_path != "": 145 | print(f"load motion module from {motion_module_path}") 146 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 147 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 148 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 149 | 150 | missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 151 | assert len(unexpected) == 0 152 | del unet_state_dict 153 | 154 | if dreambooth_model_path != "": 155 | print(f"load dreambooth model from {dreambooth_model_path}") 156 | if dreambooth_model_path.endswith(".safetensors"): 157 | dreambooth_state_dict = {} 158 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 159 | for key in f.keys(): 160 | dreambooth_state_dict[key] = f.get_tensor(key) 161 | elif dreambooth_model_path.endswith(".ckpt"): 162 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 163 | 164 | # 1. vae 165 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 166 | animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 167 | # 2. unet 168 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 169 | animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 170 | # 3. text_model 171 | animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 172 | del dreambooth_state_dict 173 | 174 | if lora_model_path != "": 175 | print(f"load lora model from {lora_model_path}") 176 | assert lora_model_path.endswith(".safetensors") 177 | lora_state_dict = {} 178 | with safe_open(lora_model_path, framework="pt", device="cpu") as f: 179 | for key in f.keys(): 180 | lora_state_dict[key] = f.get_tensor(key) 181 | 182 | animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 183 | del lora_state_dict 184 | 185 | 186 | for motion_module_lora_config in motion_module_lora_configs: 187 | path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 188 | print(f"load motion LoRA from {path}") 189 | 190 | motion_lora_state_dict = torch.load(path, map_location="cpu") 191 | motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 192 | 193 | animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha) 194 | 195 | return animation_pipeline 196 | 197 | 198 | class _ColorfulFormatter(logging.Formatter): 199 | def __init__(self, *args, **kwargs): 200 | self._root_name = kwargs.pop("root_name") + "." 201 | self._abbrev_name = kwargs.pop("abbrev_name", "") 202 | if len(self._abbrev_name): 203 | self._abbrev_name = self._abbrev_name + "." 204 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 205 | 206 | def formatMessage(self, record): 207 | record.name = record.name.replace(self._root_name, self._abbrev_name) 208 | log = super(_ColorfulFormatter, self).formatMessage(record) 209 | if record.levelno == logging.WARNING: 210 | prefix = colored("WARNING", "red", attrs=["blink"]) 211 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 212 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 213 | else: 214 | return log 215 | return prefix + " " + log 216 | 217 | 218 | # cache the opened file object, so that different calls to `setup_logger` 219 | # with the same file name can safely write to the same file. 220 | @functools.lru_cache(maxsize=None) 221 | def _cached_log_stream(filename): 222 | # use 1K buffer if writing to cloud storage 223 | io = open(filename, "a", buffering=1024 if "://" in filename else -1) 224 | atexit.register(io.close) 225 | return io 226 | 227 | @functools.lru_cache() 228 | def setup_logger(output, distributed_rank, color=True, name='AnimateDiff', abbrev_name=None): 229 | logger = logging.getLogger(name) 230 | logger.setLevel(logging.DEBUG) 231 | logger.propagate = False 232 | 233 | if abbrev_name is None: 234 | abbrev_name = 'AD' 235 | plain_formatter = logging.Formatter( 236 | "[%(asctime)s] %(name)s:%(lineno)d %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 237 | ) 238 | 239 | # stdout logging: master only 240 | if distributed_rank == 0: 241 | ch = logging.StreamHandler(stream=sys.stdout) 242 | ch.setLevel(logging.DEBUG) 243 | if color: 244 | formatter = _ColorfulFormatter( 245 | colored("[%(asctime)s %(name)s:%(lineno)d]: ", "green") + "%(message)s", 246 | datefmt="%m/%d %H:%M:%S", 247 | root_name=name, 248 | abbrev_name=str(abbrev_name), 249 | ) 250 | else: 251 | formatter = plain_formatter 252 | ch.setFormatter(formatter) 253 | logger.addHandler(ch) 254 | 255 | # file logging: all workers 256 | if output is not None: 257 | if output.endswith(".txt") or output.endswith(".log"): 258 | filename = output 259 | else: 260 | filename = os.path.join(output, "log.txt") 261 | if distributed_rank > 0: 262 | filename = filename + ".rank{}".format(distributed_rank) 263 | os.makedirs(os.path.dirname(filename), exist_ok=True) 264 | 265 | fh = logging.StreamHandler(_cached_log_stream(filename)) 266 | fh.setLevel(logging.DEBUG) 267 | fh.setFormatter(plain_formatter) 268 | logger.addHandler(fh) 269 | 270 | return logger 271 | 272 | 273 | def format_time(elapsed_time): 274 | # Time thresholds 275 | minute = 60 276 | hour = 60 * minute 277 | day = 24 * hour 278 | 279 | days, remainder = divmod(elapsed_time, day) 280 | hours, remainder = divmod(remainder, hour) 281 | minutes, seconds = divmod(remainder, minute) 282 | 283 | formatted_time = "" 284 | 285 | if days > 0: 286 | formatted_time += f"{int(days)} days " 287 | if hours > 0: 288 | formatted_time += f"{int(hours)} hours " 289 | if minutes > 0: 290 | formatted_time += f"{int(minutes)} minutes " 291 | if seconds > 0: 292 | formatted_time += f"{seconds:.2f} seconds" 293 | 294 | return formatted_time.strip() 295 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CollaborativeVideoDiffusion/CVD/107f299bd75c7a37158c52427de473cec86c649a/assets/.DS_Store -------------------------------------------------------------------------------- /assets/2c80f9eb0d3b2bb4.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=sLIFyXD2ujI 2 | 77444033 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.980310440 0.030424286 -0.195104495 -0.195846403 -0.034550700 0.999244750 -0.017780757 0.034309913 0.194416180 0.024171660 0.980621278 -0.178639121 3 | 77610867 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.973806441 0.034138829 -0.224801064 -0.221452338 -0.039088678 0.999080658 -0.017603843 0.038706263 0.223993421 0.025929911 0.974245667 -0.219951444 4 | 77777700 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.965910375 0.037603889 -0.256131083 -0.242696017 -0.043735024 0.998875856 -0.018281631 0.046505467 0.255155712 0.028860316 0.966469169 -0.265310453 5 | 77944533 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.956261098 0.040829532 -0.289650917 -0.252766079 -0.048421524 0.998644531 -0.019089982 0.054620904 0.288478881 0.032280345 0.956941962 -0.321621308 6 | 78144733 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.941536248 0.043692805 -0.334066480 -0.250198162 -0.053955212 0.998311937 -0.021497937 0.069548726 0.332563221 0.038265716 0.942304313 -0.401964240 7 | 78311567 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.926658213 0.046350952 -0.373036444 -0.239336491 -0.058738846 0.998033047 -0.021904159 0.077439241 0.371287435 0.042209402 0.927558064 -0.474019461 8 | 78478400 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.909880757 0.048629351 -0.412009954 -0.218247042 -0.063676558 0.997708619 -0.022863906 0.088967126 0.409954011 0.047038805 0.910892427 -0.543114491 9 | 78645233 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.891359746 0.050869841 -0.450433195 -0.185763327 -0.067926541 0.997452736 -0.021771761 0.093745158 0.448178291 0.050002839 0.892544627 -0.611223637 10 | 78845433 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.870108902 0.054302387 -0.489858925 -0.153515269 -0.074510135 0.996981323 -0.021829695 0.107765162 0.487194777 0.055493668 0.871528387 -0.691303250 11 | 79012267 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.852772951 0.056338910 -0.519234240 -0.128052677 -0.078825951 0.996660411 -0.021319628 0.116291007 0.516299069 0.059109934 0.854366004 -0.760654136 12 | 79179100 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.835254073 0.059146367 -0.546673834 -0.101344556 -0.084243484 0.996225357 -0.020929486 0.126763936 0.543372452 0.063535146 0.837083995 -0.832841061 13 | 79345933 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.818755865 0.062443536 -0.570736051 -0.077325807 -0.089739971 0.995768547 -0.019791666 0.136091605 0.567085147 0.067422375 0.820895016 -0.908256727 14 | 79546133 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.798365474 0.066207208 -0.598522484 -0.043774887 -0.096616283 0.995144248 -0.018795265 0.150808225 0.594371796 0.072832510 0.800885499 -0.994657638 15 | 79712967 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.781648815 0.069040783 -0.619885862 -0.013285614 -0.101820730 0.994646847 -0.017611075 0.161173621 0.615351617 0.076882906 0.784494340 -1.070102980 16 | 79879800 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.765168309 0.072694756 -0.639713168 0.019850080 -0.108554602 0.993946910 -0.016894773 0.177612448 0.634612799 0.082371153 0.768428028 -1.147576811 17 | 80080000 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.745406330 0.077463314 -0.662094295 0.062107046 -0.117075674 0.993000031 -0.015629012 0.200140798 0.656248927 0.089165099 0.749257565 -1.238600776 18 | -------------------------------------------------------------------------------- /assets/2f25826f0d0ef09a.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=t-mlAKnESzQ 2 | 167200000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991872013 -0.011311784 0.126735851 0.400533760 0.012037775 0.999915242 -0.004963919 -0.047488550 -0.126668960 0.006449190 0.991924107 -0.414499612 3 | 167467000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991945148 -0.011409644 0.126153216 0.506974565 0.012122569 0.999914587 -0.004884966 -0.069421149 -0.126086697 0.006374919 0.991998732 -0.517325825 4 | 167734000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.992271781 -0.010751382 0.123616949 0.590358341 0.011312660 0.999928653 -0.003839425 -0.085158661 -0.123566844 0.005208189 0.992322564 -0.599035085 5 | 168034000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993287027 -0.009973313 0.115245141 0.673577580 0.010455138 0.999938965 -0.003577147 -0.104263255 -0.115202427 0.004758038 0.993330657 -0.691557669 6 | 168301000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993988216 -0.009955904 0.109033749 0.753843765 0.010435819 0.999938190 -0.003831771 -0.106670354 -0.108988866 0.004946592 0.994030654 -0.805538867 7 | 168602000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994774222 -0.010583352 0.101549298 0.846176230 0.011122120 0.999926925 -0.004740742 -0.089426372 -0.101491705 0.005845411 0.994819224 -0.933449460 8 | 168869000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.995415390 -0.010595482 0.095057681 0.913119395 0.011053002 0.999929726 -0.004287821 -0.072756893 -0.095005572 0.005318835 0.995462537 -1.037255409 9 | 169169000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996029556 -0.009414902 0.088523701 0.977259045 0.009879347 0.999939620 -0.004809874 -0.042104006 -0.088473074 0.005665333 0.996062458 -1.127427189 10 | 169436000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996695757 -0.008890423 0.080737323 1.025351476 0.009221899 0.999950528 -0.003733651 -0.007486727 -0.080700137 0.004465866 0.996728420 -1.188659636 11 | 169736000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997404695 -0.008404067 0.071506783 1.073562767 0.008649707 0.999957681 -0.003126226 0.054879890 -0.071477488 0.003736625 0.997435212 -1.216979926 12 | 170003000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997887254 -0.008228444 0.064446673 1.110116903 0.008409287 0.999961436 -0.002535321 0.124372514 -0.064423330 0.003071915 0.997917950 -1.231904045 13 | 170303000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998332024 -0.007790270 0.057205603 1.136173895 0.007975516 0.999963641 -0.003010646 0.212542522 -0.057180069 0.003461868 0.998357892 -1.242942079 14 | 170570000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998471320 -0.007715963 0.054730706 1.159189486 0.007868989 0.999965727 -0.002581036 0.310163907 -0.054708913 0.003007766 0.998497844 -1.245661417 15 | 170871000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998552144 -0.007742116 0.053231847 1.173763753 0.007991423 0.999958038 -0.004472161 0.412779543 -0.053194992 0.004891084 0.998572171 -1.229165757 16 | 171137000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998553872 -0.007909958 0.053175092 1.179029258 0.008138723 0.999958515 -0.004086919 0.509089997 -0.053140558 0.004513786 0.998576820 -1.196146494 17 | 171438000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998469293 -0.008281939 0.054685175 1.181414517 0.008542870 0.999953210 -0.004539483 0.618089736 -0.054645021 0.004999703 0.998493314 -1.159911786 18 | -------------------------------------------------------------------------------- /assets/cameractrl_prompts.json: -------------------------------------------------------------------------------- 1 | {"prompts":["A serene mountain lake at sunrise, with mist hovering over the water.", 2 | "A still life of vintage objects on a wooden table.", 3 | "Ripe apples on a wooden table.", 4 | "Turtle swimming in ocean.", 5 | "A horse is eating grass on the grassland.", 6 | "Natural hot spring the steam floats up.", 7 | "A fireworks display illuminating the night sky.", 8 | "Rocky coastline with crashing waves.", 9 | "A fish is swimming in the aquarium tank.", 10 | "massive, multi-tiered elven palace adorned with flowing waterfalls, its cascades forming staircases between ethereal realms", 11 | "The sunflower swaying in the wind."], 12 | "negative_prompts":[ 13 | "worst quality, low quality, letterboxed", 14 | "worst quality, low quality, letterboxed", 15 | "worst quality, low quality, letterboxed", 16 | "worst quality, low quality, letterboxed", 17 | "worst quality, low quality, letterboxed", 18 | "worst quality, low quality, letterboxed", 19 | "worst quality, low quality, letterboxed", 20 | "worst quality, low quality, letterboxed", 21 | "worst quality, low quality, letterboxed", 22 | "worst quality, low quality, letterboxed", 23 | "worst quality, low quality, letterboxed"]} -------------------------------------------------------------------------------- /assets/cameractrl_prompts_for_circle.json: -------------------------------------------------------------------------------- 1 | {"prompts":[ 2 | "Tranquil sunset over a calm lake.", 3 | "Cozy fireplace in a rustic cabin.", 4 | "Beach party with bonfire and music.", 5 | "Dynamic urban park with people relaxing and playing.", 6 | "Peaceful countryside landscape with farm animals.", 7 | "Magical fairy tale castle in a forest clearing.", 8 | "Vibrant market scene with stalls and shoppers.", 9 | "Futuristic city skyline with flying cars.", 10 | "Serene forest with sunlight filtering through trees."]} -------------------------------------------------------------------------------- /assets/cameractrl_prompts_for_interpolate.json: -------------------------------------------------------------------------------- 1 | {"prompts":[ 2 | "A still life of vintage objects on a wooden table.", 3 | "peaceful underwater journey featuring a coral reef, teeming with colorful fish and gentle movements of water plants", 4 | "A dynamic fly-through of a futuristic city with neon lights, high-tech structures, and hover cars zipping through aerial highways", 5 | "Turtle swimming in ocean.", 6 | "Waves crashing on a rocky shore at sunset, with the ocean reflecting the fiery colors of the sky and the sound of water echoing.", 7 | "A horse is eating grass on the grassland.", 8 | "Natural hot spring the steam floats up.", 9 | "Rocky coastline with crashing waves.", 10 | "A fish is swimming in the aquarium tank.", 11 | "massive, multi-tiered elven palace adorned with flowing waterfalls, its cascades forming staircases between ethereal realms"]} -------------------------------------------------------------------------------- /assets/pose_files/0bf152ef84195293.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=QShWPZxTDoE 2 | 158692025 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.780003667 0.059620168 -0.622928321 0.726968666 -0.062449891 0.997897983 0.017311305 0.217967188 0.622651041 0.025398925 0.782087326 -1.002211444 3 | 158958959 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.743836701 0.064830206 -0.665209770 0.951841944 -0.068305343 0.997446954 0.020830527 0.206496789 0.664861917 0.029942872 0.746365905 -1.084913992 4 | 159225893 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.697046876 0.070604131 -0.713540971 1.208789672 -0.074218854 0.996899366 0.026138915 0.196421447 0.713174045 0.034738146 0.700125754 -1.130142078 5 | 159526193 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.635762572 0.077846259 -0.767949164 1.465161122 -0.080595709 0.996158004 0.034256749 0.157107229 0.767665446 0.040114246 0.639594078 -1.136893070 6 | 159793126 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.593250692 0.083153486 -0.800711632 1.635091834 -0.085384794 0.995539784 0.040124334 0.135863998 0.800476789 0.044564810 0.597704709 -1.166997229 7 | 160093427 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.555486798 0.087166689 -0.826943994 1.803789619 -0.089439675 0.994984210 0.044799786 0.145490422 0.826701283 0.049075913 0.560496747 -1.243827350 8 | 160360360 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.523399472 0.090266660 -0.847292721 1.945815368 -0.093254104 0.994468153 0.048340045 0.174777447 0.846969128 0.053712368 0.528921843 -1.336914479 9 | 160660661 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.491546303 0.092127070 -0.865964711 2.093852892 -0.095617607 0.994085968 0.051482171 0.196702533 0.865586221 0.057495601 0.497448236 -1.439709380 10 | 160927594 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.475284129 0.093297184 -0.874871790 2.200792438 -0.096743606 0.993874133 0.053430639 0.209217395 0.874497354 0.059243519 0.481398523 -1.547068315 11 | 161227895 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.464444131 0.093880348 -0.880612373 2.324141986 -0.097857766 0.993716478 0.054326952 0.220651207 0.880179226 0.060942926 0.470712721 -1.712512928 12 | 161494828 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.458157241 0.093640216 -0.883925021 2.443100890 -0.098046601 0.993691206 0.054448847 0.257385043 0.883447111 0.061719712 0.464447916 -1.885672329 13 | 161795128 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.457354397 0.093508720 -0.884354591 2.543246338 -0.097820736 0.993711591 0.054482624 0.281562244 0.883888066 0.061590351 0.463625461 -2.094829165 14 | 162062062 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.465170115 0.093944497 -0.880222261 2.606377358 -0.097235762 0.993758380 0.054675922 0.277376127 0.879864752 0.060155477 0.471401453 -2.299280675 15 | 162362362 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.511845231 0.090872414 -0.854257941 2.576774100 -0.093636356 0.994366586 0.049672548 0.270516319 0.853959382 0.054564942 0.517470777 -2.624374352 16 | 162629296 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.590568483 0.083218277 -0.802685261 2.398318316 -0.085610889 0.995516419 0.040222570 0.282138215 0.802433550 0.044964414 0.595045030 -3.012309268 17 | 162929596 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.684302032 0.072693504 -0.725566208 2.086323553 -0.074529484 0.996780157 0.029575195 0.310959312 0.725379944 0.033837710 0.687516510 -3.456740526 18 | -------------------------------------------------------------------------------- /assets/pose_files/0c11dbe781b1c11c.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=a-Unpcomk5k 2 | 89889800 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.959632158 -0.051068146 0.276583046 0.339363991 0.046715312 0.998659134 0.022308502 0.111317310 -0.277351439 -0.008487292 0.960731030 -0.353512177 3 | 90156733 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.939171016 -0.057914909 0.338531673 0.380727498 0.052699961 0.998307705 0.024584483 0.134404073 -0.339382589 -0.005248427 0.940633774 -0.477942109 4 | 90423667 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.913449824 -0.063028678 0.402040780 0.393354042 0.056629892 0.998008251 0.027794635 0.151535333 -0.402991891 -0.002621480 0.915199816 -0.622810637 5 | 90723967 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.879072070 -0.069992281 0.471522361 0.381271678 0.062575974 0.997545719 0.031412520 0.175549569 -0.472563744 0.001892101 0.881294429 -0.821022008 6 | 90990900 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.846152365 -0.078372896 0.527146876 0.360267421 0.071291871 0.996883452 0.033775900 0.212440374 -0.528151155 0.009001731 0.849102676 -1.013792538 7 | 91291200 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.806246638 -0.086898506 0.585162342 0.297888150 0.078344196 0.996124208 0.039983708 0.243578507 -0.586368918 0.013607344 0.809929788 -1.248063630 8 | 91558133 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.771091938 -0.093814306 0.629774630 0.223948432 0.087357447 0.995320201 0.041307874 0.293608807 -0.630702674 0.023163332 0.775678813 -1.459775674 9 | 91858433 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.737968326 -0.099363215 0.667480111 0.145501271 0.093257703 0.994626462 0.044957232 0.329381977 -0.668360531 0.029070651 0.743269205 -1.688460978 10 | 92125367 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.716826320 -0.101809755 0.689778805 0.086545731 0.098867603 0.994127929 0.043986596 0.379651732 -0.690206647 0.036666028 0.722682774 -1.885393814 11 | 92425667 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.703360021 -0.101482928 0.703552365 0.039205180 0.098760851 0.994108558 0.044659954 0.417778776 -0.703939617 0.038071405 0.709238708 -2.106152155 12 | 92692600 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.699525177 -0.101035394 0.707429409 0.029387371 0.096523918 0.994241416 0.046552572 0.439027166 -0.708059072 0.035719164 0.705249250 -2.314481674 13 | 92992900 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.698582709 -0.101620331 0.708276451 0.018437890 0.096638583 0.994193733 0.047326516 0.478349552 -0.708973348 0.035385344 0.704347014 -2.540820022 14 | 93259833 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.704948425 -0.098988213 0.702316940 0.047566428 0.095107265 0.994462848 0.044701166 0.517456396 -0.702853024 0.035283424 0.710459530 -2.724204596 15 | 93560133 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.714113414 -0.104848787 0.692133486 0.107161588 0.100486010 0.993833601 0.046875130 0.568063228 -0.692780316 0.036075566 0.720245779 -2.948379150 16 | 93827067 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.717699587 -0.112323314 0.687234104 0.118765931 0.105546549 0.993049562 0.052081093 0.593900230 -0.688307464 0.035156611 0.724566638 -3.140363331 17 | 94127367 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.715531290 -0.122954883 0.687675118 0.089455249 0.115526602 0.991661787 0.057100743 0.643643035 -0.688961923 0.038587399 0.723769605 -3.310401931 18 | -------------------------------------------------------------------------------- /assets/pose_files/0c9b371cc6225682.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=_ca03xP_KUU 2 | 211244000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.984322786 0.006958477 -0.176239252 0.004217217 -0.005594095 0.999950409 0.008237306 -0.107944544 0.176287830 -0.007122268 0.984312892 -0.571743822 3 | 211511000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981951714 0.008860772 -0.188924149 0.000856103 -0.007234470 0.999930620 0.009296093 -0.149397579 0.188993424 -0.007761548 0.981947660 -0.776566486 4 | 211778000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981318414 0.010160952 -0.192122281 -0.005546933 -0.008323869 0.999911606 0.010366773 -0.170816348 0.192210630 -0.008573905 0.981316268 -0.981924227 5 | 212078000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981108844 0.010863926 -0.193151161 0.019480142 -0.008781361 0.999893725 0.011634931 -0.185801323 0.193257034 -0.009719004 0.981100023 -1.207220396 6 | 212345000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981263518 0.010073495 -0.192407012 0.069708411 -0.008015377 0.999902070 0.011472094 -0.203594876 0.192503735 -0.009714933 0.981248140 -1.408936391 7 | 212646000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980964184 0.009669405 -0.193947718 0.166020848 -0.007467276 0.999899149 0.012082115 -0.219176122 0.194044977 -0.010403861 0.980937481 -1.602649833 8 | 212913000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980841637 0.009196524 -0.194589555 0.262465567 -0.006609587 0.999880970 0.013939449 -0.224018296 0.194694594 -0.012386235 0.980785728 -1.740759996 9 | 213212000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980620921 0.008805701 -0.195716679 0.389752858 -0.006055873 0.999874413 0.014644019 -0.230312701 0.195821062 -0.013174997 0.980551124 -1.890949759 10 | 213479000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980327129 0.009317402 -0.197159693 0.505632551 -0.006113928 0.999839306 0.016850581 -0.230702867 0.197285011 -0.015313662 0.980226576 -2.016199670 11 | 213779000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980493963 0.009960363 -0.196296573 0.623893674 -0.006936011 0.999846518 0.016088497 -0.223079036 0.196426690 -0.014413159 0.980412602 -2.137999468 12 | 214046000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980032921 0.010318150 -0.198567480 0.754726451 -0.007264129 0.999843955 0.016102606 -0.222246314 0.198702648 -0.014338664 0.979954958 -2.230292399 13 | 214347000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.976653159 0.010179597 -0.214580998 0.946523963 -0.006709154 0.999834776 0.016895246 -0.210005171 0.214717537 -0.015061138 0.976560056 -2.305666573 14 | 214614000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.971478105 0.011535713 -0.236848563 1.096604956 -0.007706031 0.999824286 0.017088750 -0.192895049 0.237004071 -0.014776184 0.971396267 -2.365701917 15 | 214914000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.965282261 0.014877280 -0.260785013 1.237534109 -0.014124592 0.999888897 0.004760279 -0.136261458 0.260826856 -0.000911531 0.965385139 -2.458136272 16 | 215181000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.961933076 0.016891202 -0.272762626 1.331672110 -0.022902885 0.999559581 -0.018870916 -0.076291319 0.272323757 0.024399608 0.961896241 -2.579417067 17 | 215481000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.959357142 0.017509742 -0.281651050 1.417338469 -0.039949402 0.996448219 -0.074127860 0.083949011 0.279352754 0.082366876 0.956649244 -2.712094466 18 | -------------------------------------------------------------------------------- /assets/pose_files/0f47577ab3441480.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=in69BD2eZqg 2 | 195562033 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999749303 -0.004518872 0.021929268 0.038810557 0.004613766 0.999980211 -0.004278630 0.328177052 -0.021909500 0.004378735 0.999750376 -0.278403591 3 | 195828967 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999336481 -0.006239665 0.035883281 0.034735125 0.006456365 0.999961615 -0.005926326 0.417233500 -0.035844926 0.006154070 0.999338388 -0.270773664 4 | 196095900 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998902142 -0.007417044 0.046254709 0.033849936 0.007582225 0.999965489 -0.003396692 0.504852301 -0.046227921 0.003743677 0.998923898 -0.256677740 5 | 196396200 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998096347 -0.008753631 0.061049398 0.026475959 0.009088391 0.999945164 -0.005207890 0.583593760 -0.061000463 0.005752816 0.998121142 -0.236166024 6 | 196663133 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997463286 -0.009214416 0.070583619 0.014842158 0.009590282 0.999941587 -0.004988078 0.634675512 -0.070533529 0.005652342 0.997493386 -0.198663134 7 | 196963433 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996699810 -0.009558053 0.080611102 0.003250557 0.009986609 0.999938071 -0.004914839 0.670145924 -0.080559134 0.005703651 0.996733487 -0.141256339 8 | 197230367 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996102691 -0.010129508 0.087617576 -0.013035317 0.010638822 0.999929130 -0.005347892 0.673139255 -0.087557197 0.006259197 0.996139824 -0.073934910 9 | 197530667 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995880842 -0.009925503 0.090126604 -0.036202423 0.010367444 0.999936402 -0.004436717 0.655632681 -0.090076834 0.005352824 0.995920420 0.017267095 10 | 197797600 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995802402 -0.010077500 0.090972595 -0.060858524 0.010445373 0.999939084 -0.003568561 0.618604505 -0.090931088 0.004503824 0.995846987 0.133592270 11 | 198097900 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995846093 -0.010148350 0.090484887 -0.077962281 0.010412642 0.999942780 -0.002449236 0.561822755 -0.090454854 0.003381249 0.995894790 0.274195378 12 | 198364833 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995989919 -0.009936163 0.088912196 -0.082315587 0.010200773 0.999944806 -0.002522171 0.520613290 -0.088882230 0.003419030 0.996036291 0.395169547 13 | 198665133 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997159958 -0.009351323 0.074730076 -0.068472873 0.009822783 0.999934077 -0.005943770 0.466061412 -0.074669570 0.006660947 0.997186065 0.549834051 14 | 198932067 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998626053 -0.008290987 0.051742285 -0.037270541 0.008407482 0.999962568 -0.002034174 0.410440195 -0.051723484 0.002466401 0.998658419 0.690111645 15 | 199232367 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999980092 -0.004756952 0.004140501 -0.005957613 0.004773445 0.999980688 -0.003982662 0.354437092 -0.004121476 0.004002347 0.999983490 0.842797271 16 | 199499300 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998872638 0.001147069 -0.047456335 0.002603018 -0.001435435 0.999980688 -0.006042828 0.295339877 0.047448486 0.006104136 0.998855054 0.988644188 17 | 199799600 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.992951691 0.008710741 -0.118199304 -0.030798243 -0.009495872 0.999936402 -0.006080875 0.208803899 0.118138820 0.007160421 0.992971301 1.161643267 18 | -------------------------------------------------------------------------------- /assets/pose_files/0f68374b76390082.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=-aldZQifF2U 2 | 103736967 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.804089785 -0.073792785 0.589910388 -2.686968354 0.081914566 0.996554494 0.013005137 0.128970374 -0.588837504 0.037864953 0.807363987 -1.789505608 3 | 104003900 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.772824645 -0.077280566 0.629896700 -2.856354365 0.084460691 0.996253133 0.018602582 0.115028772 -0.628974140 0.038824979 0.776456118 -1.799931844 4 | 104270833 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.740043461 -0.078656308 0.667943776 -3.017167990 0.086847030 0.995998919 0.021066183 0.116867188 -0.666928232 0.042419042 0.743913531 -1.815074499 5 | 104571133 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.696879685 -0.073477358 0.713414192 -3.221640235 0.086792909 0.996067226 0.017807571 0.133618379 -0.711916924 0.049509555 0.700516284 -1.784051774 6 | 104838067 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.654997289 -0.066671766 0.752684176 -3.418233112 0.086666502 0.996154904 0.012819566 0.161623584 -0.750644684 0.056835718 0.658256948 -1.733288907 7 | 105138367 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.603833497 -0.059696361 0.794871926 -3.619566170 0.087576874 0.996123314 0.008281946 0.184519895 -0.792284906 0.064611480 0.606720686 -1.643568460 8 | 105405300 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.555575073 -0.055402864 0.829618514 -3.768244320 0.089813948 0.995938241 0.006363695 0.197587954 -0.826601386 0.070975810 0.558294415 -1.559717271 9 | 105705600 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.501615226 -0.052979972 0.863467038 -3.914896511 0.093892507 0.995560884 0.006539768 0.201989601 -0.859980464 0.077792637 0.504362881 -1.476983336 10 | 105972533 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.454045177 -0.052372806 0.889438093 -4.034987790 0.099656843 0.994991958 0.007714771 0.211683202 -0.885387778 0.085135736 0.456990600 -1.405070279 11 | 106272833 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.397668689 -0.051785514 0.916066527 -4.178181130 0.105599925 0.994354606 0.010369749 0.208751884 -0.911431968 0.092612833 0.400892258 -1.295093582 12 | 106539767 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.345666498 -0.052948993 0.936862350 -4.285116664 0.110631727 0.993743002 0.015344846 0.195070069 -0.931812882 0.098342501 0.349361509 -1.182773054 13 | 106840067 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.284817457 -0.055293880 0.956985712 -4.392320606 0.115495987 0.993041575 0.023003323 0.168523273 -0.951598525 0.103976257 0.289221793 -1.053514096 14 | 107107000 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.228878200 -0.056077410 0.971838534 -4.485196000 0.120451130 0.992298782 0.028890507 0.159180748 -0.965974271 0.110446639 0.233870149 -0.923927626 15 | 107407300 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.162932962 -0.053445265 0.985188544 -4.601126217 0.124115810 0.991709769 0.033272449 0.152041098 -0.978799343 0.116856292 0.168215603 -0.758111250 16 | 107674233 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.102818660 -0.051196381 0.993381739 -4.691710857 0.127722457 0.991087139 0.037858382 0.141352300 -0.986466050 0.122984610 0.108441174 -0.599244073 17 | 107974533 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.034108389 -0.050325166 0.998150289 -4.758242879 0.132215530 0.990180492 0.045405328 0.118994547 -0.990633965 0.130422264 0.040427230 -0.433560831 18 | -------------------------------------------------------------------------------- /assets/pose_files/2c80f9eb0d3b2bb4.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=sLIFyXD2ujI 2 | 77444033 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.980310440 0.030424286 -0.195104495 -0.195846403 -0.034550700 0.999244750 -0.017780757 0.034309913 0.194416180 0.024171660 0.980621278 -0.178639121 3 | 77610867 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.973806441 0.034138829 -0.224801064 -0.221452338 -0.039088678 0.999080658 -0.017603843 0.038706263 0.223993421 0.025929911 0.974245667 -0.219951444 4 | 77777700 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.965910375 0.037603889 -0.256131083 -0.242696017 -0.043735024 0.998875856 -0.018281631 0.046505467 0.255155712 0.028860316 0.966469169 -0.265310453 5 | 77944533 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.956261098 0.040829532 -0.289650917 -0.252766079 -0.048421524 0.998644531 -0.019089982 0.054620904 0.288478881 0.032280345 0.956941962 -0.321621308 6 | 78144733 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.941536248 0.043692805 -0.334066480 -0.250198162 -0.053955212 0.998311937 -0.021497937 0.069548726 0.332563221 0.038265716 0.942304313 -0.401964240 7 | 78311567 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.926658213 0.046350952 -0.373036444 -0.239336491 -0.058738846 0.998033047 -0.021904159 0.077439241 0.371287435 0.042209402 0.927558064 -0.474019461 8 | 78478400 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.909880757 0.048629351 -0.412009954 -0.218247042 -0.063676558 0.997708619 -0.022863906 0.088967126 0.409954011 0.047038805 0.910892427 -0.543114491 9 | 78645233 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.891359746 0.050869841 -0.450433195 -0.185763327 -0.067926541 0.997452736 -0.021771761 0.093745158 0.448178291 0.050002839 0.892544627 -0.611223637 10 | 78845433 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.870108902 0.054302387 -0.489858925 -0.153515269 -0.074510135 0.996981323 -0.021829695 0.107765162 0.487194777 0.055493668 0.871528387 -0.691303250 11 | 79012267 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.852772951 0.056338910 -0.519234240 -0.128052677 -0.078825951 0.996660411 -0.021319628 0.116291007 0.516299069 0.059109934 0.854366004 -0.760654136 12 | 79179100 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.835254073 0.059146367 -0.546673834 -0.101344556 -0.084243484 0.996225357 -0.020929486 0.126763936 0.543372452 0.063535146 0.837083995 -0.832841061 13 | 79345933 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.818755865 0.062443536 -0.570736051 -0.077325807 -0.089739971 0.995768547 -0.019791666 0.136091605 0.567085147 0.067422375 0.820895016 -0.908256727 14 | 79546133 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.798365474 0.066207208 -0.598522484 -0.043774887 -0.096616283 0.995144248 -0.018795265 0.150808225 0.594371796 0.072832510 0.800885499 -0.994657638 15 | 79712967 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.781648815 0.069040783 -0.619885862 -0.013285614 -0.101820730 0.994646847 -0.017611075 0.161173621 0.615351617 0.076882906 0.784494340 -1.070102980 16 | 79879800 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.765168309 0.072694756 -0.639713168 0.019850080 -0.108554602 0.993946910 -0.016894773 0.177612448 0.634612799 0.082371153 0.768428028 -1.147576811 17 | 80080000 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.745406330 0.077463314 -0.662094295 0.062107046 -0.117075674 0.993000031 -0.015629012 0.200140798 0.656248927 0.089165099 0.749257565 -1.238600776 18 | -------------------------------------------------------------------------------- /assets/pose_files/2f25826f0d0ef09a.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=t-mlAKnESzQ 2 | 167200000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991872013 -0.011311784 0.126735851 0.400533760 0.012037775 0.999915242 -0.004963919 -0.047488550 -0.126668960 0.006449190 0.991924107 -0.414499612 3 | 167467000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991945148 -0.011409644 0.126153216 0.506974565 0.012122569 0.999914587 -0.004884966 -0.069421149 -0.126086697 0.006374919 0.991998732 -0.517325825 4 | 167734000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.992271781 -0.010751382 0.123616949 0.590358341 0.011312660 0.999928653 -0.003839425 -0.085158661 -0.123566844 0.005208189 0.992322564 -0.599035085 5 | 168034000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993287027 -0.009973313 0.115245141 0.673577580 0.010455138 0.999938965 -0.003577147 -0.104263255 -0.115202427 0.004758038 0.993330657 -0.691557669 6 | 168301000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993988216 -0.009955904 0.109033749 0.753843765 0.010435819 0.999938190 -0.003831771 -0.106670354 -0.108988866 0.004946592 0.994030654 -0.805538867 7 | 168602000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994774222 -0.010583352 0.101549298 0.846176230 0.011122120 0.999926925 -0.004740742 -0.089426372 -0.101491705 0.005845411 0.994819224 -0.933449460 8 | 168869000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.995415390 -0.010595482 0.095057681 0.913119395 0.011053002 0.999929726 -0.004287821 -0.072756893 -0.095005572 0.005318835 0.995462537 -1.037255409 9 | 169169000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996029556 -0.009414902 0.088523701 0.977259045 0.009879347 0.999939620 -0.004809874 -0.042104006 -0.088473074 0.005665333 0.996062458 -1.127427189 10 | 169436000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996695757 -0.008890423 0.080737323 1.025351476 0.009221899 0.999950528 -0.003733651 -0.007486727 -0.080700137 0.004465866 0.996728420 -1.188659636 11 | 169736000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997404695 -0.008404067 0.071506783 1.073562767 0.008649707 0.999957681 -0.003126226 0.054879890 -0.071477488 0.003736625 0.997435212 -1.216979926 12 | 170003000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997887254 -0.008228444 0.064446673 1.110116903 0.008409287 0.999961436 -0.002535321 0.124372514 -0.064423330 0.003071915 0.997917950 -1.231904045 13 | 170303000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998332024 -0.007790270 0.057205603 1.136173895 0.007975516 0.999963641 -0.003010646 0.212542522 -0.057180069 0.003461868 0.998357892 -1.242942079 14 | 170570000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998471320 -0.007715963 0.054730706 1.159189486 0.007868989 0.999965727 -0.002581036 0.310163907 -0.054708913 0.003007766 0.998497844 -1.245661417 15 | 170871000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998552144 -0.007742116 0.053231847 1.173763753 0.007991423 0.999958038 -0.004472161 0.412779543 -0.053194992 0.004891084 0.998572171 -1.229165757 16 | 171137000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998553872 -0.007909958 0.053175092 1.179029258 0.008138723 0.999958515 -0.004086919 0.509089997 -0.053140558 0.004513786 0.998576820 -1.196146494 17 | 171438000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998469293 -0.008281939 0.054685175 1.181414517 0.008542870 0.999953210 -0.004539483 0.618089736 -0.054645021 0.004999703 0.998493314 -1.159911786 18 | -------------------------------------------------------------------------------- /assets/pose_files/3c35b868a8ec3433.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=bJyPo9pESu0 2 | 189622767 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.966956913 0.041186374 -0.251590967 0.235831829 -0.037132759 0.999092996 0.020840336 0.069818943 0.252221137 -0.010809440 0.967609227 -0.850289525 3 | 189789600 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.967445135 0.041703269 -0.249621317 0.217678822 -0.037349533 0.999056637 0.022154763 0.078295447 0.250309765 -0.012110277 0.968090057 -0.818677483 4 | 189956433 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.967587769 0.043305319 -0.248794302 0.196350216 -0.038503598 0.998966932 0.024136283 0.085749990 0.249582499 -0.013774496 0.968255579 -0.778043636 5 | 190123267 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.967742383 0.044170257 -0.248039767 0.170234078 -0.039154600 0.998917341 0.025120445 0.090556068 0.248880804 -0.014598221 0.968424082 -0.733500964 6 | 190323467 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.973553717 0.043272153 -0.224322766 0.120337922 -0.038196862 0.998907626 0.026917407 0.091227451 0.225242496 -0.017637115 0.974143088 -0.680520640 7 | 190490300 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.984184802 0.039637893 -0.172653258 0.065019106 -0.035194401 0.998967648 0.028723357 0.090969669 0.173613548 -0.022192664 0.984563768 -0.638603728 8 | 190657133 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.993078411 0.035358477 -0.112004772 0.011571313 -0.032207530 0.999036312 0.029818388 0.092482656 0.112951167 -0.026004599 0.993260205 -0.588118143 9 | 190823967 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.997807920 0.031166473 -0.058378015 -0.027908508 -0.029339414 0.999060452 0.031897116 0.092538838 0.059317287 -0.030114418 0.997784853 -0.529325066 10 | 191024167 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.999654651 0.026247790 -0.001263706 -0.064570799 -0.026190240 0.999087334 0.033742432 0.091922841 0.002148218 -0.033697683 0.999429762 -0.448626929 11 | 191191000 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.998773992 0.022079065 0.044305529 -0.084478169 -0.023622099 0.999121666 0.034611158 0.087434649 -0.043502431 -0.035615314 0.998418272 -0.371306296 12 | 191357833 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.995725632 0.017435640 0.090699598 -0.094868572 -0.020876031 0.999092638 0.037122324 0.082208324 -0.089970052 -0.038857099 0.995186150 -0.290596011 13 | 191524667 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.989503622 0.013347236 0.143890470 -0.096537122 -0.019140780 0.999057651 0.038954727 0.079283141 -0.143234938 -0.041300017 0.988826632 -0.207477308 14 | 191724867 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.975660741 0.006981443 0.219174415 -0.085240259 -0.016479453 0.999001026 0.041537181 0.072219148 -0.218665481 -0.044138070 0.974801123 -0.112100139 15 | 191891700 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.955792487 -0.000511726 0.294041574 -0.064476318 -0.012924311 0.998958945 0.043749433 0.061688334 -0.293757826 -0.045615666 0.954790831 -0.034724173 16 | 192058533 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.925362229 -0.009678445 0.378960580 -0.029417786 -0.008889219 0.998845160 0.047216032 0.058476640 -0.378979892 -0.047060598 0.924207509 0.042010383 17 | 192258733 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.872846186 -0.021581186 0.487517983 0.038433307 -0.004890797 0.998584569 0.052961230 0.057516307 -0.487970918 -0.048611358 0.871505201 0.124675285 18 | -------------------------------------------------------------------------------- /assets/pose_files/3f79dc32d575bcdc.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=1qVpRlWxam4 2 | 86319567 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999183893 0.038032386 -0.013605987 -0.249154748 -0.038085770 0.999267697 -0.003686040 0.047875167 0.013455833 0.004201226 0.999900639 -0.566803149 3 | 86586500 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999392629 0.034676589 -0.003445767 -0.282371175 -0.034685481 0.999395013 -0.002555777 0.057086778 0.003355056 0.002673743 0.999990821 -0.624021456 4 | 86853433 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999498725 0.028301919 0.014187563 -0.320995587 -0.028301118 0.999599397 -0.000257314 0.061367205 -0.014189162 -0.000144339 0.999899328 -0.706664680 5 | 87153733 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999064565 0.022049030 0.037200645 -0.371910835 -0.022201553 0.999746680 0.003691827 0.063911726 -0.037109818 -0.004514286 0.999301016 -0.799748814 6 | 87420667 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998171926 0.018552339 0.057520505 -0.440220060 -0.018887693 0.999807596 0.005291941 0.070160264 -0.057411261 -0.006368696 0.998330295 -0.853433007 7 | 87720967 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997675776 0.016262729 0.066170901 -0.486385324 -0.016915560 0.999813497 0.009317505 0.069230577 -0.066007033 -0.010415167 0.997764826 -0.912234761 8 | 87987900 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998019218 0.015867118 0.060876362 -0.497549423 -0.016505934 0.999813735 0.010005167 0.076295227 -0.060706269 -0.010990170 0.998095155 -0.980435972 9 | 88288200 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999152124 0.018131699 0.036962789 -0.468507446 -0.018461898 0.999792457 0.008611582 0.087696066 -0.036798976 -0.009286684 0.999279559 -1.074633197 10 | 88555133 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999717414 0.022977378 0.006097841 -0.420528982 -0.023013741 0.999717355 0.005961678 0.101216630 -0.005959134 -0.006100327 0.999963641 -1.169004730 11 | 88855433 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999106526 0.030726369 -0.029017152 -0.374249594 -0.030677194 0.999527037 0.002138488 0.120936030 0.029069137 -0.001246413 0.999576628 -1.251082317 12 | 89122367 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997359693 0.039784521 -0.060752310 -0.335843098 -0.039773725 0.999207735 0.001387495 0.132824955 0.060759377 0.001032514 0.998151898 -1.312258423 13 | 89422667 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.992973983 0.050480653 -0.107025139 -0.253623964 -0.050627887 0.998716712 0.001342622 0.144421611 0.106955573 0.004085269 0.994255424 -1.394020432 14 | 89689600 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.986886561 0.059628733 -0.149997801 -0.173418608 -0.059660275 0.998209476 0.004293700 0.142984494 0.149985254 0.004711515 0.988677025 -1.462588413 15 | 89989900 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.978200734 0.067550205 -0.196367815 -0.089199207 -0.067402542 0.997698128 0.007442682 0.141665403 0.196418539 0.005955252 0.980502069 -1.524381413 16 | 90256833 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.967793405 0.073765829 -0.240695804 0.013635864 -0.073441446 0.997246027 0.010330606 0.134276795 0.240794986 0.007679154 0.970545650 -1.588498428 17 | 90557133 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.953711152 0.081056722 -0.289594263 0.148156165 -0.081249826 0.996628821 0.011376631 0.129987979 0.289540142 0.012679463 0.957081914 -1.633951355 18 | -------------------------------------------------------------------------------- /assets/pose_files/4a2d6753676df096.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=mGFQkgadzRQ 2 | 123665000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.996869564 0.002875770 -0.079011612 -0.427841466 -0.002861131 0.999995887 0.000298484 -0.005788880 0.079012141 -0.000071487 0.996873677 0.132732609 3 | 123999000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.993462563 0.003229393 -0.114112593 -0.472377562 -0.003208589 0.999994814 0.000365978 -0.005932507 0.114113182 0.000002555 0.993467748 0.123959606 4 | 124332000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.988605380 0.003602870 -0.150487319 -0.517270184 -0.003599323 0.999993503 0.000295953 -0.005751638 0.150487408 0.000249071 0.988611877 0.113156366 5 | 124708000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.981692851 0.004048047 -0.190427750 -0.566330350 -0.004096349 0.999991596 0.000139980 -0.007622665 0.190426722 0.000642641 0.981701195 0.098572887 6 | 125041000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.974759340 0.004326052 -0.223216295 -0.606091424 -0.004403458 0.999990284 0.000150970 -0.009427620 0.223214790 0.000835764 0.974768937 0.084984909 7 | 125417000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.965238512 0.004419941 -0.261333257 -0.651601078 -0.004571608 0.999989569 0.000027561 -0.007437027 0.261330664 0.001168111 0.965248644 0.068577736 8 | 125750000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.953956902 0.004390486 -0.299911648 -0.697081969 -0.004806366 0.999988258 -0.000648964 -0.003676960 0.299905270 0.002060569 0.953966737 0.050264043 9 | 126126000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.940579295 0.004839818 -0.339539677 -0.744385684 -0.005527717 0.999984145 -0.001058831 -0.001820489 0.339529186 0.002872794 0.940591156 0.028560147 10 | 126459000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.928297341 0.004980532 -0.371805429 -0.781716025 -0.005848793 0.999982178 -0.001207554 -0.001832299 0.371792793 0.003295582 0.928309917 0.009470658 11 | 126835000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.913324535 0.005156573 -0.407199889 -0.824074795 -0.006227055 0.999979734 -0.001303667 -0.001894351 0.407184929 0.003726327 0.913338125 -0.013179829 12 | 127168000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.898822486 0.005400294 -0.438279599 -0.860775204 -0.006702366 0.999976516 -0.001423908 -0.001209170 0.438261628 0.004217350 0.898837566 -0.034594674 13 | 127544000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.880397439 0.005455900 -0.474205226 -0.903308447 -0.007032821 0.999974072 -0.001551900 -0.000798134 0.474184483 0.004701289 0.880412936 -0.061250069 14 | 127877000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.862660766 0.005402398 -0.505754173 -0.939888304 -0.007276668 0.999972045 -0.001730187 -0.000489221 0.505730629 0.005172769 0.862675905 -0.086411685 15 | 128253000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.841714203 0.005442667 -0.539895892 -0.978630821 -0.007698633 0.999968529 -0.001921765 0.000975953 0.539868414 0.005774037 0.841729641 -0.115983579 16 | 128587000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.823229551 0.005282366 -0.567684054 -1.010071242 -0.007977336 0.999965608 -0.002263572 0.002284809 0.567652583 0.006392045 0.823243380 -0.141444392 17 | 128962000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.802855015 0.005112482 -0.596152425 -1.042319682 -0.008217614 0.999963105 -0.002491409 0.003637235 0.596117735 0.006899191 0.802867413 -0.169369454 18 | -------------------------------------------------------------------------------- /configs/inference_config.yaml: -------------------------------------------------------------------------------- 1 | validation_data: 2 | relative_pose: true 3 | 4 | unet_additional_kwargs: 5 | unet_use_cross_frame_attention : false 6 | unet_use_temporal_attention : false 7 | 8 | use_motion_module : true 9 | motion_module_resolutions : [ 1,2,4,8 ] 10 | motion_module_mid_block: false 11 | motion_module_type: Vanilla 12 | motion_module_kwargs: 13 | num_attention_heads : 8 14 | num_transformer_block : 1 15 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 16 | temporal_position_encoding : true 17 | temporal_position_encoding_max_len : 32 18 | temporal_attention_dim_div : 1 19 | zero_initialize : false 20 | 21 | use_epi_module : true 22 | epi_module_resolutions : [ 1,2,4,8 ] 23 | epi_module_mid_block: false 24 | epi_module_kwargs: 25 | num_attention_heads : 8 26 | num_transformer_block : 1 27 | attention_block_types : [ "Epi_Self", "Epi_Self" ] 28 | epi_position_encoding : true 29 | epi_position_encoding_feat_max_size: 64 30 | epi_position_encoding_F_mat_size : 256 31 | epi_rand_slope_ff : true 32 | zero_initialize : true 33 | 34 | pose_encoder_kwargs: 35 | downscale_factor: 8 36 | channels: [320, 640, 1280, 1280] 37 | nums_rb: 2 38 | cin: 384 39 | ksize: 1 40 | sk: true 41 | use_conv: false 42 | compression_factor: 1 43 | temporal_attention_nhead: 8 44 | attention_block_types: ["Temporal_Self", ] 45 | temporal_position_encoding: true 46 | temporal_position_encoding_max_len: 16 47 | 48 | attention_processor_kwargs: 49 | add_spatial: false 50 | spatial_attn_names: 'attn1' 51 | add_temporal: true 52 | temporal_attn_names: '0' 53 | pose_feature_dimensions: [320, 640, 1280, 1280] 54 | query_condition: true 55 | key_value_condition: true 56 | scale: 1.0 57 | 58 | noise_scheduler_kwargs: 59 | num_train_timesteps: 1000 60 | beta_start: 0.00085 61 | beta_end: 0.012 62 | beta_schedule: "linear" 63 | steps_offset: 1 64 | clip_sample: false -------------------------------------------------------------------------------- /configs/validation_prompts.txt: -------------------------------------------------------------------------------- 1 | "A joker dancing, 4K, movie quality" 2 | "Robot dancing in times square." 3 | "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." 4 | "A drone view of celebration with Christma tree and fireworks, starry sky - background." 5 | "Pacific coast, carmel by the sea ocean and waves." -------------------------------------------------------------------------------- /dist_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | PT_SCRIPT=$3 8 | RANDOM_PORT=$((49152 + RANDOM % 16384)) 9 | 10 | python -m torch.distributed.launch \ 11 | --nproc_per_node=$GPUS \ 12 | --master_port=$RANDOM_PORT \ 13 | ${PT_SCRIPT} \ 14 | --config=${CONFIG} \ 15 | --launcher=pytorch \ 16 | --port=${RANDOM_PORT} -------------------------------------------------------------------------------- /docs/badge-website.svg: -------------------------------------------------------------------------------- 1 | 2 | 15 | 17 | 35 | project: website 37 | 38 | 42 | 47 | 51 | 52 | 54 | 60 | 61 | 64 | 69 | 75 | 80 | 81 | 88 | 98 | Project 105 | 106 | 116 | Website 123 | 124 | 125 | 129 | 130 | -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CollaborativeVideoDiffusion/CVD/107f299bd75c7a37158c52427de473cec86c649a/docs/teaser.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: CVD 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.10 7 | - pip -------------------------------------------------------------------------------- /inference_epi.py: -------------------------------------------------------------------------------- 1 | # make sure you're logged in with `huggingface-cli login` 2 | import argparse 3 | import json 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | from packaging import version as pver 10 | from einops import rearrange 11 | from safetensors import safe_open 12 | 13 | from omegaconf import OmegaConf 14 | from diffusers import ( 15 | AutoencoderKL, 16 | DDIMScheduler 17 | ) 18 | from transformers import CLIPTextModel, CLIPTokenizer 19 | 20 | from animatediff.utils.util import save_videos_grid 21 | from animatediff.models.unet import UNet3DConditionModelPoseCond 22 | from animatediff.models.pose_adaptor import CameraPoseEncoder 23 | from animatediff.pipelines.pipeline_animation_epi import AnimationPipelineEpiControl 24 | from animatediff.data.dataset_validation import ValRealEstate10KPoseFolded 25 | from animatediff.data.dataset_validation import Camera 26 | 27 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_vae_checkpoint, \ 28 | convert_ldm_clip_checkpoint 29 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint 30 | import imageio 31 | 32 | from tools.visualize_trajectory import CameraPoseVisualizer 33 | 34 | 35 | def setup_for_distributed(is_master): 36 | """ 37 | This function disables printing when not in master process 38 | """ 39 | import builtins as __builtin__ 40 | builtin_print = __builtin__.print 41 | 42 | def print(*args, **kwargs): 43 | force = kwargs.pop('force', False) 44 | if is_master or force: 45 | builtin_print(*args, **kwargs) 46 | 47 | __builtin__.print = print 48 | 49 | def load_civitai_base_model(pipeline, civitai_base_model): 50 | print(f'Load civitai base model from {civitai_base_model}') 51 | if civitai_base_model.endswith(".safetensors"): 52 | dreambooth_state_dict = {} 53 | with safe_open(civitai_base_model, framework="pt", device="cpu") as f: 54 | for key in f.keys(): 55 | dreambooth_state_dict[key] = f.get_tensor(key) 56 | elif civitai_base_model.endswith(".ckpt"): 57 | dreambooth_state_dict = torch.load(civitai_base_model, map_location="cpu") 58 | 59 | # 1. vae 60 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config) 61 | pipeline.vae.load_state_dict(converted_vae_checkpoint) 62 | # 2. unet 63 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config) 64 | _, unetu = pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 65 | assert len(unetu) == 0 66 | # 3. text_model 67 | pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, text_encoder=pipeline.text_encoder) 68 | del dreambooth_state_dict 69 | return pipeline 70 | 71 | 72 | def get_pipeline(ori_model_path, unet_subfolder, image_lora_rank, image_lora_ckpt, unet_additional_kwargs, 73 | unet_mm_ckpt, unet_epi_ckpt, pose_encoder_kwargs, attention_processor_kwargs, 74 | noise_scheduler_kwargs, pose_adaptor_ckpt, civitai_lora_ckpt, civitai_base_model, gpu_id, 75 | spatial_extended_attention=False): 76 | vae = AutoencoderKL.from_pretrained(ori_model_path, subfolder="vae") 77 | tokenizer = CLIPTokenizer.from_pretrained(ori_model_path, subfolder="tokenizer") 78 | text_encoder = CLIPTextModel.from_pretrained(ori_model_path, subfolder="text_encoder") 79 | unet = UNet3DConditionModelPoseCond.from_pretrained_2d(ori_model_path, subfolder=unet_subfolder, 80 | unet_additional_kwargs=unet_additional_kwargs) 81 | pose_encoder = CameraPoseEncoder(**pose_encoder_kwargs) 82 | print(f"Setting the attention processors") 83 | unet.set_all_attn_processor(add_spatial_lora=image_lora_ckpt is not None, 84 | add_motion_lora=False, 85 | lora_kwargs={"lora_rank": image_lora_rank, "lora_scale": 1.0}, 86 | motion_lora_kwargs={"lora_rank": -1, "lora_scale": 1.0}, 87 | sync_lora_kwargs={"sync_lora_rank": 0, "sync_lora_scale": 0}, 88 | spatial_extended_attention=spatial_extended_attention, 89 | **attention_processor_kwargs) 90 | 91 | if image_lora_ckpt is not None: 92 | print(f"Loading the lora checkpoint from {image_lora_ckpt}") 93 | lora_checkpoints = torch.load(image_lora_ckpt, map_location=unet.device) 94 | if 'lora_state_dict' in lora_checkpoints.keys(): 95 | lora_checkpoints = lora_checkpoints['lora_state_dict'] 96 | _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False) 97 | assert len(lora_u) == 0 98 | print(f'Loading done') 99 | 100 | if unet_mm_ckpt is not None: 101 | print(f"Loading the motion module checkpoint from {unet_mm_ckpt}") 102 | mm_checkpoints = torch.load(unet_mm_ckpt, map_location=unet.device) 103 | _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False) 104 | assert len(mm_u) == 0 105 | print("Loading done") 106 | 107 | if unet_epi_ckpt is not None: 108 | print(f"Loading the epi module checkpoint from {unet_epi_ckpt}") 109 | ckpt = torch.load(unet_epi_ckpt, map_location=unet.device) 110 | unet_trainable_dict = ckpt['unet_trainable_dict'] 111 | _, epi_u = unet.load_state_dict(unet_trainable_dict, strict=False) 112 | assert len(epi_u) == 0 113 | print("Loading done") 114 | 115 | print(f"Loading pose adaptor") 116 | pose_adaptor_checkpoint = torch.load(pose_adaptor_ckpt, map_location='cpu') 117 | pose_encoder_state_dict = pose_adaptor_checkpoint['pose_encoder_state_dict'] 118 | pose_encoder_m, pose_encoder_u = pose_encoder.load_state_dict(pose_encoder_state_dict) 119 | assert len(pose_encoder_u) == 0 and len(pose_encoder_m) == 0 120 | attention_processor_state_dict = pose_adaptor_checkpoint['attention_processor_state_dict'] 121 | _, attn_proc_u = unet.load_state_dict(attention_processor_state_dict, strict=False) 122 | assert len(attn_proc_u) == 0 123 | print(f"Loading done") 124 | 125 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 126 | vae.to(gpu_id) 127 | text_encoder.to(gpu_id) 128 | unet.to(gpu_id) 129 | pose_encoder.to(gpu_id) 130 | pipe = AnimationPipelineEpiControl( 131 | vae=vae, 132 | text_encoder=text_encoder, 133 | tokenizer=tokenizer, 134 | unet=unet, 135 | scheduler=noise_scheduler, 136 | pose_encoder=pose_encoder) 137 | assert not (civitai_base_model and civitai_lora_ckpt) 138 | if civitai_lora_ckpt is not None: 139 | pipe.load_lora_weights(civitai_lora_ckpt) 140 | if civitai_base_model is not None: 141 | load_civitai_base_model(pipeline=pipe, civitai_base_model=civitai_base_model) 142 | pipe.enable_vae_slicing() 143 | pipe = pipe.to(gpu_id) 144 | 145 | return pipe 146 | 147 | 148 | def load_pair_poses(pose_file_0, pose_file_1): 149 | pose_file_0 = os.path.join(self.pose_file_0) 150 | with open(pose_file_0, 'r') as f: 151 | poses_0 = f.readlines() 152 | pose_file_1 = os.path.join(self.pose_file_1) 153 | with open(pose_file_1, 'r') as f: 154 | poses_1 = f.readlines() 155 | poses_0 = [pose_0.strip().split(' ') for pose in poses_0[1:]] 156 | cam_params_0 = [[float(x) for x in pose] for pose in poses_0] 157 | cam_params_0 = [Camera(cam_param) for cam_param in cam_params_0] 158 | poses_1 = [pose_0.strip().split(' ') for pose in poses_1[1:]] 159 | cam_params_1 = [[float(x) for x in pose] for pose in poses_1] 160 | cam_params_1 = [Camera(cam_param) for cam_param in cam_params_1] 161 | return cam_params_0, cam_params_1 162 | 163 | 164 | def main(args): 165 | os.makedirs(args.out_root, exist_ok=True) 166 | rank = args.local_rank 167 | setup_for_distributed(rank == 0) 168 | gpu_id = rank % torch.cuda.device_count() 169 | model_configs = OmegaConf.load(args.model_config) 170 | unet_additional_kwargs = model_configs[ 171 | 'unet_additional_kwargs'] if 'unet_additional_kwargs' in model_configs else None 172 | noise_scheduler_kwargs = model_configs['noise_scheduler_kwargs'] 173 | pose_encoder_kwargs = model_configs['pose_encoder_kwargs'] 174 | attention_processor_kwargs = model_configs['attention_processor_kwargs'] 175 | validation_configs = model_configs[ 176 | 'validation_data'] if 'validation_data' in model_configs else None 177 | unet_additional_kwargs['epi_module_kwargs']['epi_position_encoding_F_mat_size'] = args.image_height 178 | 179 | # overwritten 180 | attention_processor_kwargs["scale"] = args.pose_adaptor_scale 181 | 182 | print(f'Constructing pipeline') 183 | pipeline = get_pipeline(args.ori_model_path, args.unet_subfolder, args.image_lora_rank, args.image_lora_ckpt, 184 | unet_additional_kwargs, args.motion_module_ckpt, args.epi_module_ckpt, 185 | pose_encoder_kwargs, attention_processor_kwargs, 186 | noise_scheduler_kwargs, args.pose_adaptor_ckpt, args.civitai_lora_ckpt, 187 | args.civitai_base_model, f"cuda:{gpu_id}", 188 | spatial_extended_attention=args.spatial_extended_attention) 189 | device = torch.device(f"cuda:{gpu_id}") 190 | print('Done') 191 | 192 | print(f'Loading Validation Dataset') 193 | 194 | # with open(args.validation_prompts_file, "r") as f: 195 | # validation_prompts = [x.replace("\n", "") for x in f.readlines()] 196 | if args.caption_file.endswith('.json'): 197 | json_file = json.load(open(args.caption_file, 'r')) 198 | captions = json_file['captions'] if 'captions' in json_file else json_file['prompts'] 199 | if args.use_negative_prompt: 200 | negative_prompts = json_file['negative_prompts'] 201 | else: 202 | negative_prompts = None 203 | if isinstance(captions[0], dict): 204 | captions = [cap['caption'] for cap in captions] 205 | if args.use_specific_seeds: 206 | specific_seeds = json_file['seeds'] 207 | else: 208 | specific_seeds = None 209 | elif args.caption_file.endswith('.txt'): 210 | with open(args.caption_file, 'r') as f: 211 | captions = f.readlines() 212 | captions = [cap.strip() for cap in captions] 213 | negative_prompts = None 214 | specific_seeds = None 215 | 216 | if args.num_videos is not None: 217 | captions = captions * args.num_videos 218 | negative_prompts = negative_prompts * args.num_videos 219 | validation_configs["validation_prompts"] = captions 220 | validation_configs["validation_negative_prompts"] = negative_prompts 221 | validation_configs["sample_size"] = args.image_height 222 | if args.pose_file_0 is not None and args.pose_file_1 is not None: 223 | validation_configs["pose_file_0"] = args.pose_file_0 224 | validation_configs["pose_file_1"] = args.pose_file_1 225 | validation_dataset = ValRealEstate10KPoseFolded(**validation_configs) 226 | validation_dataloader = torch.utils.data.DataLoader( 227 | validation_dataset, 228 | batch_size=1, 229 | shuffle=False, 230 | num_workers=0, 231 | pin_memory=True, 232 | drop_last=False 233 | ) 234 | print(f'Done') 235 | 236 | generator = torch.Generator(device=device) 237 | generator.manual_seed(args.global_seed) 238 | 239 | validation_data_iter = iter(validation_dataloader) 240 | sample_all = [] 241 | 242 | if args.no_lora_validation: 243 | pipeline.unet.set_image_layer_lora_scale(0.0) 244 | 245 | for idx, validation_batch in enumerate(validation_data_iter): 246 | if specific_seeds is not None: 247 | specific_seed = specific_seeds[idx] 248 | generator.manual_seed(specific_seed) 249 | 250 | F_mats = validation_batch['F_mats'].to(device=pipeline.unet.device) 251 | F_mats = torch.cat(F_mats.chunk(2, dim=1), dim=0) 252 | plucker_embedding = validation_batch['plucker_embedding'].to(device=pipeline.unet.device) 253 | plucker_embedding = torch.cat(plucker_embedding.chunk(2, dim=1), dim=0) 254 | plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b c f h w") 255 | 256 | print(validation_batch['validation_prompt']) 257 | print(validation_batch['validation_negative_prompt']) 258 | output = pipeline( 259 | F_mats=F_mats, 260 | prompt=validation_batch['validation_prompt'], 261 | negative_prompt=validation_batch['validation_negative_prompt'], 262 | pose_embedding=plucker_embedding, 263 | video_length=args.video_length, 264 | height=args.image_height, 265 | width=args.image_width, 266 | num_inference_steps=args.num_inference_steps, 267 | guidance_scale=args.guidance_scale, 268 | generator=generator, 269 | ) # [2 3 f h w] 270 | sample = output.videos # [2 3 f h w] 271 | # save images 272 | cur_out_root = os.path.join(args.out_root, str(idx)) 273 | img_out_root = os.path.join(cur_out_root, 'imgs') 274 | os.makedirs(img_out_root, exist_ok=True) 275 | for frame_i in range(sample.shape[2]): 276 | imageio.imwrite(f"{img_out_root}/{frame_i}-{0}.png", 277 | (sample[0, :, frame_i, ] * 255.0).clamp(0, 255).permute(1, 2, 0).detach().numpy().astype(np.uint8)) 278 | imageio.imwrite(f"{img_out_root}/{frame_i}-{1}.png", 279 | (sample[1, :, frame_i, ] * 255.0).clamp(0, 255).permute(1, 2, 0).detach().numpy().astype(np.uint8)) 280 | # save individual videos 281 | vid_out_root = os.path.join(cur_out_root, 'vids') 282 | for video_i in range(sample.shape[0]): 283 | save_path = f"{vid_out_root}/{video_i}.mp4" 284 | save_videos_grid(sample[video_i].unsqueeze(0), save_path) 285 | # save combined videos 286 | save_path = f"{vid_out_root}/horizontal.mp4" 287 | save_videos_grid(rearrange(sample, "b c f h w -> c f h (b w)").unsqueeze(0), save_path) 288 | save_path = f"{vid_out_root}/vertical.mp4" 289 | save_videos_grid(rearrange(sample, "b c f h w -> c f (b h) w").unsqueeze(0), save_path) 290 | # save trajectories 291 | ret_c2w = validation_batch['ret_c2w'].squeeze() 292 | ret_c2w_list = ret_c2w.chunk(2) 293 | pose_out_root = os.path.join(cur_out_root, "poses") 294 | os.makedirs(pose_out_root, exist_ok=True) 295 | for video_idx, ret_c2w in enumerate(ret_c2w_list): 296 | ret_c2w = ret_c2w.detach().cpu().numpy() 297 | visualizer = CameraPoseVisualizer([-1, 1], [-1, 1], [-1, 1]) 298 | transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4) 299 | for frame_idx, c2w in enumerate(ret_c2w): 300 | visualizer.extrinsic2pyramid(c2w @ transform_matrix, frame_idx / ret_c2w.shape[0], hw_ratio=args.image_width / args.image_height, base_xval=0.035, zval=0.04) 301 | visualizer.colorbar(16) 302 | pose_img_dir = os.path.join(pose_out_root, f"pose_img_{video_idx}.png") 303 | visualizer.show(pose_img_dir) 304 | ret_c2w_dir = os.path.join(pose_out_root, f"ret_c2w_{video_idx}.png") 305 | np.save(ret_c2w_dir, ret_c2w) 306 | 307 | sample = rearrange(sample, "b c f h w -> c f h (b w)") 308 | sample_all.append(sample) 309 | 310 | if args.no_lora_validation: 311 | pipeline.unet.set_image_layer_lora_scale(1.0) 312 | 313 | vid_out_root = os.path.join(args.out_root, 'vids') 314 | for video_i in range(len(sample_all)): 315 | save_path = f"{vid_out_root}/{video_i}.mp4" 316 | save_videos_grid(sample_all[video_i].unsqueeze(0), save_path) 317 | 318 | sample_all = torch.stack(sample_all, dim=0) # n x 3 x f x 2h x w 319 | save_path = f"{args.out_root}/results.gif" 320 | save_videos_grid(sample_all, save_path) 321 | print(f"Saved samples to {save_path}") 322 | save_path = f"{args.out_root}/results.mp4" 323 | save_videos_grid(sample_all, save_path) 324 | print(f"Saved samples to {save_path}") 325 | 326 | if __name__ == '__main__': 327 | parser = argparse.ArgumentParser() 328 | parser.add_argument("--out_root", type=str) 329 | parser.add_argument("--local-rank", type=int) 330 | 331 | parser.add_argument("--image_height", type=int, default=256) 332 | parser.add_argument("--image_width", type=int, default=384) 333 | parser.add_argument("--video_length", type=int, default=16) 334 | 335 | # Model Configs 336 | parser.add_argument("--ori_model_path", type=str, help='path to the sd model folder') 337 | parser.add_argument("--unet_subfolder", type=str, help='subfolder name of unet ckpt') 338 | parser.add_argument("--image_lora_rank", type=int, default=2) 339 | parser.add_argument("--image_lora_ckpt", default=None) 340 | parser.add_argument("--civitai_lora_ckpt", default=None) 341 | parser.add_argument("--civitai_base_model", default=None) 342 | parser.add_argument("--pose_adaptor_ckpt", default=None, help='path to the camera control model ckpt') 343 | parser.add_argument("--motion_module_ckpt", type=str, help='path to the animatediff motion module ckpt') 344 | parser.add_argument("--epi_module_ckpt", type=str, help='path to the epi module ckpt') 345 | parser.add_argument("--model_config", type=str) 346 | 347 | # Inference Configs 348 | parser.add_argument("--num_inference_steps", type=int, default=25) 349 | parser.add_argument("--guidance_scale", type=float, default=15.0) 350 | parser.add_argument("--caption_file", required=True, help='prompts path, json or txt') 351 | parser.add_argument("--use_negative_prompt", action='store_true', help='whether to use negative prompts') 352 | parser.add_argument("--use_specific_seeds", action='store_true', help='whether to use specific seeds for each prompt') 353 | parser.add_argument("--zero_first_frame_scale", action='store_true') 354 | parser.add_argument("--global_seed", type=int, default=1024) 355 | 356 | parser.add_argument("--spatial_extended_attention", action='store_true') 357 | parser.add_argument("--pose_adaptor_scale", type=float, default=1.0) 358 | 359 | parser.add_argument("--pose_file_0", default=None) 360 | parser.add_argument("--pose_file_1", default=None) 361 | parser.add_argument("--num_videos", type=int, default=None) 362 | 363 | # validation dataset configs 364 | parser.add_argument("--no_lora_validation", action='store_true') 365 | 366 | # DDP args 367 | parser.add_argument("--world_size", default=1, type=int, 368 | help="number of the distributed processes.") 369 | parser.add_argument('--local_rank', type=int, default=-1, 370 | help='Replica rank on the current node. This field is required ' 371 | 'by `torch.distributed.launch`.') 372 | args = parser.parse_args() 373 | 374 | main(args) -------------------------------------------------------------------------------- /inference_epi_advanced.py: -------------------------------------------------------------------------------- 1 | # make sure you're logged in with `huggingface-cli login` 2 | import argparse 3 | import json 4 | import os 5 | 6 | import numpy as np 7 | import math 8 | import torch 9 | from tqdm import tqdm 10 | from packaging import version as pver 11 | from einops import rearrange 12 | from safetensors import safe_open 13 | 14 | from omegaconf import OmegaConf 15 | from diffusers import ( 16 | AutoencoderKL, 17 | DDIMScheduler 18 | ) 19 | from transformers import CLIPTextModel, CLIPTokenizer 20 | 21 | from animatediff.utils.util import save_videos_grid, save_video_as_images 22 | from animatediff.models.unet import UNet3DConditionModelPoseCond 23 | from animatediff.models.pose_adaptor import CameraPoseEncoder 24 | from animatediff.data.dataset_validation import Camera 25 | 26 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_vae_checkpoint, \ 27 | convert_ldm_clip_checkpoint 28 | from animatediff.pipelines.pipeline_animation_epi_advanced import AnimationPipelineEpiControl 29 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint 30 | 31 | 32 | def setup_for_distributed(is_master): 33 | """ 34 | This function disables printing when not in master process 35 | """ 36 | import builtins as __builtin__ 37 | builtin_print = __builtin__.print 38 | 39 | def print(*args, **kwargs): 40 | force = kwargs.pop('force', False) 41 | if is_master or force: 42 | builtin_print(*args, **kwargs) 43 | 44 | __builtin__.print = print 45 | 46 | 47 | def custom_meshgrid(*args): 48 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid 49 | if pver.parse(torch.__version__) < pver.parse('1.10'): 50 | return torch.meshgrid(*args) 51 | else: 52 | return torch.meshgrid(*args, indexing='ij') 53 | 54 | 55 | def get_relative_pose(cam_params, zero_first_frame_scale): 56 | abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] 57 | abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] 58 | source_cam_c2w = abs_c2ws[0] 59 | if zero_first_frame_scale: 60 | cam_to_origin = 0 61 | else: 62 | cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) 63 | target_cam_c2w = np.array([ 64 | [1, 0, 0, 0], 65 | [0, 1, 0, -cam_to_origin], 66 | [0, 0, 1, 0], 67 | [0, 0, 0, 1] 68 | ]) 69 | abs2rel = target_cam_c2w @ abs_w2cs[0] 70 | ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] 71 | ret_poses = np.array(ret_poses, dtype=np.float32) 72 | return ret_poses 73 | 74 | 75 | def ray_condition(K, c2w, H, W, device): 76 | # c2w: B, V, 4, 4 77 | # K: B, V, 4 78 | 79 | B = K.shape[0] 80 | 81 | j, i = custom_meshgrid( 82 | torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), 83 | torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), 84 | ) 85 | i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] 86 | j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] 87 | 88 | fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 89 | 90 | zs = torch.ones_like(i) # [B, HxW] 91 | xs = (i - cx) / fx * zs 92 | ys = (j - cy) / fy * zs 93 | zs = zs.expand_as(ys) 94 | 95 | directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 96 | directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 97 | 98 | rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW 99 | rays_o = c2w[..., :3, 3] # B, V, 3 100 | rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW 101 | # c2w @ dirctions 102 | rays_dxo = torch.cross(rays_o, rays_d) 103 | plucker = torch.cat([rays_dxo, rays_d], dim=-1) 104 | plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 105 | # plucker = plucker.permute(0, 1, 4, 2, 3) 106 | return plucker 107 | 108 | from scipy.spatial.transform import Rotation as R 109 | from scipy.spatial.transform import Slerp 110 | def interpolate_pose(src_pose, tgt_pose, split_num, perturb_traj_norm): 111 | ret_poses = np.repeat(src_pose[None], split_num, axis=0) 112 | 113 | perturb_t = perturb_traj_norm * np.random.randn(3) 114 | # interpolate translation 115 | for i in range(split_num): 116 | alpha = i / (split_num-1) 117 | ret_poses[i, :3, 3] = src_pose[:3, 3] * (1-alpha) + (tgt_pose[:3, 3]+perturb_t) * alpha # blend translate 118 | 119 | # interpolate rotation 120 | src_quat = R.from_matrix(src_pose[:3, :3]) 121 | tgt_quat = R.from_matrix(tgt_pose[:3, :3]) 122 | interp_time = np.linspace(0, 1, split_num) 123 | sl = Slerp([0, 1], R.concatenate([src_quat, tgt_quat])) 124 | interp_quat = sl(interp_time) 125 | interp_rot = interp_quat.as_matrix() 126 | ret_poses[:, :3, :3] = interp_rot 127 | 128 | return ret_poses 129 | 130 | def load_civitai_base_model(pipeline, civitai_base_model): 131 | print(f'Load civitai base model from {civitai_base_model}') 132 | if civitai_base_model.endswith(".safetensors"): 133 | dreambooth_state_dict = {} 134 | with safe_open(civitai_base_model, framework="pt", device="cpu") as f: 135 | for key in f.keys(): 136 | dreambooth_state_dict[key] = f.get_tensor(key) 137 | elif civitai_base_model.endswith(".ckpt"): 138 | dreambooth_state_dict = torch.load(civitai_base_model, map_location="cpu") 139 | 140 | # 1. vae 141 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config) 142 | pipeline.vae.load_state_dict(converted_vae_checkpoint) 143 | # 2. unet 144 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config) 145 | _, unetu = pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 146 | assert len(unetu) == 0 147 | # 3. text_model 148 | pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, text_encoder=pipeline.text_encoder) 149 | del dreambooth_state_dict 150 | return pipeline 151 | 152 | 153 | def get_pipeline(args, ori_model_path, unet_subfolder, image_lora_rank, image_lora_ckpt, unet_additional_kwargs, 154 | unet_mm_ckpt, unet_epi_ckpt, pose_encoder_kwargs, attention_processor_kwargs, 155 | noise_scheduler_kwargs, pose_adaptor_ckpt, civitai_lora_ckpt, civitai_base_model, gpu_id, 156 | spatial_extended_attention=False): 157 | 158 | vae = AutoencoderKL.from_pretrained(ori_model_path, subfolder="vae") 159 | tokenizer = CLIPTokenizer.from_pretrained(ori_model_path, subfolder="tokenizer") 160 | text_encoder = CLIPTextModel.from_pretrained(ori_model_path, subfolder="text_encoder") 161 | 162 | unet_additional_kwargs['epi_module_kwargs']['epi_mono_direction'] = args.mono_direction 163 | unet_additional_kwargs['epi_module_kwargs']['epi_fix_firstframe'] = args.fix_firstframe 164 | unet_additional_kwargs['epi_module_kwargs']['epi_position_encoding_F_mat_size'] = args.image_height 165 | unet = UNet3DConditionModelPoseCond.from_pretrained_2d(ori_model_path, subfolder=unet_subfolder, 166 | unet_additional_kwargs=unet_additional_kwargs) 167 | 168 | pose_encoder = CameraPoseEncoder(**pose_encoder_kwargs) 169 | print(f"Setting the attention processors") 170 | unet.set_all_attn_processor(add_spatial_lora=image_lora_ckpt is not None, 171 | add_motion_lora=False, 172 | lora_kwargs={"lora_rank": image_lora_rank, "lora_scale": 1.0}, 173 | motion_lora_kwargs={"lora_rank": -1, "lora_scale": 1.0}, 174 | sync_lora_kwargs={"sync_lora_rank": 0, "sync_lora_scale": 0}, 175 | spatial_extended_attention=spatial_extended_attention, 176 | **attention_processor_kwargs) 177 | 178 | if image_lora_ckpt is not None: 179 | print(f"Loading the lora checkpoint from {image_lora_ckpt}") 180 | lora_checkpoints = torch.load(image_lora_ckpt, map_location=unet.device) 181 | if 'lora_state_dict' in lora_checkpoints.keys(): 182 | lora_checkpoints = lora_checkpoints['lora_state_dict'] 183 | _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False) 184 | assert len(lora_u) == 0 185 | print(f'Loading done') 186 | 187 | if unet_mm_ckpt is not None: 188 | print(f"Loading the motion module checkpoint from {unet_mm_ckpt}") 189 | mm_checkpoints = torch.load(unet_mm_ckpt, map_location=unet.device) 190 | _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False) 191 | assert len(mm_u) == 0 192 | print("Loading done") 193 | 194 | if unet_epi_ckpt is not None: 195 | print(f"Loading the epi module checkpoint from {unet_epi_ckpt}") 196 | ckpt = torch.load(unet_epi_ckpt, map_location=unet.device) 197 | unet_trainable_dict = ckpt['unet_trainable_dict'] 198 | _, epi_u = unet.load_state_dict(unet_trainable_dict, strict=False) 199 | assert len(epi_u) == 0 200 | print("Loading done") 201 | 202 | print(f"Loading pose adaptor") 203 | pose_adaptor_checkpoint = torch.load(pose_adaptor_ckpt, map_location='cpu') 204 | pose_encoder_state_dict = pose_adaptor_checkpoint['pose_encoder_state_dict'] 205 | pose_encoder_m, pose_encoder_u = pose_encoder.load_state_dict(pose_encoder_state_dict) 206 | assert len(pose_encoder_u) == 0 and len(pose_encoder_m) == 0 207 | attention_processor_state_dict = pose_adaptor_checkpoint['attention_processor_state_dict'] 208 | _, attn_proc_u = unet.load_state_dict(attention_processor_state_dict, strict=False) 209 | assert len(attn_proc_u) == 0 210 | print(f"Loading done") 211 | 212 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 213 | vae.to(gpu_id) 214 | text_encoder.to(gpu_id) 215 | unet.to(gpu_id) 216 | pose_encoder.to(gpu_id) 217 | 218 | pipe = AnimationPipelineEpiControl( 219 | vae=vae, 220 | text_encoder=text_encoder, 221 | tokenizer=tokenizer, 222 | unet=unet, 223 | scheduler=noise_scheduler, 224 | pose_encoder=pose_encoder) 225 | 226 | assert not (civitai_base_model and civitai_lora_ckpt) 227 | if civitai_lora_ckpt is not None: 228 | pipe.load_lora_weights(civitai_lora_ckpt) 229 | if civitai_base_model is not None: 230 | load_civitai_base_model(pipeline=pipe, civitai_base_model=civitai_base_model) 231 | pipe.enable_vae_slicing() 232 | pipe = pipe.to(gpu_id) 233 | 234 | return pipe 235 | 236 | 237 | def main(args): 238 | os.makedirs(args.out_root, exist_ok=True) 239 | rank = args.local_rank 240 | setup_for_distributed(rank == 0) 241 | gpu_id = rank % torch.cuda.device_count() 242 | model_configs = OmegaConf.load(args.model_config) 243 | unet_additional_kwargs = model_configs[ 244 | 'unet_additional_kwargs'] if 'unet_additional_kwargs' in model_configs else None 245 | noise_scheduler_kwargs = model_configs['noise_scheduler_kwargs'] 246 | pose_encoder_kwargs = model_configs['pose_encoder_kwargs'] 247 | attention_processor_kwargs = model_configs['attention_processor_kwargs'] 248 | 249 | # overwritten 250 | attention_processor_kwargs["scale"] = args.pose_adaptor_scale 251 | 252 | print(f'Constructing pipeline') 253 | pipeline = get_pipeline(args, args.ori_model_path, args.unet_subfolder, args.image_lora_rank, args.image_lora_ckpt, 254 | unet_additional_kwargs, args.motion_module_ckpt, args.epi_module_ckpt, 255 | pose_encoder_kwargs, attention_processor_kwargs, 256 | noise_scheduler_kwargs, args.pose_adaptor_ckpt, args.civitai_lora_ckpt, 257 | args.civitai_base_model, f"cuda:{gpu_id}", 258 | spatial_extended_attention=args.spatial_extended_attention) 259 | device = torch.device(f"cuda:{gpu_id}") 260 | print('Done') 261 | 262 | print(f'Loading Validation Dataset') 263 | 264 | # with open(args.validation_prompts_file, "r") as f: 265 | # validation_prompts = [x.replace("\n", "") for x in f.readlines()] 266 | if args.caption_file.endswith('.json'): 267 | json_file = json.load(open(args.caption_file, 'r')) 268 | captions = json_file['captions'] if 'captions' in json_file else json_file['prompts'] 269 | if args.use_negative_prompt: 270 | negative_prompts = json_file['negative_prompts'] 271 | else: 272 | negative_prompts = None 273 | if isinstance(captions[0], dict): 274 | captions = [cap['caption'] for cap in captions] 275 | if args.use_specific_seeds and "seeds" in json_file.keys(): 276 | specific_seeds = json_file['seeds'] 277 | else: 278 | specific_seeds = None 279 | elif args.caption_file.endswith('.txt'): 280 | with open(args.caption_file, 'r') as f: 281 | captions = f.readlines() 282 | captions = [cap.strip() for cap in captions] 283 | negative_prompts = None 284 | specific_seeds = None 285 | else: 286 | raise ValueError("Invalid prompt file") 287 | 288 | print(f'Done') 289 | 290 | 291 | generator = torch.Generator(device=device) 292 | generator.manual_seed(42) 293 | 294 | c2ws_list = [] 295 | 296 | # Define camera trajectory (extrinsic and intrinsic) 297 | K_mats = np.array([[223.578, 0, 128], [0, 223.578, 128], [0, 0, 1]], dtype=np.float64) 298 | K_mats = np.repeat(K_mats[np.newaxis, ...], args.view_num*args.video_length, axis=0) 299 | K_mats[:, 0] *= args.image_width / 256 300 | K_mats[:, 1] *= args.image_height / 256 301 | 302 | if args.cam_pattern == "interpolate": 303 | for i in range(args.view_num): 304 | src_pose = np.eye(4) 305 | tgt_pose = src_pose.copy() 306 | 307 | angle = math.pi / (args.view_num-1) * i 308 | 309 | cam_at = np.array([math.cos(angle), math.cos(angle+0.5) * 0.3, - math.sin(angle) * 0.2]) * args.camera_dist 310 | look_at = np.array([0, 0, 1]) 311 | 312 | cam_z = look_at - cam_at 313 | cam_x = np.array([1, 0, 0]) 314 | cam_y = np.cross(cam_z, cam_x) 315 | cam_y = cam_y / (np.linalg.norm(cam_y)+1e-6) 316 | cam_x = np.cross(cam_y, cam_z) 317 | cam_x = cam_x / (np.linalg.norm(cam_x)+1e-6) 318 | tgt_pose[:3, :3] = np.stack([cam_x, cam_y, cam_z], axis=1) 319 | tgt_pose[:3, 3] = cam_at 320 | 321 | c2ws_list.append(interpolate_pose(src_pose, tgt_pose, args.video_length, args.cam_perturb_traj)) 322 | else: 323 | for i in range(args.view_num): 324 | src_pose = np.eye(4) 325 | tgt_pose = src_pose.copy() 326 | 327 | if args.cam_pattern == "upper_hemi": 328 | angle = math.pi / (args.view_num-1) * i + math.pi 329 | elif args.cam_pattern == "circle": 330 | angle = 2*math.pi / args.view_num * i 331 | 332 | cam_at = np.array([math.cos(angle), math.sin(angle), 0]) * args.camera_dist 333 | look_at = np.array([0, 0, 1]) 334 | 335 | cam_z = look_at - cam_at 336 | cam_x = np.array([1, 0, 0]) 337 | cam_y = np.cross(cam_z, cam_x) 338 | cam_y = cam_y / (np.linalg.norm(cam_y)+1e-6) 339 | cam_x = np.cross(cam_y, cam_z) 340 | cam_x = cam_x / (np.linalg.norm(cam_x)+1e-6) 341 | tgt_pose[:3, :3] = np.stack([cam_x, cam_y, cam_z], axis=1) 342 | tgt_pose[:3, 3] = cam_at 343 | c2ws_list.append(interpolate_pose(src_pose, tgt_pose, args.video_length, args.cam_perturb_traj)) 344 | 345 | c2ws_list = np.concatenate(c2ws_list, axis=0) # bf, 4, 4 346 | intrinsic = np.asarray([[K[0,0], K[1,1], K[0,2], K[1,2]] for K in K_mats], dtype=np.float32) 347 | 348 | K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] 349 | c2ws = torch.as_tensor(c2ws_list).float()[None] # [1, n_frame, 4, 4] 350 | K_mats = torch.as_tensor(K_mats).float()[None] 351 | plucker_embedding = ray_condition(K, c2ws, args.image_height, args.image_width, device='cpu')[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W 352 | plucker_embedding = plucker_embedding.to(device) 353 | plucker_embedding = rearrange(plucker_embedding, "(b f) c h w -> b c f h w", f=args.video_length)# B V 6 H W 354 | plucker_embedding = plucker_embedding.to(device=pipeline.unet.device) 355 | 356 | for seed_id in range(args.multiseed): 357 | sample_all = [] 358 | for idx, prompt in enumerate(captions): 359 | sub_out_dir = os.path.join(args.out_root, "%d_%04d"%(seed_id, idx)) 360 | os.makedirs(sub_out_dir, exist_ok=True) 361 | 362 | save_transforms_json_file = open(os.path.join(sub_out_dir, "transforms.json"), "w") 363 | save_transforms_json = { 364 | "fl_x": float(intrinsic[0,0]), 365 | "fl_y": float(intrinsic[0,1]), 366 | "cx": float(intrinsic[0,2]), 367 | "cy": float(intrinsic[0,3]), 368 | "w": args.image_width, 369 | "h": args.image_height, 370 | "camera_model": "PINHOLE", 371 | "frames": [] 372 | } 373 | if specific_seeds is not None: 374 | specific_seed = specific_seeds[idx] 375 | generator.manual_seed(specific_seed) 376 | sample = pipeline( 377 | F_mats=None, 378 | prompt=prompt, 379 | pose_embedding=plucker_embedding, 380 | video_length=args.video_length, 381 | height=args.image_height, 382 | width=args.image_width, 383 | num_inference_steps=args.num_inference_steps, 384 | guidance_scale=args.guidance_scale, 385 | generator=generator, 386 | aux_c2w=c2ws, 387 | aux_K_mats=K_mats, 388 | multistep=args.multistep, 389 | accumulate_step=args.accumulate_step, 390 | ).videos # [b 3 f h w] 391 | 392 | sample_reshape = rearrange(sample, "b c f h w -> c f (b h) w") 393 | # sample_reshape = rearrange(sample, "b c f h w -> c f (b h) w") 394 | save_path = f"{sub_out_dir}/video.gif" 395 | save_videos_grid(sample_reshape[None], save_path, mp4_also=True) 396 | 397 | for video_idx, sample_video in enumerate(sample): # [3, f, h, w] 398 | image_save_path = f"{sub_out_dir}/images/{video_idx}" 399 | image_paths = save_video_as_images(sample_video, image_save_path) 400 | for img_idx, img_path in enumerate(image_paths): 401 | ref_img_path = img_path.replace(f"{sub_out_dir}/", "") 402 | c2w = c2ws[0,img_idx+video_idx*len(image_paths)].detach().cpu().numpy().copy() 403 | c2w[:3, 1] *= -1 404 | c2w[:3, 2] *= -1 # opencv 2 opengl 405 | # w2c = np.linalg.inv(c2w) 406 | c2w = [[float(c2w[i, j]) for j in range(4)] for i in range(4)] 407 | save_transforms_json['frames'].append({ 408 | "file_path": ref_img_path, 409 | "transform_matrix": c2w}) 410 | json.dump(save_transforms_json, save_transforms_json_file, indent=4) 411 | 412 | sample_all.append(sample_reshape) 413 | 414 | sample_all = torch.stack(sample_all, dim=0) # n x 3 x f x 2h x w 415 | save_path_all = f"{args.out_root}/results_all_{seed_id}.gif" 416 | save_videos_grid(sample_all, save_path_all, n_rows=8, mp4_also=True) 417 | print(f"Saved samples to {save_path}") 418 | 419 | if __name__ == '__main__': 420 | parser = argparse.ArgumentParser() 421 | parser.add_argument("--out_root", type=str) 422 | parser.add_argument("--local-rank", type=int) 423 | parser.add_argument("--image_height", type=int, default=256) 424 | parser.add_argument("--image_width", type=int, default=384) 425 | parser.add_argument("--video_length", type=int, default=16) 426 | 427 | # Model Configs 428 | parser.add_argument("--ori_model_path", type=str, help='path to the sd model folder') 429 | parser.add_argument("--unet_subfolder", type=str, help='subfolder name of unet ckpt') 430 | parser.add_argument("--image_lora_rank", type=int, default=2) 431 | parser.add_argument("--image_lora_ckpt", default=None) 432 | parser.add_argument("--civitai_lora_ckpt", default=None) 433 | parser.add_argument("--civitai_base_model", default=None) 434 | parser.add_argument("--pose_adaptor_ckpt", default=None, help='path to the camera control model ckpt') 435 | parser.add_argument("--motion_module_ckpt", type=str, help='path to the animatediff motion module ckpt') 436 | parser.add_argument("--epi_module_ckpt", type=str, help='path to the epi module ckpt') 437 | 438 | parser.add_argument("--model_config", type=str) 439 | 440 | # Inference Configs 441 | parser.add_argument("--num_inference_steps", type=int, default=25) 442 | parser.add_argument("--guidance_scale", type=float, default=14.0) 443 | parser.add_argument("--caption_file", required=True, help='prompts path, json or txt') 444 | parser.add_argument("--use_negative_prompt", action='store_true', help='whether to use negative prompts') 445 | parser.add_argument("--use_specific_seeds", action='store_true', help='whether to use specific seeds for each prompt') 446 | parser.add_argument("--zero_first_frame_scale", action='store_true') 447 | parser.add_argument("--multiseed", type=int, default=1) 448 | 449 | parser.add_argument("--cam_pattern", type=str, choices=["upper_hemi", "circle", "interpolate"]) 450 | parser.add_argument("--cam_perturb_traj", type=float, default=0) 451 | parser.add_argument("--camera_dist", type=float, default=0.5) 452 | 453 | parser.add_argument("--view_num", type=int, default=2) 454 | parser.add_argument("--multistep", type=int, default=1) 455 | parser.add_argument("--accumulate_step", type=int, default=1) 456 | parser.add_argument("--fix_firstframe", action='store_true') 457 | parser.add_argument("--mono_direction", action='store_true') 458 | parser.add_argument("--spatial_extended_attention", action='store_true') 459 | 460 | parser.add_argument("--pose_adaptor_scale", type=float, default=1.0) 461 | 462 | # DDP args 463 | parser.add_argument("--world_size", default=1, type=int, 464 | help="number of the distributed processes.") 465 | parser.add_argument('--local_rank', type=int, default=-1, 466 | help='Replica rank on the current node. This field is required ' 467 | 'by `torch.distributed.launch`.') 468 | args = parser.parse_args() 469 | 470 | main(args) 471 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | diffusers==0.24.0 3 | xformers==0.0.24 4 | imageio==2.27.0 5 | imageio[ffmpeg] 6 | opencv-python 7 | transformers 8 | gdown 9 | einops 10 | decord 11 | omegaconf 12 | safetensors 13 | gradio 14 | wandb 15 | triton 16 | termcolor 17 | accelerate 18 | huggingface_hub==0.25 19 | scipy -------------------------------------------------------------------------------- /run_inference_advanced.sh: -------------------------------------------------------------------------------- 1 | GPU=$1 2 | CAMERA_TYPE=$2 3 | VIEW_NUM=$3 4 | MASTER_PORT=$(expr 27000 + $GPU) 5 | 6 | # Parameters: 7 | # ori_model_path: path to the Stable Diffusion folder (fused with webvid lora) 8 | # pose_adaptor_ckpt: path to the CameraCtrl's pose module checkpoint 9 | # motion_module_ckpt: path to the AnimateDiff's motion module checkpoint 10 | # epi_module_ckpt: path to the trained CVD's module checkpoint 11 | # civitai_base_model (optional): Stable Diffusion's LoRA checkpoint. The webvid LoRA from AnimateDiff will be used if not specified. 12 | # caption_file: Text prompt file 13 | # view_num: Number of generated multi-view videos 14 | # multistep: Number of recurrent steps for each denoising step 15 | # multiseed: Number of samples for each text prompt 16 | # accumulate_step: Number of pairs assigned to each video (default: 1) 17 | # cam_pattern: pattern of camera trajectories (supported inputs: circle, interpolate, upper_hemi) 18 | 19 | if [ $CAMERA_TYPE == "circle" ]; then 20 | CAMERA_CONFIG='--caption_file assets/cameractrl_prompts_for_circle.json --cam_pattern circle' 21 | elif [ $CAMERA_TYPE == "interpolate" ]; then 22 | CAMERA_CONFIG='--caption_file assets/cameractrl_prompts_for_interpolate.json --cam_pattern interpolate' 23 | else 24 | echo "Invalid camera trajectory" 25 | exit 1 26 | fi 27 | 28 | if [ $VIEW_NUM == "4" ]; then 29 | VIDEO_CONFIG='--view_num 4 --multistep 3' 30 | elif [ $VIEW_NUM == "6" ]; then 31 | VIDEO_CONFIG='--view_num 6 --multistep 6 --accumulate_step 2' 32 | else 33 | echo "Invalid video number" 34 | exit 1 35 | fi 36 | 37 | CUDA_VISIBLE_DEVICES=$GPU python -m torch.distributed.launch --nproc_per_node=1 --master_port=$MASTER_PORT inference_epi_advanced.py \ 38 | --out_root results/${CAMERA_TYPE}_$GPU \ 39 | --ori_model_path ./models/StableDiffusion --unet_subfolder unet_webvidlora_v3 \ 40 | --pose_adaptor_ckpt ./models/CameraCtrl.ckpt \ 41 | --motion_module_ckpt ./models/animatediff_mm.ckpt \ 42 | --epi_module_ckpt ./models/CVD.ckpt \ 43 | --model_config ./configs/inference_config.yaml \ 44 | --use_specific_seeds --zero_first_frame_scale \ 45 | --image_height 256 \ 46 | --image_width 256 \ 47 | --num_inference_steps 25 \ 48 | --multiseed 3 \ 49 | $CAMERA_CONFIG $VIDEO_CONFIG \ 50 | --civitai_base_model ./models/realisticVisionV60B1_v51VAE.safetensors \ -------------------------------------------------------------------------------- /run_inference_simple.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | GPU=$1 4 | SEED=2024 5 | RANDOM_PORT=$((25100 + GPU)) 6 | 7 | # Parameters: 8 | # ori_model_path: path to the Stable Diffusion folder (fused with webvid lora) 9 | # pose_adaptor_ckpt: path to the CameraCtrl's pose module checkpoint 10 | # motion_module_ckpt: path to the AnimateDiff's motion module checkpoint 11 | # epi_module_ckpt: path to the trained CVD's module checkpoint 12 | # civitai_base_model (optional): Stable Diffusion's LoRA checkpoint. The webvid LoRA from AnimateDiff will be used if not specified. 13 | # caption_file: Text prompt file 14 | # pose_file_0: path to the first pose file 15 | # pose_file_1: path to the second pose file 16 | 17 | CUDA_VISIBLE_DEVICES=${GPU} python -m torch.distributed.launch --nproc_per_node=1 --master_port=${RANDOM_PORT} inference_epi.py \ 18 | --out_root ./results/pair_${GPU}/ \ 19 | --ori_model_path ./models/StableDiffusion --unet_subfolder unet_webvidlora_v3 \ 20 | --pose_adaptor_ckpt ./models/CameraCtrl.ckpt \ 21 | --motion_module_ckpt ./models/animatediff_mm.ckpt \ 22 | --epi_module_ckpt ./models/CVD.ckpt \ 23 | --model_config ./configs/inference_config.yaml \ 24 | --civitai_base_model ./models/realisticVisionV60B1_v51VAE.safetensors \ 25 | --caption_file ./assets/cameractrl_prompts.json \ 26 | --zero_first_frame_scale \ 27 | --image_height 256 \ 28 | --image_width 256 \ 29 | --no_lora_validation \ 30 | --guidance_scale 8.5 \ 31 | --pose_adaptor_scale 1.0 \ 32 | --global_seed ${SEED} \ 33 | --use_negative_prompt \ 34 | --num_videos 8 \ 35 | --pose_file_0 ./assets/pose_files/2f25826f0d0ef09a.txt \ 36 | --pose_file_1 ./assets/pose_files/2c80f9eb0d3b2bb4.txt \ 37 | 38 | # Other poses options: 39 | # --pose_file_0 /home/shengqu/repos/Epi_CameraCtrl/assets/pose_files/0bf152ef84195293.txt \ 40 | # --pose_file_1 /home/shengqu/repos/Epi_CameraCtrl/assets/pose_files/0c9b371cc6225682.txt \ 41 | 42 | # --pose_file_0 /home/shengqu/repos/Epi_CameraCtrl/assets/pose_files/0bf152ef84195293.txt \ 43 | # --pose_file_1 /home/shengqu/repos/Epi_CameraCtrl/assets/pose_files/0bf152ef84195293.txt \ 44 | 45 | # --pose_file_0 /home/shengqu/repos/Epi_CameraCtrl/assets/pose_files/0bf152ef84195293.txt \ 46 | # --pose_file_1 /home/shengqu/repos/Epi_CameraCtrl/assets/pose_files/0c9b371cc6225682.txt \ 47 | 48 | -------------------------------------------------------------------------------- /tools/merge_lora2unet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import shutil 5 | from diffusers.models import UNet2DConditionModel 6 | from safetensors.torch import save_file 7 | from diffusers.utils import SAFETENSORS_WEIGHTS_NAME 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--lora_scale', type=float, default=1.0) 13 | parser.add_argument('--lora_ckpt_path', type=str, required=True) 14 | parser.add_argument('--unet_ckpt_path', type=str, required=True, help='root path of the sd1.5 model') 15 | parser.add_argument('--save_path', type=str, required=True, help='args.unet_ckpt_path + a new subfolder name') 16 | parser.add_argument('--unet_config_path', type=str, required=True, help='path to unet config, in the `unet` subfolder of args.unet_ckpt_path') 17 | parser.add_argument('--lora_keys', nargs='*', type=str, default=['to_q', 'to_k', 'to_v', 'to_out']) 18 | parser.add_argument('--negative_lora_keys', type=str, default="bias") 19 | 20 | return parser.parse_args() 21 | 22 | 23 | if __name__ == '__main__': 24 | args = get_args() 25 | os.makedirs(args.save_path, exist_ok=True) 26 | unet = UNet2DConditionModel.from_pretrained(args.unet_ckpt_path, subfolder='unet') 27 | fused_state_dict = unet.state_dict() 28 | 29 | print(f'Loading the lora weights from {args.lora_ckpt_path}') 30 | lora_state_dict = torch.load(args.lora_ckpt_path, map_location='cpu') 31 | if 'state_dict' in lora_state_dict: 32 | lora_state_dict = lora_state_dict['state_dict'] 33 | print(f'Loading done') 34 | print(f'Fusing the lora weight to unet weight') 35 | used_lora_key = [] 36 | for lora_key in args.lora_keys: 37 | unet_keys = [x for x in fused_state_dict.keys() if lora_key in x and args.negative_lora_keys not in x] 38 | print(f'There are {len(unet_keys)} unet keys for lora key: {lora_key}') 39 | for unet_key in unet_keys: 40 | prefixes = unet_key.split('.') 41 | idx = prefixes.index(lora_key) 42 | lora_down_key = ".".join(prefixes[:idx]) + f".processor.{lora_key}_lora.down" + f".{prefixes[-1]}" 43 | lora_up_key = ".".join(prefixes[:idx]) + f".processor.{lora_key}_lora.up" + f".{prefixes[-1]}" 44 | assert lora_down_key in lora_state_dict and lora_up_key in lora_state_dict 45 | print(f'Fusing lora weight for {unet_key}') 46 | fused_state_dict[unet_key] = fused_state_dict[unet_key] + torch.bmm(lora_state_dict[lora_up_key][None, ...], lora_state_dict[lora_down_key][None, ...])[0] * args.lora_scale 47 | used_lora_key.append(lora_down_key) 48 | used_lora_key.append(lora_up_key) 49 | assert len(set(used_lora_key) - set(lora_state_dict.keys())) == 0 50 | print(f'Fusing done') 51 | save_path = os.path.join(args.save_path, SAFETENSORS_WEIGHTS_NAME) 52 | print(f'Saving the fused state dict to {save_path}') 53 | save_file(fused_state_dict, save_path) 54 | config_dst_path = os.path.join(args.save_path, 'config.json') 55 | print(f'Copying the unet config to {config_dst_path}') 56 | shutil.copy(args.unet_config_path, config_dst_path) 57 | print('Done!') 58 | -------------------------------------------------------------------------------- /tools/visualize_trajectory.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import numpy as np 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | from matplotlib.patches import Patch 7 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 8 | 9 | 10 | class CameraPoseVisualizer: 11 | def __init__(self, xlim, ylim, zlim): 12 | self.fig = plt.figure(figsize=(18, 7)) 13 | self.ax = self.fig.add_subplot(projection='3d') 14 | self.plotly_data = None # plotly data traces 15 | self.ax.set_aspect("auto") 16 | self.ax.set_xlim(xlim) 17 | self.ax.set_ylim(ylim) 18 | self.ax.set_zlim(zlim) 19 | self.ax.set_xlabel('x') 20 | self.ax.set_ylabel('y') 21 | self.ax.set_zlabel('z') 22 | print('initialize camera pose visualizer') 23 | 24 | def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9/16, base_xval=1, zval=3): 25 | vertex_std = np.array([[0, 0, 0, 1], 26 | [base_xval, -base_xval * hw_ratio, zval, 1], 27 | [base_xval, base_xval * hw_ratio, zval, 1], 28 | [-base_xval, base_xval * hw_ratio, zval, 1], 29 | [-base_xval, -base_xval * hw_ratio, zval, 1]]) 30 | vertex_transformed = vertex_std @ extrinsic.T 31 | meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], 32 | [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], 33 | [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], 34 | [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], 35 | [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]] 36 | 37 | color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map) 38 | 39 | self.ax.add_collection3d( 40 | Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) 41 | 42 | def customize_legend(self, list_label): 43 | list_handle = [] 44 | for idx, label in enumerate(list_label): 45 | color = plt.cm.rainbow(idx / len(list_label)) 46 | patch = Patch(color=color, label=label) 47 | list_handle.append(patch) 48 | plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle) 49 | 50 | def colorbar(self, max_frame_length): 51 | cmap = mpl.cm.rainbow 52 | norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length) 53 | self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical', label='Frame Number') 54 | 55 | def show(self, save_path="pose.png"): 56 | plt.title('Extrinsic Parameters') 57 | # plt.show() 58 | plt.savefig(save_path)#, transparent=True) 59 | plt.close() 60 | 61 | def get_args(): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--pose_file_path', required=True) 64 | parser.add_argument('--hw_ratio', default=9/16, type=float) 65 | parser.add_argument('--sample_stride', type=int, default=4) 66 | parser.add_argument('--num_frames', type=int, default=16) 67 | parser.add_argument('--all_frames', action='store_true') 68 | parser.add_argument('--base_xval', type=float, default=0.1) 69 | parser.add_argument('--zval', type=float, default=0.2) 70 | parser.add_argument('--use_exact_fx', action='store_true') 71 | parser.add_argument('--relative_c2w', action='store_true') 72 | parser.add_argument('--x_min', type=float, default=-2) 73 | parser.add_argument('--x_max', type=float, default=2) 74 | parser.add_argument('--y_min', type=float, default=-2) 75 | parser.add_argument('--y_max', type=float, default=2) 76 | parser.add_argument('--z_min', type=float, default=-2) 77 | parser.add_argument('--z_max', type=float, default=2) 78 | return parser.parse_args() 79 | 80 | 81 | def get_c2w(w2cs, transform_matrix, relative_c2w): 82 | if relative_c2w: 83 | target_cam_c2w = np.array([ 84 | [1, 0, 0, 0], 85 | [0, 1, 0, 0], 86 | [0, 0, 1, 0], 87 | [0, 0, 0, 1] 88 | ]) 89 | abs2rel = target_cam_c2w @ w2cs[0] 90 | ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]] 91 | print(ret_poses) 92 | else: 93 | ret_poses = [np.linalg.inv(w2c) for w2c in w2cs] 94 | ret_poses = [transform_matrix @ x for x in ret_poses] 95 | return np.array(ret_poses, dtype=np.float32) 96 | 97 | 98 | if __name__ == '__main__': 99 | args = get_args() 100 | with open(args.pose_file_path, 'r') as f: 101 | poses = f.readlines() 102 | w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]] 103 | fxs = [float(pose.strip().split(' ')[1]) for pose in poses[1:]] 104 | if args.all_frames: 105 | args.num_frames = len(fxs) 106 | args.sample_stride = 1 107 | cropped_length = args.num_frames * args.sample_stride 108 | total_frames = len(w2cs) 109 | start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1)) 110 | end_frame_ind = min(start_frame_ind + cropped_length, total_frames) 111 | frame_ind = np.linspace(start_frame_ind, end_frame_ind - 1, args.num_frames, dtype=int) 112 | w2cs = [w2cs[x] for x in frame_ind] 113 | transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4) 114 | last_row = np.zeros((1, 4)) 115 | last_row[0, -1] = 1.0 116 | w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs] 117 | c2ws = get_c2w(w2cs, transform_matrix, args.relative_c2w) 118 | 119 | 120 | visualizer = CameraPoseVisualizer([args.x_min, args.x_max], [args.y_min, args.y_max], [args.z_min, args.z_max]) 121 | for frame_idx, c2w in enumerate(c2ws): 122 | visualizer.extrinsic2pyramid(c2w, frame_idx / args.num_frames, hw_ratio=args.hw_ratio, base_xval=args.base_xval, 123 | zval=(fxs[frame_idx] if args.use_exact_fx else args.zval)) 124 | 125 | visualizer.colorbar(args.num_frames) 126 | pose_file_name = args.pose_file_path.split('/')[-1].split('.')[0] 127 | visualizer.show() --------------------------------------------------------------------------------