├── .DS_Store ├── LICENSE ├── README.md ├── assets └── teaser.png ├── cam_utils.py ├── configs ├── .DS_Store ├── eval.yaml └── inference │ └── inference_v2.yaml ├── depth_anything_v2 ├── depth_anything_v2 │ ├── __pycache__ │ │ ├── dinov2.cpython-310.pyc │ │ └── dpt.cpython-310.pyc │ ├── dinov2.py │ ├── dinov2_layers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── attention.cpython-310.pyc │ │ │ ├── block.cpython-310.pyc │ │ │ ├── drop_path.cpython-310.pyc │ │ │ ├── layer_scale.cpython-310.pyc │ │ │ ├── mlp.cpython-310.pyc │ │ │ ├── patch_embed.cpython-310.pyc │ │ │ └── swiglu_ffn.cpython-310.pyc │ │ ├── attention.py │ │ ├── block.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── dpt.py │ └── util │ │ ├── __pycache__ │ │ ├── blocks.cpython-310.pyc │ │ └── transform.cpython-310.pyc │ │ ├── blocks.py │ │ └── transform.py └── metric_depth_estimation.py ├── examples ├── DollyOut.txt ├── Still.txt ├── TiltDown.txt ├── backview.jpeg ├── backview.json ├── balloon.json ├── balloon.png ├── balloons.json ├── balloons.png ├── bamboo.json └── bamboo.png ├── inference.py ├── mesh ├── world_envelope.mtl ├── world_envelope.obj └── world_envelope.png ├── requirements.txt └── src ├── .DS_Store ├── __init__.py ├── __pycache__ └── __init__.cpython-310.pyc ├── models ├── __pycache__ │ ├── attention.cpython-310.pyc │ ├── fusion_module.cpython-310.pyc │ ├── gated_self_attention.cpython-310.pyc │ ├── motion_encoder.cpython-310.pyc │ ├── motion_module.cpython-310.pyc │ ├── mutual_self_attention.cpython-310.pyc │ ├── pose_guider.cpython-310.pyc │ ├── resnet.cpython-310.pyc │ ├── transformer_2d.cpython-310.pyc │ ├── transformer_3d.cpython-310.pyc │ ├── unet_2d_blocks.cpython-310.pyc │ ├── unet_2d_condition.cpython-310.pyc │ ├── unet_3d.cpython-310.pyc │ └── unet_3d_blocks.cpython-310.pyc ├── attention.py ├── fusion_module.py ├── motion_encoder.py ├── motion_module.py ├── mutual_self_attention.py ├── resnet.py ├── transformer_2d.py ├── transformer_3d.py ├── unet_2d_blocks.py ├── unet_2d_condition.py ├── unet_3d.py └── unet_3d_blocks.py ├── pipelines ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── pipeline_motion2vid_merge_infer.cpython-310.pyc │ └── utils.cpython-310.pyc ├── pipeline_motion2vid_merge_infer.py └── utils.py └── utils ├── __pycache__ ├── util.cpython-310.pyc ├── utils.cpython-310.pyc └── visualizer.cpython-310.pyc ├── util.py ├── utils.py └── visualizer.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Perception-as-Control 2 | Official implementation of "Perception-as-Control: Fine-grained Controllable Image Animation with 3D-aware Motion Representation" 3 | 4 |

5 | 6 |

7 | 8 | ### [Project page](https://chen-yingjie.github.io/projects/Perception-as-Control/index.html) | [Paper](https://arxiv.org/abs/2501.05020) | [Video](https://comming_soon) | [Online Demo](https://comming_soon) 9 | 10 | **Perception-as-Control: Fine-grained Controllable Image Animation with 3D-aware Motion Representation**
11 | [Yingjie Chen](https://chen-yingjie.github.io/), 12 | [Yifang Men](https://menyifang.github.io/), 13 | [Yuan Yao](mailto:yaoy92@gmail.com), 14 | [Miaomiao Cui](mailto:miaomiao.cmm@alibaba-inc.com), 15 | [Liefeng Bo](https://scholar.google.com/citations?user=FJwtMf0AAAAJ&hl=en)
16 | 17 | ## 💡 Abstract 18 | Motion-controllable image animation is a fundamental task with a wide range of potential applications. Recent works have made progress in controlling camera or object motion via the same 2D motion representations or different control signals, while they still struggle in supporting collaborative camera and object motion control with adaptive control granularity. To this end, we introduce 3D-aware motion representation and propose an image animation framework, called Perception-as-Control, to achieve fine-grained collaborative motion control. Specifically, we construct 3D-aware motion representation from a reference image, manipulate it based on interpreted user intentions, and perceive it from different viewpoints. In this way, camera and object motions are transformed into intuitive, consistent visual changes. Then, the proposed framework leverages the perception results as motion control signals, enabling it to support various motion-related video synthesis tasks in a unified and flexible way. Experiments demonstrate the superiority of the proposed method. 19 | 20 | ## 🔥 Updates 21 | - (2025-03-31) We release the inference code and model weights of Perception-as-Control. 22 | - (2025-03-10) We update a new version of paper with more details. 23 | - (2025-01-09) The project page, demo video and technical report are released. The full paper version with more details is in process. 24 | 25 | ## Usage 26 | ### Environment 27 | ```shell 28 | $ pip install -r requirements.txt 29 | ``` 30 | ### Pretrained Weights 31 | 1. Download [pretrained weights](https://drive.google.com/drive/folders/1ZncmHG9K_n1BjGhVzQemomWzxQ60bYqg?usp=drive_link) and put them in `$INSTALL_DIR/pretrained_weights`. 32 | 33 | 2. Download pretrained weight of based models and put them in `$INSTALL_DIR/pretrained_weights`: 34 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 35 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder) 36 | 37 | The pretrained weights are organized as follows: 38 | ```text 39 | ./pretrained_weights/ 40 | |-- denoising_unet.pth 41 | |-- reference_unet.pth 42 | |-- cam_encoder.pth 43 | |-- obj_encoder.pth 44 | |-- sd-vae-ft-mse 45 | | |-- ... 46 | |-- sd-image-variations-diffusers 47 | | |-- ... 48 | ``` 49 | 50 | ### Inference 51 | ```shell 52 | $ python inference.py 53 | ``` 54 | The results will be saved in `$INSTALL_DIR/outputs`. 55 | 56 | ## 🎥 Demo 57 | 58 | ### Fine-grained collaborative motion control 59 | 60 | 61 | 62 |

Camera Motion Control

63 | 66 |

Object Motion Control

67 | 70 |

Collaborative Motion Control

71 | 74 | 75 | 76 | 77 | 80 | 83 | 86 | 87 | 88 |
64 | 65 | 68 | 69 | 72 | 73 |
78 | 79 | 81 | 82 | 84 | 85 |
89 | 90 | 91 | ### Potential applications 92 | 93 | 94 | 95 | 96 | 100 | 104 | 105 | 106 | 107 | 111 | 115 | 116 | 117 |
97 |

Motion Generation

98 | 99 |
101 |

Motion Clone

102 | 103 |
108 |

Motion Transfer

109 | 110 |
112 |

Motion Editing

113 | 114 |
118 | 119 | For more details, please refer to our [project page](https://chen-yingjie.github.io/projects/Perception-as-Control/index.html). 120 | 121 | 122 | ## 📑 TODO List 123 | - [x] Release inference code and model weights 124 | - [ ] Provide a Gradio demo 125 | - [ ] Release training code 126 | 127 | ## 🔗 Citation 128 | 129 | If you find this code useful for your research, please use the following BibTeX entry. 130 | 131 | ```bibtex 132 | @inproceedings{chen2025perception, 133 | title={Perception-as-Control: Fine-grained Controllable Image Animation with 3D-aware Motion Representation}, 134 | author={Chen, Yingjie and Men, Yifang and Yao, Yuan and Cui, Miaomiao and Bo, Liefeng}, 135 | journal={arXiv preprint arXiv:2501.05020}, 136 | website={https://chen-yingjie.github.io/projects/Perception-as-Control/index.html}, 137 | year={2025}} 138 | ``` 139 | 140 | ## Acknowledgements 141 | 142 | We would like to thank the contributors to [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone), [Depth-Anything-V2](https://github.com/DepthAnything/Depth-Anything-V2), [SpaTracker](https://github.com/henry123-boy/SpaTracker), [Tartanvo](https://github.com/castacks/tartanvo), [diffusers](https://github.com/huggingface/diffusers) for their open research and exploration. -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/assets/teaser.png -------------------------------------------------------------------------------- /cam_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import copy 5 | import math 6 | import shutil 7 | import random 8 | import argparse 9 | import subprocess 10 | from PIL import Image 11 | from scipy.spatial.transform import Rotation as R 12 | from scipy.interpolate import interp1d 13 | 14 | import torch 15 | import numpy as np 16 | # import open3d as o3d 17 | import pytorch3d 18 | from pytorch3d.structures import Meshes 19 | from pytorch3d.renderer import (look_at_view_transform, FoVPerspectiveCameras, 20 | PointLights, RasterizationSettings, PointsRasterizer, PointsRasterizationSettings, 21 | MeshRenderer, MeshRasterizer, SoftPhongShader, PointsRenderer, AlphaCompositor, 22 | blending, Textures) 23 | from pytorch3d.io import load_obj 24 | 25 | from depth_anything_v2.metric_depth_estimation import depth_estimation 26 | 27 | 28 | def draw_points(image, points, output_path): 29 | image = image.copy() 30 | for point in points: 31 | u, v = map(int, point[:2]) 32 | image = cv2.circle(image, (u, v), 5, (0, 0, 255), -1) 33 | 34 | return image 35 | 36 | 37 | class Render(): 38 | def __init__(self, img_size=(320, 576), focal=100, device='cpu', 39 | ply_path=None, uvmap_obj_path=None, is_pointcloud=False): 40 | 41 | self.img_size = img_size 42 | self.focal = focal 43 | self.device = device 44 | self.is_pointcloud = is_pointcloud 45 | 46 | if self.is_pointcloud: 47 | if ply_path != None: 48 | self.model = self.load_point_cloud(ply_path) 49 | else: 50 | if uvmap_obj_path != None: 51 | self.model = self.load_mesh(uvmap_obj_path) 52 | 53 | self.set_renderer() 54 | self.get_renderer() 55 | 56 | def get_points(self): 57 | return self.vts.detach().clone().cpu().numpy() 58 | 59 | def load_point_cloud(self, ply_path): 60 | ply = o3d.io.read_point_cloud(ply_path) 61 | self.vts = torch.Tensor(np.asarray(ply.points)).to(self.device) 62 | rgbs = np.asarray(ply.colors)[:, ::-1].copy() 63 | self.rgbs = torch.Tensor(rgbs).to(self.device) * 255 64 | 65 | colored_pointclouds = pytorch3d.structures.Pointclouds(points=[self.vts], features=[self.rgbs]) 66 | colored_pointclouds.center = torch.tensor([0, 0, 0]) 67 | 68 | return colored_pointclouds 69 | 70 | def load_mesh(self, uvmap_obj_path): 71 | batch_size = 1 72 | 73 | verts, faces, aux = load_obj(uvmap_obj_path) 74 | verts_uvs = aux.verts_uvs[None, ...].repeat(batch_size, 1, 1) 75 | faces_uvs = faces.textures_idx[None, ...].repeat(batch_size, 1, 1) 76 | 77 | tex_maps = aux.texture_images 78 | texture_image = list(tex_maps.values())[0] 79 | texture_image = texture_image * 255.0 80 | texture_maps = torch.tensor(texture_image).unsqueeze(0).repeat( 81 | batch_size, 1, 1, 1).float() 82 | tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_maps).to(self.device) 83 | 84 | mesh = Meshes( 85 | verts=torch.tensor(verts[None, ...]).float().to(self.device), 86 | faces=torch.tensor(faces.verts_idx[None, ...]).float().to(self.device), 87 | textures=tex 88 | ) 89 | 90 | return mesh 91 | 92 | def update_view(self, distance=1.0, pitch=0.0, azimuth=0.0, roll=0.0, target_point=[0.0, 0.0, 1.0]): 93 | target_point = torch.tensor(target_point, device=self.device) 94 | 95 | pitch_rad = math.radians(pitch) 96 | azimuth_rad = math.radians(azimuth) 97 | 98 | z = distance * math.cos(pitch_rad) * math.cos(azimuth_rad) 99 | x = distance * math.cos(pitch_rad) * math.sin(azimuth_rad) 100 | y = distance * math.sin(pitch_rad) 101 | self.camera_position = target_point - torch.tensor([x, y, z], device=self.device) 102 | 103 | R, T = look_at_view_transform( 104 | eye=self.camera_position.unsqueeze(0), 105 | at=target_point.unsqueeze(0), 106 | up=torch.tensor([0.0, 1.0, 0.0], device=self.device).unsqueeze(0), 107 | device=self.device 108 | ) 109 | 110 | def rotate_roll(R, roll): 111 | roll_rad = math.radians(roll) 112 | roll_matrix = torch.tensor([ 113 | [math.cos(roll_rad), -math.sin(roll_rad), 0], 114 | [math.sin(roll_rad), math.cos(roll_rad), 0], 115 | [0, 0, 1], 116 | ], device=self.device) 117 | return torch.matmul(R, roll_matrix) 118 | 119 | if roll != 0: 120 | R = rotate_roll(R, roll) 121 | 122 | self.cameras.R = R 123 | self.cameras.T = T 124 | 125 | return R.cpu().squeeze().numpy(), T.cpu().squeeze().numpy() 126 | 127 | def update_RT(self, RT, S=np.eye(3)): 128 | 129 | if self.is_pointcloud: 130 | R_norm = torch.tensor([[[1., 0., 0.], 131 | [0., 1., 0.], 132 | [0., 0., 1.]]]) 133 | T_norm = torch.tensor([[0., 0., 0.]]) 134 | else: 135 | R_norm, T_norm = look_at_view_transform(-50, 0, 0) 136 | 137 | RT_norm = torch.cat((R_norm.squeeze(0), T_norm.squeeze(0)[..., None]), 138 | 1).numpy() 139 | RT_norm = np.concatenate((RT_norm, np.array([0, 0, 0, 1])[None, ...]), 140 | 0) 141 | 142 | R = RT[:, :3] 143 | T = RT[:, 3] 144 | R = S @ R @ S 145 | T = S @ T 146 | 147 | RT = np.concatenate((R, T[..., None]), 1) 148 | RT = np.concatenate((RT, np.array([0, 0, 0, 1])[None, ...]), 0) 149 | 150 | RT = RT @ RT_norm 151 | R = RT[:3, :3][None, ...] 152 | T = RT[:3, 3][None, ...] 153 | 154 | self.cameras.R = torch.tensor(R).to(self.device) 155 | self.cameras.T = torch.tensor(T).to(self.device) 156 | 157 | return R, T 158 | 159 | def interpolate_slerp(self, RTs, T=16): 160 | Rs = [RT[0] for RT in RTs] 161 | Ts = [RT[1] for RT in RTs] 162 | 163 | times = np.linspace(0, 1, len(RTs)) 164 | interp_times = np.linspace(0, 1, T) 165 | quaternions = [R.from_matrix(R_).as_quat() for R_ in Rs] 166 | interp_Ts = interp1d(times, Ts, axis=0)(interp_times) 167 | 168 | def slerp(q1, q2, t): 169 | q1 = R.from_quat(q1) 170 | q2 = R.from_quat(q2) 171 | return (q1 * (q2 * q1.inv()) ** t).as_quat() 172 | 173 | interp_quaternions = [] 174 | for i in range(len(quaternions) - 1): 175 | t = (interp_times - times[i]) / (times[i + 1] - times[i]) 176 | valid_t = t[(t >= 0) & (t <= 1)] 177 | for t_val in valid_t: 178 | interp_quaternions.append(slerp(quaternions[i], quaternions[i + 1], t_val).squeeze()) 179 | 180 | if len(interp_quaternions) < T: 181 | interp_quaternions.extend([quaternions[-1]] * (T - len(interp_quaternions))) 182 | 183 | interp_quaternions = np.array(interp_quaternions) 184 | interp_Rs = [R.from_quat(q).as_matrix() for q in interp_quaternions] 185 | interpolated_poses = [] 186 | for t in range(T): 187 | interpolated_poses.append((interp_Rs[t].squeeze(), interp_Ts[t].squeeze())) 188 | 189 | return interpolated_poses 190 | 191 | def set_renderer(self): 192 | 193 | self.lights = PointLights(device=self.device, 194 | location=[[0.0, 0.0, 1e5]], 195 | ambient_color=[[1, 1, 1]], 196 | specular_color=[[0., 0., 0.]], 197 | diffuse_color=[[0., 0., 0.]]) 198 | 199 | if self.is_pointcloud: 200 | self.raster_settings = PointsRasterizationSettings( 201 | image_size=self.img_size, 202 | radius=0.01, 203 | points_per_pixel=10, 204 | bin_size=0, 205 | ) 206 | else: 207 | self.raster_settings = RasterizationSettings( 208 | image_size=self.img_size, 209 | blur_radius=0.0, 210 | faces_per_pixel=10, 211 | bin_size=0, 212 | ) 213 | self.blend_params = blending.BlendParams(background_color=[0, 0, 0]) 214 | 215 | R_norm = torch.tensor([[[1., 0., 0.], 216 | [0., 1., 0.], 217 | [0., 0., 1.]]]) 218 | T_norm = torch.tensor([[0., 0., 0.]]) 219 | 220 | self.cameras = FoVPerspectiveCameras( 221 | device=self.device, 222 | R=R_norm, 223 | T=T_norm, 224 | znear=0.01, 225 | zfar=200, 226 | fov=2 * np.arctan(self.img_size[0] // 2 / self.focal) * 180. / np.pi, 227 | aspect_ratio=1.0, 228 | ) 229 | 230 | return 231 | 232 | def get_renderer(self): 233 | if self.is_pointcloud: 234 | 235 | self.renderer = PointsRenderer( 236 | rasterizer=PointsRasterizer(cameras=self.cameras, raster_settings=self.raster_settings), 237 | compositor=AlphaCompositor() 238 | ) 239 | else: 240 | self.renderer = MeshRenderer(rasterizer=MeshRasterizer( 241 | cameras=self.cameras, raster_settings=self.raster_settings), 242 | shader=SoftPhongShader(device=self.device, 243 | cameras=self.cameras, 244 | lights=self.lights, 245 | blend_params=self.blend_params)) 246 | 247 | return 248 | 249 | def render(self): 250 | rendered_img = self.renderer(self.model).cpu() 251 | return rendered_img 252 | 253 | def project_vs(self, coords_3d=None): 254 | if coords_3d is None: 255 | coords_3d = self.model.points_padded()[0] 256 | else: 257 | coords_3d = torch.tensor(coords_3d).to(self.device).float() 258 | 259 | coords_3d = self.cameras.transform_points_screen(coords_3d, image_size=self.img_size) 260 | 261 | return coords_3d 262 | 263 | def get_vertical_distances(self, points_world): 264 | R = self.cameras.R.cpu().numpy()[0] 265 | T = self.cameras.T.cpu().numpy()[0] 266 | T = T[:, np.newaxis] 267 | points_camera = (R @ points_world.T + T).T 268 | vertical_distances = points_camera[:, 2] 269 | 270 | return vertical_distances 271 | 272 | 273 | def generate_query(img_size, video_length): 274 | 275 | height, width = img_size[0], img_size[1] 276 | query_points = [] 277 | for i in range(0, width, width // 30): 278 | for j in range(0, height, height // 30): 279 | query_points.append([i, j]) 280 | query_points = np.array(query_points)[None, ...] 281 | query_points = np.repeat(query_points, video_length, 0) 282 | return query_points 283 | 284 | 285 | def estimate_depth(imagepath, query_points, img_size, focal, output_dir): 286 | 287 | height, width = img_size[0], img_size[1] 288 | imagename = imagepath.split('/')[-1].split('.')[0] 289 | 290 | metric_depth, relative_depth, _ = depth_estimation(imagepath) 291 | 292 | query_points = np.round(query_points) 293 | relative_points_depth = [relative_depth[int(p[1]), int(p[0])] if 0 <= int(p[1]) < relative_depth.shape[0] and 0 <= int(p[0]) < relative_depth.shape[1] else 0.5 for p in query_points[0]] 294 | relative_points_depth = np.array(relative_points_depth)[None, ...] 295 | relative_points_depth = np.repeat(relative_points_depth, query_points.shape[0], 0)[..., None] 296 | 297 | points_depth = [metric_depth[int(p[1]), int(p[0])] if 0 <= int(p[1]) < metric_depth.shape[0] and 0 <= int(p[0]) < metric_depth.shape[1] else 0.5 for p in query_points[0]] 298 | points_depth = np.array(points_depth)[None, ...] 299 | points_depth = np.repeat(points_depth, query_points.shape[0], 0)[..., None] 300 | points_3d = np.concatenate([query_points, points_depth], -1) 301 | 302 | points_3d[:, :, 0] = -(points_3d[:, :, 0] - width / 2) / focal 303 | points_3d[:, :, 1] = -(points_3d[:, :, 1] - height / 2) / focal 304 | 305 | points_3d[:, :, 0] *= points_3d[:, :, 2] 306 | points_3d[:, :, 1] *= points_3d[:, :, 2] 307 | 308 | return points_3d, relative_points_depth 309 | 310 | 311 | def cam_adaptation(imagepath, vo_path, query_path, cam_type, video_length, output_dir): 312 | 313 | os.makedirs(output_dir, exist_ok=True) 314 | 315 | tmp_folder = './tmp_cam' 316 | if os.path.exists(tmp_folder): 317 | shutil.rmtree(tmp_folder) 318 | os.makedirs(tmp_folder, exist_ok=False) 319 | 320 | img = Image.open(imagepath).convert('RGB') 321 | img_size = (img.size[1], img.size[0]) 322 | 323 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 324 | 325 | S = np.eye(3) 326 | focal_point = 500 327 | uvmap_obj_path = './mesh/world_envelope.obj' 328 | 329 | renderer = Render(img_size=img_size, device=device, uvmap_obj_path=uvmap_obj_path, is_pointcloud=False) 330 | renderer_point = Render(img_size=img_size, focal=focal_point, device=device, is_pointcloud=True) 331 | 332 | if cam_type == 'generated': 333 | pattern_name = vo_path.split('/')[-1].split('.')[0] 334 | RTs = generate_vo(renderer_point, pattern_name) 335 | else: 336 | camera_poses = np.loadtxt(vo_path).reshape(-1, 3, 4) 337 | RTs = [(pose[:, :3], pose[:, 3]) for pose in camera_poses] 338 | 339 | if len(RTs) < video_length: 340 | interpolated_poses = renderer.interpolate_slerp(RTs, T=video_length) 341 | else: 342 | interpolated_poses = RTs 343 | interpolated_camera_poses = [] 344 | for i, pose in enumerate(interpolated_poses): 345 | R = pose[0] 346 | T = pose[1] 347 | RT = np.concatenate((R, T[..., None]), 1) 348 | interpolated_camera_poses.append(RT) 349 | interpolated_camera_poses = np.array(interpolated_camera_poses) 350 | camera_poses = interpolated_camera_poses 351 | 352 | if query_path == 'none': 353 | query_points = generate_query(img_size, video_length) 354 | else: 355 | query_points = np.load(query_path)[:, :, :2] 356 | 357 | points_3d, points_depth = estimate_depth(imagepath, query_points, img_size, focal_point, output_dir) 358 | 359 | projected_3d_points = [] 360 | for idx, RT in enumerate(camera_poses): 361 | renderer.update_RT(RT, S) 362 | rendered_img = renderer.render().squeeze().numpy()[:, :, :3] 363 | 364 | renderer_point.update_RT(RT, S) 365 | projected_3d_point = renderer_point.project_vs(points_3d[idx]).cpu().numpy() 366 | projected_3d_point_depth = renderer.get_vertical_distances(points_3d[idx]) 367 | projected_3d_point[:, 2] = projected_3d_point_depth 368 | projected_3d_points.append(projected_3d_point) 369 | 370 | cv2.imwrite(os.path.join(tmp_folder, str(idx).zfill(5) + ".png"), rendered_img) 371 | 372 | cam_path = "./tmp/rendered_cam.mp4" 373 | track_path = "./tmp/rendered_tracks.npy" 374 | 375 | projected_3d_points = np.array(projected_3d_points) 376 | np.save(track_path, projected_3d_points) 377 | 378 | subprocess.call( 379 | "ffmpeg -loglevel quiet -i {} -c:v h264 -pix_fmt yuv420p -y {}".format( 380 | os.path.join(tmp_folder, "%05d.png"), cam_path), 381 | shell=True) 382 | shutil.rmtree(tmp_folder) 383 | 384 | return projected_3d_points, points_depth, cam_path 385 | 386 | 387 | if __name__ == '__main__': 388 | parser = argparse.ArgumentParser() 389 | 390 | # set the gpu 391 | parser.add_argument('--gpu', type=int, default=0, help='gpu id') 392 | # set start idx & end idx 393 | parser.add_argument('--imagepath', type=str, default='./assets/boat3.jpg', help='image path') 394 | parser.add_argument('--vo_path', type=str, default='./cam_patterns/PanDown.txt', help='video path') 395 | parser.add_argument('--query_path', type=str, default='none', help='quey point path') 396 | parser.add_argument('--output_dir', type=str, default='./tmp', help='output dir') 397 | parser.add_argument('--video_length', type=int, default=16, help='camera pose length') 398 | parser.add_argument('--cam_type', type=str, default='generated', help='generate or user-provided') 399 | 400 | args = parser.parse_args() 401 | 402 | # set the gpu 403 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 404 | 405 | os.makedirs(args.output_dir, exist_ok=True) 406 | 407 | tmp_folder = './tmp_cam' 408 | if os.path.exists(tmp_folder): 409 | shutil.rmtree(tmp_folder) 410 | os.makedirs(tmp_folder, exist_ok=False) 411 | 412 | img = Image.open(args.imagepath).convert('RGB') 413 | img_size = (img.size[1], img.size[0]) 414 | 415 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 416 | 417 | S = np.eye(3) 418 | focal_point = 500 419 | uvmap_obj_path = './mesh/world_envelope.obj' 420 | 421 | renderer = Render(img_size=img_size, device=device, uvmap_obj_path=uvmap_obj_path, is_pointcloud=False) 422 | renderer_point = Render(img_size=img_size, focal=focal_point, device=device, is_pointcloud=True) 423 | 424 | if args.cam_type == 'generated': 425 | pattern_name = args.vo_path.split('/')[-1].split('.')[0] 426 | # RTs = generate_vo(renderer_point, pattern_name) 427 | raise NotImplementedError 428 | else: 429 | camera_poses = np.loadtxt(args.vo_path).reshape(-1, 3, 4) 430 | RTs = [(pose[:, :3], pose[:, 3]) for pose in camera_poses] 431 | 432 | if len(RTs) < args.video_length: 433 | interpolated_poses = renderer.interpolate_slerp(RTs, T=args.video_length) 434 | else: 435 | interpolated_poses = RTs 436 | interpolated_camera_poses = [] 437 | for i, pose in enumerate(interpolated_poses): 438 | R = pose[0] 439 | T = pose[1] 440 | RT = np.concatenate((R, T[..., None]), 1) 441 | interpolated_camera_poses.append(RT) 442 | interpolated_camera_poses = np.array(interpolated_camera_poses) 443 | camera_poses = interpolated_camera_poses 444 | 445 | if args.query_path == 'none': 446 | query_points = generate_query(img_size, args.video_length) 447 | else: 448 | query_points = np.load(args.query_path)[:, :, :2] 449 | 450 | points_3d, _ = estimate_depth(args.imagepath, query_points, img_size, focal_point, args.output_dir) 451 | 452 | projected_3d_points = [] 453 | for idx, RT in enumerate(camera_poses): 454 | renderer.update_RT(RT, S) 455 | rendered_img = renderer.render().squeeze().numpy()[:, :, :3] 456 | 457 | renderer_point.update_RT(RT, S) 458 | projected_3d_point = renderer_point.project_vs(points_3d[idx]).cpu().numpy() 459 | projected_3d_point_depth = renderer.get_vertical_distances(points_3d[idx]) 460 | projected_3d_point[:, 2] = projected_3d_point_depth 461 | projected_3d_points.append(projected_3d_point) 462 | 463 | cv2.imwrite(os.path.join(tmp_folder, str(idx).zfill(5) + ".png"), rendered_img) 464 | 465 | cam_path = "./tmp/rendered_cam.mp4" 466 | track_path = "./tmp/rendered_tracks.npy" 467 | 468 | projected_3d_points = np.array(projected_3d_points) 469 | np.save(track_path, projected_3d_points) 470 | 471 | subprocess.call( 472 | "ffmpeg -loglevel quiet -i {} -c:v h264 -pix_fmt yuv420p -y {}".format( 473 | os.path.join(tmp_folder, "%05d.png"), cam_path), 474 | shell=True) 475 | shutil.rmtree(tmp_folder) -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/configs/.DS_Store -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | base_model_path: './pretrained_weights/sd-image-variations-diffusers' 2 | vae_model_path: './pretrained_weights/sd-vae-ft-mse' 3 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder' 4 | mm_path: '' 5 | 6 | denoising_unet_path: ./pretrained_weights/denoising_unet.pth 7 | reference_unet_path: ./pretrained_weights/reference_unet.pth 8 | obj_encoder_path: ./pretrained_weights/obj_encoder.pth 9 | cam_encoder_path: ./pretrained_weights/cam_encoder.pth 10 | 11 | use_lora: false 12 | lora_rank: 64 13 | 14 | inference_config: "./configs/inference/inference_v2.yaml" 15 | weight_dtype: 'fp32' 16 | 17 | is_obj: true 18 | is_cam: true 19 | is_depth: true 20 | is_adapted: true 21 | fusion_type: 'max' 22 | is_pad: true 23 | 24 | W: 768 25 | H: 512 26 | circle_scale: 10 27 | 28 | sample_n_frames: 16 29 | sample_stride: 4 30 | guidance_scale: 3.5 31 | steps: 20 32 | seed: 12580 33 | 34 | save_dir: './outputs' 35 | 36 | sample_n_trajs: -1 37 | cam_only: false 38 | obj_only: false 39 | 40 | remove_tmp_results: true 41 | 42 | test_cases: 43 | 44 | - "./examples/balloons.png": 45 | - "./examples/balloons.json" 46 | - "./examples/Still.txt" 47 | - 12597 48 | 49 | - "./examples/backview.jpeg": 50 | - "./examples/backview.json" 51 | - "./examples/DollyOut.txt" 52 | - 12597 53 | 54 | - "./examples/balloon.png": 55 | - "./examples/balloon.json" 56 | - "./examples/TiltDown.txt" 57 | - 12580 58 | 59 | - "./examples/bamboo.png": 60 | - "./examples/bamboo.json" 61 | - "./examples/Still.txt" 62 | - 12587 -------------------------------------------------------------------------------- /configs/inference/inference_v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | motion_module_mid_block: true 12 | motion_module_decoder_only: false 13 | motion_module_type: Vanilla 14 | motion_module_kwargs: 15 | num_attention_heads: 8 16 | num_transformer_block: 1 17 | attention_block_types: 18 | - Temporal_Self 19 | - Temporal_Self 20 | temporal_position_encoding: true 21 | temporal_position_encoding_max_len: 32 22 | temporal_attention_dim_div: 1 23 | 24 | noise_scheduler_kwargs: 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "linear" 28 | clip_sample: false 29 | steps_offset: 1 30 | ### Zero-SNR params 31 | prediction_type: "v_prediction" 32 | rescale_betas_zero_snr: True 33 | timestep_spacing: "trailing" 34 | 35 | sampler: DDIM -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/__pycache__/dpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/__pycache__/dpt.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | from functools import partial 11 | import math 12 | import logging 13 | from typing import Sequence, Tuple, Union, Callable 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.utils.checkpoint 18 | from torch.nn.init import trunc_normal_ 19 | 20 | from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: 27 | if not depth_first and include_root: 28 | fn(module=module, name=name) 29 | for child_name, child_module in module.named_children(): 30 | child_name = ".".join((name, child_name)) if name else child_name 31 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 32 | if depth_first and include_root: 33 | fn(module=module, name=name) 34 | return module 35 | 36 | 37 | class BlockChunk(nn.ModuleList): 38 | def forward(self, x): 39 | for b in self: 40 | x = b(x) 41 | return x 42 | 43 | 44 | class DinoVisionTransformer(nn.Module): 45 | def __init__( 46 | self, 47 | img_size=224, 48 | patch_size=16, 49 | in_chans=3, 50 | embed_dim=768, 51 | depth=12, 52 | num_heads=12, 53 | mlp_ratio=4.0, 54 | qkv_bias=True, 55 | ffn_bias=True, 56 | proj_bias=True, 57 | drop_path_rate=0.0, 58 | drop_path_uniform=False, 59 | init_values=None, # for layerscale: None or 0 => no layerscale 60 | embed_layer=PatchEmbed, 61 | act_layer=nn.GELU, 62 | block_fn=Block, 63 | ffn_layer="mlp", 64 | block_chunks=1, 65 | num_register_tokens=0, 66 | interpolate_antialias=False, 67 | interpolate_offset=0.1, 68 | ): 69 | """ 70 | Args: 71 | img_size (int, tuple): input image size 72 | patch_size (int, tuple): patch size 73 | in_chans (int): number of input channels 74 | embed_dim (int): embedding dimension 75 | depth (int): depth of transformer 76 | num_heads (int): number of attention heads 77 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 78 | qkv_bias (bool): enable bias for qkv if True 79 | proj_bias (bool): enable bias for proj in attn if True 80 | ffn_bias (bool): enable bias for ffn if True 81 | drop_path_rate (float): stochastic depth rate 82 | drop_path_uniform (bool): apply uniform drop rate across blocks 83 | weight_init (str): weight init scheme 84 | init_values (float): layer-scale init values 85 | embed_layer (nn.Module): patch embedding layer 86 | act_layer (nn.Module): MLP activation layer 87 | block_fn (nn.Module): transformer block class 88 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 89 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 90 | num_register_tokens: (int) number of extra cls tokens (so-called "registers") 91 | interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings 92 | interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings 93 | """ 94 | super().__init__() 95 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 96 | 97 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 98 | self.num_tokens = 1 99 | self.n_blocks = depth 100 | self.num_heads = num_heads 101 | self.patch_size = patch_size 102 | self.num_register_tokens = num_register_tokens 103 | self.interpolate_antialias = interpolate_antialias 104 | self.interpolate_offset = interpolate_offset 105 | 106 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 107 | num_patches = self.patch_embed.num_patches 108 | 109 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 111 | assert num_register_tokens >= 0 112 | self.register_tokens = ( 113 | nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None 114 | ) 115 | 116 | if drop_path_uniform is True: 117 | dpr = [drop_path_rate] * depth 118 | else: 119 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 120 | 121 | if ffn_layer == "mlp": 122 | logger.info("using MLP layer as FFN") 123 | ffn_layer = Mlp 124 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 125 | logger.info("using SwiGLU layer as FFN") 126 | ffn_layer = SwiGLUFFNFused 127 | elif ffn_layer == "identity": 128 | logger.info("using Identity layer as FFN") 129 | 130 | def f(*args, **kwargs): 131 | return nn.Identity() 132 | 133 | ffn_layer = f 134 | else: 135 | raise NotImplementedError 136 | 137 | blocks_list = [ 138 | block_fn( 139 | dim=embed_dim, 140 | num_heads=num_heads, 141 | mlp_ratio=mlp_ratio, 142 | qkv_bias=qkv_bias, 143 | proj_bias=proj_bias, 144 | ffn_bias=ffn_bias, 145 | drop_path=dpr[i], 146 | norm_layer=norm_layer, 147 | act_layer=act_layer, 148 | ffn_layer=ffn_layer, 149 | init_values=init_values, 150 | ) 151 | for i in range(depth) 152 | ] 153 | if block_chunks > 0: 154 | self.chunked_blocks = True 155 | chunked_blocks = [] 156 | chunksize = depth // block_chunks 157 | for i in range(0, depth, chunksize): 158 | # this is to keep the block index consistent if we chunk the block list 159 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) 160 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 161 | else: 162 | self.chunked_blocks = False 163 | self.blocks = nn.ModuleList(blocks_list) 164 | 165 | self.norm = norm_layer(embed_dim) 166 | self.head = nn.Identity() 167 | 168 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 169 | 170 | self.init_weights() 171 | 172 | def init_weights(self): 173 | trunc_normal_(self.pos_embed, std=0.02) 174 | nn.init.normal_(self.cls_token, std=1e-6) 175 | if self.register_tokens is not None: 176 | nn.init.normal_(self.register_tokens, std=1e-6) 177 | named_apply(init_weights_vit_timm, self) 178 | 179 | def interpolate_pos_encoding(self, x, w, h): 180 | previous_dtype = x.dtype 181 | npatch = x.shape[1] - 1 182 | N = self.pos_embed.shape[1] - 1 183 | if npatch == N and w == h: 184 | return self.pos_embed 185 | pos_embed = self.pos_embed.float() 186 | class_pos_embed = pos_embed[:, 0] 187 | patch_pos_embed = pos_embed[:, 1:] 188 | dim = x.shape[-1] 189 | w0 = w // self.patch_size 190 | h0 = h // self.patch_size 191 | # we add a small number to avoid floating point error in the interpolation 192 | # see discussion at https://github.com/facebookresearch/dino/issues/8 193 | # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 194 | w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset 195 | # w0, h0 = w0 + 0.1, h0 + 0.1 196 | 197 | sqrt_N = math.sqrt(N) 198 | sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N 199 | patch_pos_embed = nn.functional.interpolate( 200 | patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), 201 | scale_factor=(sx, sy), 202 | # (int(w0), int(h0)), # to solve the upsampling shape issue 203 | mode="bicubic", 204 | antialias=self.interpolate_antialias 205 | ) 206 | 207 | assert int(w0) == patch_pos_embed.shape[-2] 208 | assert int(h0) == patch_pos_embed.shape[-1] 209 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 210 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 211 | 212 | def prepare_tokens_with_masks(self, x, masks=None): 213 | B, nc, w, h = x.shape 214 | x = self.patch_embed(x) 215 | if masks is not None: 216 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 217 | 218 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 219 | x = x + self.interpolate_pos_encoding(x, w, h) 220 | 221 | if self.register_tokens is not None: 222 | x = torch.cat( 223 | ( 224 | x[:, :1], 225 | self.register_tokens.expand(x.shape[0], -1, -1), 226 | x[:, 1:], 227 | ), 228 | dim=1, 229 | ) 230 | 231 | return x 232 | 233 | def forward_features_list(self, x_list, masks_list): 234 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] 235 | for blk in self.blocks: 236 | x = blk(x) 237 | 238 | all_x = x 239 | output = [] 240 | for x, masks in zip(all_x, masks_list): 241 | x_norm = self.norm(x) 242 | output.append( 243 | { 244 | "x_norm_clstoken": x_norm[:, 0], 245 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], 246 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], 247 | "x_prenorm": x, 248 | "masks": masks, 249 | } 250 | ) 251 | return output 252 | 253 | def forward_features(self, x, masks=None): 254 | if isinstance(x, list): 255 | return self.forward_features_list(x, masks) 256 | 257 | x = self.prepare_tokens_with_masks(x, masks) 258 | 259 | for blk in self.blocks: 260 | x = blk(x) 261 | 262 | x_norm = self.norm(x) 263 | return { 264 | "x_norm_clstoken": x_norm[:, 0], 265 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], 266 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], 267 | "x_prenorm": x, 268 | "masks": masks, 269 | } 270 | 271 | def _get_intermediate_layers_not_chunked(self, x, n=1): 272 | x = self.prepare_tokens_with_masks(x) 273 | # If n is an int, take the n last blocks. If it's a list, take them 274 | output, total_block_len = [], len(self.blocks) 275 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 276 | for i, blk in enumerate(self.blocks): 277 | x = blk(x) 278 | if i in blocks_to_take: 279 | output.append(x) 280 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 281 | return output 282 | 283 | def _get_intermediate_layers_chunked(self, x, n=1): 284 | x = self.prepare_tokens_with_masks(x) 285 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 286 | # If n is an int, take the n last blocks. If it's a list, take them 287 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 288 | for block_chunk in self.blocks: 289 | for blk in block_chunk[i:]: # Passing the nn.Identity() 290 | x = blk(x) 291 | if i in blocks_to_take: 292 | output.append(x) 293 | i += 1 294 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 295 | return output 296 | 297 | def get_intermediate_layers( 298 | self, 299 | x: torch.Tensor, 300 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 301 | reshape: bool = False, 302 | return_class_token: bool = False, 303 | norm=True 304 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 305 | if self.chunked_blocks: 306 | outputs = self._get_intermediate_layers_chunked(x, n) 307 | else: 308 | outputs = self._get_intermediate_layers_not_chunked(x, n) 309 | if norm: 310 | outputs = [self.norm(out) for out in outputs] 311 | class_tokens = [out[:, 0] for out in outputs] 312 | outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] 313 | if reshape: 314 | B, _, w, h = x.shape 315 | outputs = [ 316 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() 317 | for out in outputs 318 | ] 319 | if return_class_token: 320 | return tuple(zip(outputs, class_tokens)) 321 | return tuple(outputs) 322 | 323 | def forward(self, *args, is_training=False, **kwargs): 324 | ret = self.forward_features(*args, **kwargs) 325 | if is_training: 326 | return ret 327 | else: 328 | return self.head(ret["x_norm_clstoken"]) 329 | 330 | 331 | def init_weights_vit_timm(module: nn.Module, name: str = ""): 332 | """ViT weight initialization, original timm impl (for reproducibility)""" 333 | if isinstance(module, nn.Linear): 334 | trunc_normal_(module.weight, std=0.02) 335 | if module.bias is not None: 336 | nn.init.zeros_(module.bias) 337 | 338 | 339 | def vit_small(patch_size=16, num_register_tokens=0, **kwargs): 340 | model = DinoVisionTransformer( 341 | patch_size=patch_size, 342 | embed_dim=384, 343 | depth=12, 344 | num_heads=6, 345 | mlp_ratio=4, 346 | block_fn=partial(Block, attn_class=MemEffAttention), 347 | num_register_tokens=num_register_tokens, 348 | **kwargs, 349 | ) 350 | return model 351 | 352 | 353 | def vit_base(patch_size=16, num_register_tokens=0, **kwargs): 354 | model = DinoVisionTransformer( 355 | patch_size=patch_size, 356 | embed_dim=768, 357 | depth=12, 358 | num_heads=12, 359 | mlp_ratio=4, 360 | block_fn=partial(Block, attn_class=MemEffAttention), 361 | num_register_tokens=num_register_tokens, 362 | **kwargs, 363 | ) 364 | return model 365 | 366 | 367 | def vit_large(patch_size=16, num_register_tokens=0, **kwargs): 368 | model = DinoVisionTransformer( 369 | patch_size=patch_size, 370 | embed_dim=1024, 371 | depth=24, 372 | num_heads=16, 373 | mlp_ratio=4, 374 | block_fn=partial(Block, attn_class=MemEffAttention), 375 | num_register_tokens=num_register_tokens, 376 | **kwargs, 377 | ) 378 | return model 379 | 380 | 381 | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): 382 | """ 383 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 384 | """ 385 | model = DinoVisionTransformer( 386 | patch_size=patch_size, 387 | embed_dim=1536, 388 | depth=40, 389 | num_heads=24, 390 | mlp_ratio=4, 391 | block_fn=partial(Block, attn_class=MemEffAttention), 392 | num_register_tokens=num_register_tokens, 393 | **kwargs, 394 | ) 395 | return model 396 | 397 | 398 | def DINOv2(model_name): 399 | model_zoo = { 400 | "vits": vit_small, 401 | "vitb": vit_base, 402 | "vitl": vit_large, 403 | "vitg": vit_giant2 404 | } 405 | 406 | return model_zoo[model_name]( 407 | img_size=518, 408 | patch_size=14, 409 | init_values=1.0, 410 | ffn_layer="mlp" if model_name != "vitg" else "swiglufused", 411 | block_chunks=0, 412 | num_register_tokens=0, 413 | interpolate_antialias=False, 414 | interpolate_offset=0.1 415 | ) -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | 83 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dinov2_layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/dpt.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision.transforms import Compose 6 | 7 | from .dinov2 import DINOv2 8 | from .util.blocks import FeatureFusionBlock, _make_scratch 9 | from .util.transform import Resize, NormalizeImage, PrepareForNet 10 | 11 | 12 | def _make_fusion_block(features, use_bn, size=None): 13 | return FeatureFusionBlock( 14 | features, 15 | nn.ReLU(False), 16 | deconv=False, 17 | bn=use_bn, 18 | expand=False, 19 | align_corners=True, 20 | size=size, 21 | ) 22 | 23 | 24 | class ConvBlock(nn.Module): 25 | def __init__(self, in_feature, out_feature): 26 | super().__init__() 27 | 28 | self.conv_block = nn.Sequential( 29 | nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), 30 | nn.BatchNorm2d(out_feature), 31 | nn.ReLU(True) 32 | ) 33 | 34 | def forward(self, x): 35 | return self.conv_block(x) 36 | 37 | 38 | class DPTHead(nn.Module): 39 | def __init__( 40 | self, 41 | in_channels, 42 | features=256, 43 | use_bn=False, 44 | out_channels=[256, 512, 1024, 1024], 45 | use_clstoken=False 46 | ): 47 | super(DPTHead, self).__init__() 48 | 49 | self.use_clstoken = use_clstoken 50 | 51 | self.projects = nn.ModuleList([ 52 | nn.Conv2d( 53 | in_channels=in_channels, 54 | out_channels=out_channel, 55 | kernel_size=1, 56 | stride=1, 57 | padding=0, 58 | ) for out_channel in out_channels 59 | ]) 60 | 61 | self.resize_layers = nn.ModuleList([ 62 | nn.ConvTranspose2d( 63 | in_channels=out_channels[0], 64 | out_channels=out_channels[0], 65 | kernel_size=4, 66 | stride=4, 67 | padding=0), 68 | nn.ConvTranspose2d( 69 | in_channels=out_channels[1], 70 | out_channels=out_channels[1], 71 | kernel_size=2, 72 | stride=2, 73 | padding=0), 74 | nn.Identity(), 75 | nn.Conv2d( 76 | in_channels=out_channels[3], 77 | out_channels=out_channels[3], 78 | kernel_size=3, 79 | stride=2, 80 | padding=1) 81 | ]) 82 | 83 | if use_clstoken: 84 | self.readout_projects = nn.ModuleList() 85 | for _ in range(len(self.projects)): 86 | self.readout_projects.append( 87 | nn.Sequential( 88 | nn.Linear(2 * in_channels, in_channels), 89 | nn.GELU())) 90 | 91 | self.scratch = _make_scratch( 92 | out_channels, 93 | features, 94 | groups=1, 95 | expand=False, 96 | ) 97 | 98 | self.scratch.stem_transpose = None 99 | 100 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 101 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 102 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 103 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 104 | 105 | head_features_1 = features 106 | head_features_2 = 32 107 | 108 | self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) 109 | self.scratch.output_conv2 = nn.Sequential( 110 | nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), 111 | nn.ReLU(True), 112 | nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), 113 | nn.Sigmoid() 114 | ) 115 | 116 | def forward(self, out_features, patch_h, patch_w): 117 | out = [] 118 | for i, x in enumerate(out_features): 119 | if self.use_clstoken: 120 | x, cls_token = x[0], x[1] 121 | readout = cls_token.unsqueeze(1).expand_as(x) 122 | x = self.readout_projects[i](torch.cat((x, readout), -1)) 123 | else: 124 | x = x[0] 125 | 126 | x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) 127 | 128 | x = self.projects[i](x) 129 | x = self.resize_layers[i](x) 130 | 131 | out.append(x) 132 | 133 | layer_1, layer_2, layer_3, layer_4 = out 134 | 135 | layer_1_rn = self.scratch.layer1_rn(layer_1) 136 | layer_2_rn = self.scratch.layer2_rn(layer_2) 137 | layer_3_rn = self.scratch.layer3_rn(layer_3) 138 | layer_4_rn = self.scratch.layer4_rn(layer_4) 139 | 140 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) 141 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) 142 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) 143 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 144 | 145 | out = self.scratch.output_conv1(path_1) 146 | out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) 147 | out = self.scratch.output_conv2(out) 148 | 149 | return out 150 | 151 | 152 | class DepthAnythingV2(nn.Module): 153 | def __init__( 154 | self, 155 | encoder='vitl', 156 | features=256, 157 | out_channels=[256, 512, 1024, 1024], 158 | use_bn=False, 159 | use_clstoken=False, 160 | max_depth=20.0 161 | ): 162 | super(DepthAnythingV2, self).__init__() 163 | 164 | self.intermediate_layer_idx = { 165 | 'vits': [2, 5, 8, 11], 166 | 'vitb': [2, 5, 8, 11], 167 | 'vitl': [4, 11, 17, 23], 168 | 'vitg': [9, 19, 29, 39] 169 | } 170 | 171 | self.max_depth = max_depth 172 | 173 | self.encoder = encoder 174 | self.pretrained = DINOv2(model_name=encoder) 175 | 176 | self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) 177 | 178 | def forward(self, x): 179 | patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 180 | 181 | features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) 182 | 183 | depth = self.depth_head(features, patch_h, patch_w) * self.max_depth 184 | 185 | return depth.squeeze(1) 186 | 187 | @torch.no_grad() 188 | def infer_image(self, raw_image, input_size=518): 189 | image, (h, w) = self.image2tensor(raw_image, input_size) 190 | 191 | depth = self.forward(image) 192 | 193 | depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] 194 | 195 | return depth.cpu().numpy() 196 | 197 | def image2tensor(self, raw_image, input_size=518): 198 | transform = Compose([ 199 | Resize( 200 | width=input_size, 201 | height=input_size, 202 | resize_target=False, 203 | keep_aspect_ratio=True, 204 | ensure_multiple_of=14, 205 | resize_method='lower_bound', 206 | image_interpolation_method=cv2.INTER_CUBIC, 207 | ), 208 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 209 | PrepareForNet(), 210 | ]) 211 | 212 | h, w = raw_image.shape[:2] 213 | 214 | image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 215 | 216 | image = transform({'image': image})['image'] 217 | image = torch.from_numpy(image).unsqueeze(0) 218 | 219 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 220 | image = image.to(DEVICE) 221 | 222 | return image, (h, w) 223 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/util/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 5 | scratch = nn.Module() 6 | 7 | out_shape1 = out_shape 8 | out_shape2 = out_shape 9 | out_shape3 = out_shape 10 | if len(in_shape) >= 4: 11 | out_shape4 = out_shape 12 | 13 | if expand: 14 | out_shape1 = out_shape 15 | out_shape2 = out_shape * 2 16 | out_shape3 = out_shape * 4 17 | if len(in_shape) >= 4: 18 | out_shape4 = out_shape * 8 19 | 20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 23 | if len(in_shape) >= 4: 24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 25 | 26 | return scratch 27 | 28 | 29 | class ResidualConvUnit(nn.Module): 30 | """Residual convolution module. 31 | """ 32 | 33 | def __init__(self, features, activation, bn): 34 | """Init. 35 | 36 | Args: 37 | features (int): number of features 38 | """ 39 | super().__init__() 40 | 41 | self.bn = bn 42 | 43 | self.groups=1 44 | 45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 46 | 47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 48 | 49 | if self.bn == True: 50 | self.bn1 = nn.BatchNorm2d(features) 51 | self.bn2 = nn.BatchNorm2d(features) 52 | 53 | self.activation = activation 54 | 55 | self.skip_add = nn.quantized.FloatFunctional() 56 | 57 | def forward(self, x): 58 | """Forward pass. 59 | 60 | Args: 61 | x (tensor): input 62 | 63 | Returns: 64 | tensor: output 65 | """ 66 | 67 | out = self.activation(x) 68 | out = self.conv1(out) 69 | if self.bn == True: 70 | out = self.bn1(out) 71 | 72 | out = self.activation(out) 73 | out = self.conv2(out) 74 | if self.bn == True: 75 | out = self.bn2(out) 76 | 77 | if self.groups > 1: 78 | out = self.conv_merge(out) 79 | 80 | return self.skip_add.add(out, x) 81 | 82 | 83 | class FeatureFusionBlock(nn.Module): 84 | """Feature fusion block. 85 | """ 86 | 87 | def __init__( 88 | self, 89 | features, 90 | activation, 91 | deconv=False, 92 | bn=False, 93 | expand=False, 94 | align_corners=True, 95 | size=None 96 | ): 97 | """Init. 98 | 99 | Args: 100 | features (int): number of features 101 | """ 102 | super(FeatureFusionBlock, self).__init__() 103 | 104 | self.deconv = deconv 105 | self.align_corners = align_corners 106 | 107 | self.groups=1 108 | 109 | self.expand = expand 110 | out_features = features 111 | if self.expand == True: 112 | out_features = features // 2 113 | 114 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 115 | 116 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 117 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 118 | 119 | self.skip_add = nn.quantized.FloatFunctional() 120 | 121 | self.size=size 122 | 123 | def forward(self, *xs, size=None): 124 | """Forward pass. 125 | 126 | Returns: 127 | tensor: output 128 | """ 129 | output = xs[0] 130 | 131 | if len(xs) == 2: 132 | res = self.resConfUnit1(xs[1]) 133 | output = self.skip_add.add(output, res) 134 | 135 | output = self.resConfUnit2(output) 136 | 137 | if (size is None) and (self.size is None): 138 | modifier = {"scale_factor": 2} 139 | elif size is None: 140 | modifier = {"size": self.size} 141 | else: 142 | modifier = {"size": size} 143 | 144 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) 145 | 146 | output = self.out_conv(output) 147 | 148 | return output 149 | -------------------------------------------------------------------------------- /depth_anything_v2/depth_anything_v2/util/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | class Resize(object): 6 | """Resize sample to given size (width, height). 7 | """ 8 | 9 | def __init__( 10 | self, 11 | width, 12 | height, 13 | resize_target=True, 14 | keep_aspect_ratio=False, 15 | ensure_multiple_of=1, 16 | resize_method="lower_bound", 17 | image_interpolation_method=cv2.INTER_AREA, 18 | ): 19 | """Init. 20 | 21 | Args: 22 | width (int): desired output width 23 | height (int): desired output height 24 | resize_target (bool, optional): 25 | True: Resize the full sample (image, mask, target). 26 | False: Resize image only. 27 | Defaults to True. 28 | keep_aspect_ratio (bool, optional): 29 | True: Keep the aspect ratio of the input sample. 30 | Output sample might not have the given width and height, and 31 | resize behaviour depends on the parameter 'resize_method'. 32 | Defaults to False. 33 | ensure_multiple_of (int, optional): 34 | Output width and height is constrained to be multiple of this parameter. 35 | Defaults to 1. 36 | resize_method (str, optional): 37 | "lower_bound": Output will be at least as large as the given size. 38 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 39 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 40 | Defaults to "lower_bound". 41 | """ 42 | self.__width = width 43 | self.__height = height 44 | 45 | self.__resize_target = resize_target 46 | self.__keep_aspect_ratio = keep_aspect_ratio 47 | self.__multiple_of = ensure_multiple_of 48 | self.__resize_method = resize_method 49 | self.__image_interpolation_method = image_interpolation_method 50 | 51 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 52 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 53 | 54 | if max_val is not None and y > max_val: 55 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 56 | 57 | if y < min_val: 58 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 59 | 60 | return y 61 | 62 | def get_size(self, width, height): 63 | # determine new height and width 64 | scale_height = self.__height / height 65 | scale_width = self.__width / width 66 | 67 | if self.__keep_aspect_ratio: 68 | if self.__resize_method == "lower_bound": 69 | # scale such that output size is lower bound 70 | if scale_width > scale_height: 71 | # fit width 72 | scale_height = scale_width 73 | else: 74 | # fit height 75 | scale_width = scale_height 76 | elif self.__resize_method == "upper_bound": 77 | # scale such that output size is upper bound 78 | if scale_width < scale_height: 79 | # fit width 80 | scale_height = scale_width 81 | else: 82 | # fit height 83 | scale_width = scale_height 84 | elif self.__resize_method == "minimal": 85 | # scale as least as possbile 86 | if abs(1 - scale_width) < abs(1 - scale_height): 87 | # fit width 88 | scale_height = scale_width 89 | else: 90 | # fit height 91 | scale_width = scale_height 92 | else: 93 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 94 | 95 | if self.__resize_method == "lower_bound": 96 | new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) 97 | new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) 98 | elif self.__resize_method == "upper_bound": 99 | new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) 100 | new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) 101 | elif self.__resize_method == "minimal": 102 | new_height = self.constrain_to_multiple_of(scale_height * height) 103 | new_width = self.constrain_to_multiple_of(scale_width * width) 104 | else: 105 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 106 | 107 | return (new_width, new_height) 108 | 109 | def __call__(self, sample): 110 | width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) 111 | 112 | # resize sample 113 | sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) 114 | 115 | if self.__resize_target: 116 | if "depth" in sample: 117 | sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) 118 | 119 | if "mask" in sample: 120 | sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) 121 | 122 | return sample 123 | 124 | 125 | class NormalizeImage(object): 126 | """Normlize image by given mean and std. 127 | """ 128 | 129 | def __init__(self, mean, std): 130 | self.__mean = mean 131 | self.__std = std 132 | 133 | def __call__(self, sample): 134 | sample["image"] = (sample["image"] - self.__mean) / self.__std 135 | 136 | return sample 137 | 138 | 139 | class PrepareForNet(object): 140 | """Prepare sample for usage as network input. 141 | """ 142 | 143 | def __init__(self): 144 | pass 145 | 146 | def __call__(self, sample): 147 | image = np.transpose(sample["image"], (2, 0, 1)) 148 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 149 | 150 | if "depth" in sample: 151 | depth = sample["depth"].astype(np.float32) 152 | sample["depth"] = np.ascontiguousarray(depth) 153 | 154 | if "mask" in sample: 155 | sample["mask"] = sample["mask"].astype(np.float32) 156 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 157 | 158 | return sample -------------------------------------------------------------------------------- /depth_anything_v2/metric_depth_estimation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import matplotlib 4 | import numpy as np 5 | import os 6 | import torch 7 | 8 | import sys 9 | sys.path.append('.') 10 | from .depth_anything_v2.dpt import DepthAnythingV2 11 | 12 | def estimate(depth_anything, input_size, imagepath, output_dir, is_vis=False): 13 | cmap = matplotlib.colormaps.get_cmap('Spectral') 14 | 15 | imagename = imagepath.split('/')[-1].split('.')[0] 16 | 17 | raw_image = cv2.imread(imagepath) 18 | 19 | metric_depth = depth_anything.infer_image(raw_image, input_size) 20 | 21 | output_path = os.path.join(output_dir, imagename + '_depth_meter.npy') 22 | np.save(output_path, metric_depth) 23 | 24 | relative_depth = (metric_depth - metric_depth.min()) / (metric_depth.max() - metric_depth.min()) 25 | 26 | vis_depth = None 27 | if is_vis: 28 | vis_depth = relative_depth * 255.0 29 | vis_depth = vis_depth.astype(np.uint8) 30 | vis_depth = (cmap(vis_depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8) 31 | 32 | output_path = os.path.join(output_dir, imagename + '_depth.png') 33 | cv2.imwrite(output_path, vis_depth) 34 | 35 | return metric_depth, relative_depth, vis_depth 36 | 37 | 38 | def depth_estimation(imagepath, 39 | input_size=518, 40 | encoder='vitl', 41 | load_from='./depth_anything_v2/ckpts/depth_anything_v2_metric_hypersim_vitl.pth', 42 | max_depth=20, 43 | output_dir='./tmp', 44 | is_vis=False): 45 | 46 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 47 | 48 | model_configs = { 49 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 50 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 51 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 52 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} 53 | } 54 | 55 | depth_anything = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth}) 56 | depth_anything.load_state_dict(torch.load(load_from, map_location='cpu')) 57 | depth_anything = depth_anything.to(DEVICE).eval() 58 | 59 | metric_depth, relative_depth, vis_depth = estimate(depth_anything, input_size, imagepath, output_dir, is_vis) 60 | 61 | return metric_depth, relative_depth, vis_depth 62 | 63 | 64 | if __name__ == '__main__': 65 | 66 | parser = argparse.ArgumentParser() 67 | 68 | parser.add_argument('--input-size', type=int, default=518) 69 | 70 | parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg']) 71 | parser.add_argument('--load-from', type=str, default='../Depth-Anything-V2/checkpoints/depth_anything_v2_metric_hypersim_vitl.pth') 72 | parser.add_argument('--max-depth', type=float, default=20) 73 | 74 | # set the gpu 75 | parser.add_argument('--gpu', type=int, default=0, help='gpu id') 76 | parser.add_argument('--output_dir', type=str, default='./tmp', help='output dir') 77 | parser.add_argument('--imagepath', type=str, default='tmp.jpg', help='imagepath') 78 | 79 | args = parser.parse_args() 80 | 81 | output_dir = args.output_dir 82 | os.makedirs(output_dir, exist_ok=True) 83 | 84 | # set the gpu 85 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 86 | 87 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 88 | 89 | model_configs = { 90 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 91 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 92 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 93 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} 94 | } 95 | 96 | depth_anything = DepthAnythingV2(**{**model_configs[args.encoder], 'max_depth': args.max_depth}) 97 | depth_anything.load_state_dict(torch.load(args.load_from, map_location='cpu')) 98 | depth_anything = depth_anything.to(DEVICE).eval() 99 | 100 | estimate(depth_anything, args.input_size, args.imagepath, args.output_dir) -------------------------------------------------------------------------------- /examples/DollyOut.txt: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 2 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 6.666666666666666574e-02 3 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 1.333333333333333315e-01 4 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 2.000000000000000111e-01 5 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 2.666666666666666630e-01 6 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 3.333333333333333148e-01 7 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 4.000000000000000222e-01 8 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 4.666666666666666741e-01 9 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 5.333333333333333259e-01 10 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 5.999999999999999778e-01 11 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 6.666666666666666297e-01 12 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 7.333333333333332815e-01 13 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 8.000000000000000444e-01 14 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 8.666666666666666963e-01 15 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 9.333333333333333481e-01 16 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 1.000000000000000000e+00 17 | -------------------------------------------------------------------------------- /examples/Still.txt: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 2 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 3 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 4 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 5 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 6 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 7 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 8 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 9 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 10 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 11 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 12 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 13 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 14 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 15 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 16 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 17 | -------------------------------------------------------------------------------- /examples/TiltDown.txt: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 -0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 2 | 1.000000000000000222e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.999323080022165522e-01 -1.163526593424524941e-02 -2.327371078232923999e-04 0.000000000000000000e+00 1.163526593424524941e-02 9.999323080022165522e-01 4.103782897194226538e-05 3 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.997292411732788819e-01 -2.326895663969883593e-02 -4.654742156465847998e-04 0.000000000000000000e+00 2.326895663969883593e-02 9.997292411732788819e-01 8.207565794388453075e-05 4 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.993908270505680314e-01 -3.489949580125043666e-02 -6.982113234698773081e-04 0.000000000000000000e+00 3.489949580125043666e-02 9.993908270505680314e-01 1.231134869158268029e-04 5 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.989171113137995661e-01 -4.652531272860008921e-02 -9.309484312931695996e-04 0.000000000000000000e+00 4.652531272860008921e-02 9.989171113137995661e-01 1.641513158877690615e-04 6 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.983081583082533683e-01 -5.814482827546655491e-02 -1.163685539116461999e-03 0.000000000000000000e+00 5.814482827546655491e-02 9.983081583082533683e-01 2.051891448597113201e-04 7 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.975640503856371133e-01 -6.975647194491900460e-02 -1.396422646939754616e-03 0.000000000000000000e+00 6.975647194491900460e-02 9.975640503856371133e-01 2.462269738316536058e-04 8 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.966848882862320291e-01 -8.135867170659374925e-02 -1.629159754763047016e-03 0.000000000000000000e+00 8.135867170659374925e-02 9.966848882862320291e-01 2.872648028035958644e-04 9 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.956707905510119305e-01 -9.294986198764870755e-02 -1.861896862586339199e-03 0.000000000000000000e+00 9.294986198764870755e-02 9.956707905510119305e-01 3.283026317755381230e-04 10 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.945218953792129835e-01 -1.045284631635702705e-01 -2.094633970409631816e-03 0.000000000000000000e+00 1.045284631635702705e-01 9.945218953792129835e-01 3.693404607474803816e-04 11 | 1.000000000000000222e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.932383578896332166e-01 -1.160929128616613598e-01 -2.327371078232923999e-03 0.000000000000000000e+00 1.160929128616613598e-01 9.932383578896332166e-01 4.103782897194226402e-04 12 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.918203518526672591e-01 -1.276416454408651757e-01 -2.560108186056216182e-03 0.000000000000000000e+00 1.276416454408651757e-01 9.918203518526672591e-01 4.514161186913648988e-04 13 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.902680692435954501e-01 -1.391730973879709010e-01 -2.792845293879509232e-03 0.000000000000000000e+00 1.391730973879709010e-01 9.902680692435954501e-01 4.924539476633072116e-04 14 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.885817202165927409e-01 -1.506857075292882542e-01 -3.025582401702801415e-03 0.000000000000000000e+00 1.506857075292882542e-01 9.885817202165927409e-01 5.334917766352494702e-04 15 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.867615330762774528e-01 -1.621779172420052539e-01 -3.258319509526094032e-03 0.000000000000000000e+00 1.621779172420052539e-01 9.867615330762774528e-01 5.745296056071917288e-04 16 | 1.000000000000000222e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.848077542468020029e-01 -1.736481706652007739e-01 -3.491056617349386215e-03 0.000000000000000000e+00 1.736481706652007739e-01 9.848077542468020029e-01 6.155674345791339874e-04 17 | -------------------------------------------------------------------------------- /examples/backview.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/backview.jpeg -------------------------------------------------------------------------------- /examples/backview.json: -------------------------------------------------------------------------------- 1 | [[[548, 423, 0.5], [562, 412, 0.5], [555, 428, 0.5], [545, 413, 0.5], [553, 433, 0.5]], [[247, 242, 0.5], [264, 233, 0.5], [287, 229, 0.5]], [[411, 185, 0.5]], [[65, 234, 0.5], [77, 229, 0.5], [81, 222, 0.5], [69, 228, 0.5], [64, 229, 0.5], [54, 229, 0.5]], [[148, 107, 0.5], [160, 102, 0.5], [164, 95, 0.5], [152, 101, 0.5], [147, 102, 0.5], [137, 102, 0.5]], [[103, 142, 0.5], [111, 138, 0.5], [114, 133, 0.5], [105, 137, 0.5], [102, 138, 0.5], [95, 138, 0.5]], [[235, 23, 0.5], [243, 19, 0.5], [246, 14, 0.5], [237, 18, 0.5], [234, 19, 0.5], [227, 19, 0.5]], [[745, 166, 0.5], [746, 159, 0.5], [743, 172, 0.5], [748, 179, 0.5], [752, 185, 0.5]], [[729, 70, 0.5], [729, 65, 0.5], [727, 74, 0.5], [731, 79, 0.5], [733, 83, 0.5]], [[366, 287, 0.5], [356, 281, 0.5], [368, 288, 0.5]], [[453, 288, 0.5], [467, 280, 0.5], [450, 290, 0.5]], [[440, 423, 0.5], [449, 420, 0.5], [438, 424, 0.5]], [[368, 407, 0.5], [365, 405, 0.5], [364, 402, 0.5], [372, 408, 0.5]], [[63, 441, 0.5], [71, 441, 0.5], [54, 438, 0.5]], [[140, 444, 0.5], [145, 444, 0.5], [133, 441, 0.5]], [[196, 457, 0.5], [199, 457, 0.5], [191, 454, 0.5]], [[744, 442, 0.5], [746, 442, 0.5], [740, 439, 0.5]], [[463, 70, 0.5]], [[298, 149, 0.5]], [[205, 202, 0.5]], [[83, 337, 0.5]], [[564, 362, 0.5]], [[408, 359, 0.5]], [[404, 232, 0.5]], [[611, 285, 0.5]], [[722, 308, 0.5]], [[644, 327, 0.5]], [[612, 90, 0.5]], [[34, 26, 0.5]], [[3, 102, 0.5]], [[10, 169, 0.5]], [[198, 395, 0.5], [254, 404, 0.5]], [[18, 38, 0.5]], [[424, 486, 0.5]], [[99, 310, 0.5], [23, 316, 0.5]], [[306, 289, 0.5], [234, 289, 0.5]], [[608, 273, 0.5], [561, 269, 0.5]], [[199, 398, 0.5], [273, 377, 0.5]], [[213, 404, 0.5], [287, 383, 0.5]]] -------------------------------------------------------------------------------- /examples/balloon.json: -------------------------------------------------------------------------------- 1 | [[[560, 147, 0.5], [352, 151, 0.5]], [[359, 283, 0.5], [490, 288, 0.5]], [[124, 99, 0.5]], [[95, 397, 0.5]], [[653, 347, 0.5]], [[642, 465, 0.5]]] -------------------------------------------------------------------------------- /examples/balloon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/balloon.png -------------------------------------------------------------------------------- /examples/balloons.json: -------------------------------------------------------------------------------- 1 | [[[520, 90, 0.5], [338, 84, 0.5]], [[557, 197, 0.5], [666, 197, 0.5]], [[334, 246, 0.5], [402, 307, 0.5]], [[604, 283, 0.5], [629, 272, 0.5], [619, 270, 0.5], [600, 270, 0.5], [584, 270, 0.5], [574, 268, 0.5], [574, 252, 0.5], [598, 248, 0.5], [610, 248, 0.5]], [[524, 319, 0.5]]] -------------------------------------------------------------------------------- /examples/balloons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/balloons.png -------------------------------------------------------------------------------- /examples/bamboo.json: -------------------------------------------------------------------------------- 1 | [[[179, 146, 0.5], [209, 146, 0.5], [157, 145, 0.5], [204, 146, 0.5], [162, 146, 0.5]], [[158, 306, 0.5], [185, 306, 0.5], [138, 305, 0.5], [180, 306, 0.5], [142, 306, 0.5]], [[336, 236, 0.5], [336, 215, 0.5], [337, 267, 0.5], [335, 213, 0.5], [335, 253, 0.5]], [[305, 361, 0.5], [284, 376, 0.5], [327, 351, 0.5], [289, 371, 0.5], [319, 356, 0.5]]] -------------------------------------------------------------------------------- /examples/bamboo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/bamboo.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import numpy as np 5 | import gradio as gr 6 | from PIL import Image 7 | from pathlib import Path 8 | from datetime import datetime 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | 15 | from diffusers import AutoencoderKL, DDIMScheduler 16 | from einops import repeat 17 | from omegaconf import OmegaConf 18 | from decord import VideoReader 19 | from transformers import CLIPVisionModelWithProjection 20 | 21 | from src.models.motion_encoder import MotionEncoder 22 | from src.models.fusion_module import FusionModule 23 | from src.models.unet_2d_condition import UNet2DConditionModel 24 | from src.models.unet_3d import UNet3DConditionModel 25 | from src.pipelines.pipeline_motion2vid_merge_infer import Motion2VideoPipeline 26 | from src.utils.util import save_videos_grid 27 | from src.utils.utils import interpolate_trajectory, interpolate_trajectory_3d 28 | from src.utils.visualizer import Visualizer 29 | 30 | import safetensors 31 | 32 | 33 | def visualize_tracks(background_image_path, splited_tracks, video_length, width, height, save_path='./tmp/hint.mp4'): 34 | 35 | background_image = Image.open(background_image_path).convert('RGBA') 36 | background_image = background_image.resize((width, height)) 37 | w, h = background_image.size 38 | transparent_background = np.array(background_image) 39 | transparent_background[:, :, -1] = 128 40 | transparent_background = Image.fromarray(transparent_background) 41 | 42 | # Create a transparent layer with the same size as the background image 43 | transparent_layer = np.zeros((h, w, 4)) 44 | for splited_track in splited_tracks: 45 | if len(splited_track) > 1: 46 | splited_track = interpolate_trajectory(splited_track, video_length) 47 | for i in range(len(splited_track)-1): 48 | start_point = (int(splited_track[i][0]), int(splited_track[i][1])) 49 | end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) 50 | vx = end_point[0] - start_point[0] 51 | vy = end_point[1] - start_point[1] 52 | arrow_length = np.sqrt(vx**2 + vy**2) + 1e-6 53 | if i == len(splited_track)-2: 54 | cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length) 55 | else: 56 | cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2) 57 | else: 58 | cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1) 59 | 60 | transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) 61 | trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) 62 | 63 | save_dir = os.path.dirname(save_path) 64 | save_name = os.path.basename(save_path).split('.')[0] 65 | vis_track = Visualizer(save_dir=save_dir, 66 | pad_value=0, 67 | linewidth=10, 68 | mode='optical_flow', 69 | tracks_leave_trace=-1) 70 | video = np.repeat(np.array(background_image.convert('RGB'))[None, None, ...], splited_tracks.shape[1], 1) 71 | video = video.transpose(0, 1, 4, 2, 3) 72 | tracks = splited_tracks[:, :, :2][None].transpose(0, 2, 1, 3) 73 | depths = splited_tracks[:, :, 2][None].transpose(0, 2, 1) 74 | video_tracks = vis_track.visualize(torch.from_numpy(video), 75 | torch.from_numpy(tracks), 76 | depths=torch.from_numpy(depths), 77 | filename=save_name, 78 | is_depth_norm=True, 79 | query_frame=0) 80 | 81 | return trajectory_map, transparent_layer, video_tracks 82 | 83 | 84 | class Net(nn.Module): 85 | def __init__( 86 | self, 87 | reference_unet: UNet2DConditionModel, 88 | denoising_unet: UNet3DConditionModel, 89 | obj_encoder: MotionEncoder, 90 | cam_encoder: MotionEncoder, 91 | fusion_module: FusionModule, 92 | ): 93 | super().__init__() 94 | self.reference_unet = reference_unet 95 | self.denoising_unet = denoising_unet 96 | self.obj_encoder = obj_encoder 97 | self.cam_encoder = cam_encoder 98 | self.fusion_module = fusion_module 99 | 100 | 101 | class Model(): 102 | def __init__(self, config_path): 103 | self.config_path = config_path 104 | self.load_config() 105 | self.init_model() 106 | self.init_savedir() 107 | 108 | self.output_dir = './tmp' 109 | os.makedirs(self.output_dir, exist_ok=True) 110 | 111 | def load_config(self): 112 | self.config = OmegaConf.load(self.config_path) 113 | 114 | def init_model(self): 115 | 116 | if self.config.weight_dtype == "fp16": 117 | weight_dtype = torch.float16 118 | else: 119 | weight_dtype = torch.float32 120 | 121 | vae = AutoencoderKL.from_pretrained( 122 | self.config.vae_model_path, 123 | ).to("cuda", dtype=weight_dtype) 124 | 125 | reference_unet = UNet2DConditionModel.from_pretrained( 126 | self.config.base_model_path, 127 | subfolder="unet", 128 | ).to(dtype=weight_dtype, device="cuda") 129 | 130 | inference_config_path = self.config.inference_config 131 | infer_config = OmegaConf.load(inference_config_path) 132 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 133 | self.config.base_model_path, 134 | self.config.mm_path, 135 | subfolder="unet", 136 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 137 | ).to(dtype=weight_dtype, device="cuda") 138 | 139 | obj_encoder = MotionEncoder( 140 | conditioning_embedding_channels=320, 141 | conditioning_channels = 3, 142 | block_out_channels=(16, 32, 96, 256) 143 | ).to(device="cuda") 144 | 145 | cam_encoder = MotionEncoder( 146 | conditioning_embedding_channels=320, 147 | conditioning_channels = 1, 148 | block_out_channels=(16, 32, 96, 256) 149 | ).to(device="cuda") 150 | 151 | fusion_module = FusionModule( 152 | fusion_type=self.config.fusion_type, 153 | ).to(device="cuda") 154 | 155 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 156 | self.config.image_encoder_path 157 | ).to(dtype=weight_dtype, device="cuda") 158 | 159 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 160 | scheduler = DDIMScheduler(**sched_kwargs) 161 | 162 | self.generator = torch.manual_seed(self.config.seed) 163 | 164 | denoising_unet.load_state_dict( 165 | torch.load(self.config.denoising_unet_path, map_location="cpu"), 166 | strict=False, 167 | ) 168 | reference_unet.load_state_dict( 169 | torch.load(self.config.reference_unet_path, map_location="cpu"), 170 | ) 171 | obj_encoder.load_state_dict( 172 | torch.load(self.config.obj_encoder_path, map_location="cpu"), 173 | ) 174 | cam_encoder.load_state_dict( 175 | torch.load(self.config.cam_encoder_path, map_location="cpu"), 176 | ) 177 | 178 | pipe = Motion2VideoPipeline( 179 | vae=vae, 180 | image_encoder=image_enc, 181 | reference_unet=reference_unet, 182 | denoising_unet=denoising_unet, 183 | obj_encoder=obj_encoder, 184 | cam_encoder=cam_encoder, 185 | fusion_module=fusion_module, 186 | scheduler=scheduler, 187 | ) 188 | self.pipe = pipe.to("cuda", dtype=weight_dtype) 189 | 190 | def init_savedir(self): 191 | date_str = datetime.now().strftime("%Y%m%d") 192 | time_str = datetime.now().strftime("%H%M") 193 | save_dir_name = f"{time_str}" 194 | 195 | if 'save_dir' not in self.config: 196 | self.save_dir = Path(f"output/{date_str}/{save_dir_name}") 197 | else: 198 | self.save_dir = Path(self.config.save_dir) 199 | self.save_dir.mkdir(exist_ok=True, parents=True) 200 | 201 | def get_control_objs(self, track_path, ori_sample_size, sample_size, video_length=16, imagepath=None, vo_path=None, cam_type='generated'): 202 | imagename = imagepath.split('/')[-1].split('.')[0] 203 | 204 | vis = Visualizer(save_dir='./tmp', 205 | grayscale=False, 206 | mode='rainbow_all', 207 | pad_value=0, 208 | linewidth=1, 209 | tracks_leave_trace=1) 210 | 211 | with open(track_path, 'r') as f: 212 | track_points_user = json.load(f) 213 | track_points = [] 214 | for splited_track in track_points_user: 215 | splited_track = np.array(splited_track) 216 | if splited_track.shape[-1] == 3: 217 | splited_track = interpolate_trajectory_3d(splited_track, video_length) 218 | else: 219 | splited_track = interpolate_trajectory(splited_track, video_length) 220 | splited_track = np.array(splited_track) 221 | track_points.append(splited_track) 222 | track_points = np.array(track_points) 223 | 224 | if track_points.shape[0] == 0 and not self.config.is_adapted: 225 | return None, None, None, None 226 | 227 | query_path = os.path.join("./tmp/tmp_query.npy") 228 | if track_points.shape[0] != 0: 229 | query_points = track_points.transpose(1, 0, 2) 230 | np.save(query_path, query_points) 231 | tracks = query_points 232 | 233 | # get adapted tracks 234 | if track_points.shape[0] == 0: 235 | query_path = 'none' 236 | 237 | projected_3d_points, points_depth, cam_path = cam_adaptation(imagepath, vo_path, query_path, cam_type, video_length, self.output_dir) 238 | 239 | if self.config.is_adapted or track_points.shape[0] == 0: 240 | tracks = projected_3d_points[:, :, :2] 241 | tracks = np.concatenate([tracks, np.ones_like(tracks[:, :, 0])[..., None] * 0.5], -1) 242 | 243 | # get depth 244 | if self.config.is_depth: 245 | tracks = np.concatenate([tracks.astype(float)[..., :2], points_depth], -1) 246 | 247 | T, _, _ = tracks.shape 248 | 249 | tracks[:, :, 0] /= ori_sample_size[1] 250 | tracks[:, :, 1] /= ori_sample_size[0] 251 | 252 | tracks[:, :, 0] *= sample_size[1] 253 | tracks[:, :, 1] *= sample_size[0] 254 | 255 | tracks[..., 0] = np.clip(tracks[:, :, 0], 0, sample_size[1] - 1) 256 | tracks[..., 1] = np.clip(tracks[:, :, 1], 0, sample_size[0] - 1) 257 | 258 | tracks = tracks[np.newaxis, :, :, :] 259 | tracks = torch.tensor(tracks) 260 | 261 | pred_tracks = tracks[:, :, :, :3] 262 | 263 | # vis tracks 264 | splited_tracks = [] 265 | for i in range(pred_tracks.shape[2]): 266 | splited_tracks.append(pred_tracks[0, :, i, :3]) 267 | splited_tracks = np.array(splited_tracks) 268 | 269 | video = torch.zeros(T, 3, sample_size[0], sample_size[1])[None].float() 270 | 271 | vis_objs = vis.visualize(video=video, 272 | tracks=pred_tracks[..., :2], 273 | filename=Path(track_path).stem, 274 | depths=pred_tracks[..., 2], 275 | circle_scale=self.config.circle_scale, 276 | is_blur=False, 277 | is_depth_norm=True, 278 | save_video=False 279 | ) 280 | tracks = tracks.squeeze().numpy() 281 | vis_objs = vis_objs.squeeze().numpy() 282 | guide_value_objs = vis_objs 283 | 284 | return guide_value_objs, splited_tracks, cam_path 285 | 286 | def get_control_cams(self, cam_path, sample_size): 287 | vr = VideoReader(cam_path) 288 | cams = vr.get_batch(list(range(0, len(vr)))).asnumpy()[:, :, :, ::-1] 289 | resized_cams = [] 290 | resized_rgb_cams = [] 291 | for i in range(cams.shape[0]): 292 | if i == 0: 293 | cam_height, cam_width = cams[i].shape[:2] 294 | frame = np.array(Image.fromarray(cams[i]).convert('L').resize([sample_size[1], sample_size[0]])) 295 | resized_cams.append(frame) 296 | frame_rgb = np.array(Image.fromarray(cams[i]).convert('RGB').resize([sample_size[1], sample_size[0]])) 297 | resized_rgb_cams.append(frame_rgb) 298 | guide_value_cams = np.array(resized_cams)[..., None] 299 | del vr 300 | guide_value_cams = guide_value_cams.transpose(0, 3, 1, 2) 301 | return guide_value_cams 302 | 303 | 304 | def run(self, ref_image_path, cam_path, track_path, seed=None): 305 | if not seed: 306 | seed = self.config.seed 307 | self.generator = torch.manual_seed(seed) 308 | video_length = self.config.sample_n_frames 309 | 310 | if os.path.exists(cam_path) and cam_path.endswith('.txt'): 311 | vo_path = cam_path 312 | cam_type = 'user-provided' 313 | else: 314 | vo_path = cam_path 315 | cam_type = 'generated' 316 | 317 | ref_name = Path(ref_image_path).stem 318 | 319 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 320 | ref_image_path = f'{self.save_dir}/{ref_name}_ref.png' 321 | w, h = ref_image_pil.size 322 | ori_width, ori_height = w // 2 * 2, h // 2 * 2 323 | ref_image_pil.resize((ori_width, ori_height)).save(ref_image_path) 324 | 325 | image_transform = transforms.Compose( 326 | [transforms.Resize((self.config.H, self.config.W)), transforms.ToTensor()] 327 | ) 328 | 329 | ref_image_tensor = image_transform(ref_image_pil) # (c, h, w) 330 | ref_image_pil = Image.fromarray((ref_image_tensor * 255).permute(1, 2, 0).numpy().astype(np.uint8)).convert("RGB") 331 | 332 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) 333 | ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video_length) 334 | 335 | control_objs, splited_tracks, cam_path = self.get_control_objs( 336 | track_path, 337 | ori_sample_size=(ori_height, ori_width), 338 | sample_size=(self.config.H, self.config.W), 339 | video_length=video_length, 340 | imagepath=ref_image_path, 341 | vo_path=vo_path.replace('.mp4', '.txt', 1), 342 | cam_type=cam_type, 343 | ) 344 | 345 | control_cams = self.get_control_cams(cam_path, (self.config.H, self.config.W)) 346 | 347 | ref_image_path = f'{self.save_dir}/{ref_name}_ref.png' 348 | ref_image_pil.save(ref_image_path) 349 | 350 | if control_objs is None or self.config.cam_only: 351 | control_objs = np.zeros_like(control_cams).repeat(3, 1) 352 | if self.config.obj_only: 353 | control_cams = np.zeros_like(control_cams) 354 | 355 | control_objs = torch.from_numpy(control_objs).unsqueeze(0).float() 356 | control_objs = control_objs.transpose( 357 | 1, 2 358 | ) # (bs, c, f, H, W) 359 | control_cams = torch.from_numpy(control_cams).unsqueeze(0).float() 360 | control_cams = control_cams.transpose( 361 | 1, 2 362 | ) # (bs, c, f, H, W) 363 | 364 | video = self.pipe( 365 | reference_image=ref_image_pil, 366 | control_objs=control_objs, 367 | control_cams=control_cams, 368 | width=self.config.W, 369 | height=self.config.H, 370 | video_length=video_length, 371 | num_inference_steps=self.config.steps, 372 | guidance_scale=self.config.guidance_scale, 373 | generator=self.generator, 374 | is_obj=self.config.is_obj, 375 | is_cam=self.config.is_cam 376 | ).videos 377 | 378 | cam_pattern_postfix = vo_path.split('/')[-1].split('.')[0] 379 | video_path = f"{self.save_dir}/{ref_name}_{cam_pattern_postfix}_gen.mp4" 380 | hint_path = f'{self.save_dir}/{ref_name}_{cam_pattern_postfix}_hint.mp4' 381 | vis_obj_path = f'{self.save_dir}/{ref_name}_{cam_pattern_postfix}_obj.mp4' 382 | vis_cam_path = f'{self.save_dir}/{ref_name}_{cam_pattern_postfix}_cam.mp4' 383 | 384 | _, _, _ = visualize_tracks(ref_image_path, splited_tracks, video_length, self.config.W, self.config.H, save_path=hint_path) 385 | 386 | torchvision.io.write_video(vis_obj_path, control_objs[0].permute(1, 2, 3, 0).numpy(), fps=8, video_codec='h264', options={'crf': '10'}) 387 | torchvision.io.write_video(vis_cam_path, control_cams[0].permute(1, 2, 3, 0).numpy().repeat(3, 3), fps=8, video_codec='h264', options={'crf': '10'}) 388 | 389 | save_videos_grid( 390 | video, 391 | video_path, 392 | n_rows=1, 393 | fps=8, 394 | ) 395 | 396 | return video_path, hint_path, vis_obj_path, vis_cam_path 397 | 398 | 399 | if __name__ == '__main__': 400 | 401 | config_path = 'configs/eval.yaml' 402 | config = OmegaConf.load(config_path) 403 | Model = Model(config_path) 404 | 405 | path_sets = [] 406 | for test_case in config["test_cases"]: 407 | ref_image_path = list(test_case.keys())[0] 408 | track_path = test_case[ref_image_path][0] 409 | cam_path = test_case[ref_image_path][1] 410 | seed = test_case[ref_image_path][2] 411 | path_set = {'ref_image_path': ref_image_path, 412 | 'track_path': track_path, 413 | 'cam_path': cam_path, 414 | 'seed': seed, 415 | } 416 | path_sets.append(path_set) 417 | 418 | for path_set in path_sets: 419 | ref_image_path = path_set['ref_image_path'] 420 | track_path = path_set['track_path'] 421 | cam_path = path_set['cam_path'] 422 | seed = path_set['seed'] 423 | 424 | Model.run(ref_image_path, cam_path, track_path, seed) -------------------------------------------------------------------------------- /mesh/world_envelope.mtl: -------------------------------------------------------------------------------- 1 | # Blender 4.2.0 MTL File: 'None' 2 | # www.blender.org 3 | 4 | newmtl Material 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd world_envelope.png 14 | -------------------------------------------------------------------------------- /mesh/world_envelope.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/mesh/world_envelope.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | av==11.0.0 3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a 4 | decord==0.6.0 5 | diffusers==0.24.0 6 | einops==0.4.1 7 | gradio==3.41.2 8 | gradio_client==0.5.0 9 | imageio==2.33.0 10 | imageio-ffmpeg==0.4.9 11 | numpy==1.23.5 12 | omegaconf==2.2.3 13 | onnxruntime-gpu==1.16.3 14 | open3d==0.19.0 15 | open-clip-torch==2.20.0 16 | opencv-contrib-python==4.8.1.78 17 | opencv-python==4.8.1.78 18 | Pillow==9.5.0 19 | scikit-image==0.21.0 20 | scikit-learn==1.3.2 21 | scipy==1.11.4 22 | torch==2.0.1 23 | torchdiffeq==0.2.3 24 | torchmetrics==1.2.1 25 | torchsde==0.2.5 26 | torchvision==0.15.2 27 | tqdm==4.66.1 28 | transformers==4.30.2 29 | mlflow==2.9.2 30 | xformers==0.0.22 31 | controlnet-aux==0.0.7 32 | git+https://github.com/facebookresearch/pytorch3d.git -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/.DS_Store -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/fusion_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/fusion_module.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/gated_self_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/gated_self_attention.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/motion_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/motion_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/motion_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/motion_module.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/mutual_self_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/mutual_self_attention.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/pose_guider.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/pose_guider.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/transformer_2d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/transformer_2d.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/transformer_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/transformer_3d.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_2d_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_2d_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_2d_condition.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_2d_condition.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_3d.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_3d_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_3d_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/fusion_module.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from einops import rearrange 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward 10 | 11 | from src.models.motion_module import zero_module 12 | from src.models.resnet import InflatedConv3d 13 | 14 | 15 | class FusionModule(ModelMixin): 16 | def __init__( 17 | self, 18 | in_channels: int = 640, 19 | out_channels: int = 320, 20 | fusion_type: str = 'conv', 21 | fusion_with_norm: bool = True, 22 | ): 23 | super().__init__() 24 | 25 | self.fusion_type = fusion_type 26 | self.fusion_with_norm = fusion_with_norm 27 | self.norm1 = nn.LayerNorm(out_channels) 28 | self.norm2 = nn.LayerNorm(out_channels) 29 | 30 | if self.fusion_type == 'sum': 31 | self.fusion = None 32 | elif self.fusion_type == 'max': 33 | self.fusion = None 34 | elif self.fusion_type == 'conv': 35 | self.fusion = InflatedConv3d( 36 | in_channels, out_channels, kernel_size=3, padding=1 37 | ) 38 | 39 | def forward(self, feat1, feat2): 40 | b, c, t, h, w = feat1.shape 41 | 42 | if self.fusion_type == 'sum': 43 | return feat1 + feat2 44 | elif self.fusion_type == 'max': 45 | return torch.max(feat1, feat2) 46 | elif self.fusion_type == 'conv': 47 | b, c, t, h, w = feat1.shape 48 | if self.fusion_with_norm: 49 | feat1 = rearrange(feat1, "b c t h w -> (b t) (h w) c") 50 | feat2 = rearrange(feat2, "b c t h w -> (b t) (h w) c") 51 | feat1 = self.norm1(feat1) 52 | feat2 = self.norm2(feat2) 53 | feat1 = feat1.view(b, t, h, w, c) 54 | feat2 = feat2.view(b, t, h, w, c) 55 | feat1 = feat1.permute(0, 4, 1, 2, 3).contiguous() 56 | feat2 = feat2.permute(0, 4, 1, 2, 3).contiguous() 57 | feat = torch.concat((feat1, feat2), 1) 58 | feat = self.fusion(feat) 59 | 60 | return feat -------------------------------------------------------------------------------- /src/models/motion_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from diffusers.models.modeling_utils import ModelMixin 7 | 8 | from src.models.motion_module import zero_module 9 | from src.models.resnet import InflatedConv3d 10 | 11 | 12 | class MotionEncoder(ModelMixin): 13 | def __init__( 14 | self, 15 | conditioning_embedding_channels: int, 16 | conditioning_channels: int = 3, 17 | block_out_channels: Tuple[int] = (16, 32, 64, 128), 18 | ): 19 | super().__init__() 20 | self.conv_in = InflatedConv3d( 21 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 22 | ) 23 | 24 | self.blocks = nn.ModuleList([]) 25 | 26 | for i in range(len(block_out_channels) - 1): 27 | channel_in = block_out_channels[i] 28 | channel_out = block_out_channels[i + 1] 29 | self.blocks.append( 30 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) 31 | ) 32 | self.blocks.append( 33 | InflatedConv3d( 34 | channel_in, channel_out, kernel_size=3, padding=1, stride=2 35 | ) 36 | ) 37 | 38 | self.conv_out = zero_module( 39 | InflatedConv3d( 40 | block_out_channels[-1], 41 | conditioning_embedding_channels, 42 | kernel_size=3, 43 | padding=1, 44 | ) 45 | ) 46 | 47 | def forward(self, conditioning): 48 | embedding = self.conv_in(conditioning) 49 | embedding = F.silu(embedding) 50 | 51 | for block in self.blocks: 52 | embedding = block(embedding) 53 | embedding = F.silu(embedding) 54 | 55 | embedding = self.conv_out(embedding) 56 | 57 | return embedding 58 | -------------------------------------------------------------------------------- /src/models/motion_module.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Callable, Optional 5 | 6 | import torch 7 | from diffusers.models.attention import FeedForward 8 | from diffusers.models.attention_processor import Attention, AttnProcessor 9 | from diffusers.utils import BaseOutput 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from einops import rearrange, repeat 12 | from torch import nn 13 | 14 | 15 | def zero_module(module): 16 | # Zero out the parameters of a module and return it. 17 | for p in module.parameters(): 18 | p.detach().zero_() 19 | return module 20 | 21 | 22 | @dataclass 23 | class TemporalTransformer3DModelOutput(BaseOutput): 24 | sample: torch.FloatTensor 25 | 26 | 27 | if is_xformers_available(): 28 | import xformers 29 | import xformers.ops 30 | else: 31 | xformers = None 32 | 33 | 34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): 35 | if motion_module_type == "Vanilla": 36 | return VanillaTemporalModule( 37 | in_channels=in_channels, 38 | **motion_module_kwargs, 39 | ) 40 | else: 41 | raise ValueError 42 | 43 | 44 | class VanillaTemporalModule(nn.Module): 45 | def __init__( 46 | self, 47 | in_channels, 48 | num_attention_heads=8, 49 | num_transformer_block=2, 50 | attention_block_types=("Temporal_Self", "Temporal_Self"), 51 | cross_frame_attention_mode=None, 52 | temporal_position_encoding=False, 53 | temporal_position_encoding_max_len=24, 54 | temporal_attention_dim_div=1, 55 | zero_initialize=True, 56 | ): 57 | super().__init__() 58 | 59 | self.temporal_transformer = TemporalTransformer3DModel( 60 | in_channels=in_channels, 61 | num_attention_heads=num_attention_heads, 62 | attention_head_dim=in_channels 63 | // num_attention_heads 64 | // temporal_attention_dim_div, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_frame_attention_mode=cross_frame_attention_mode, 68 | temporal_position_encoding=temporal_position_encoding, 69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 70 | ) 71 | 72 | if zero_initialize: 73 | self.temporal_transformer.proj_out = zero_module( 74 | self.temporal_transformer.proj_out 75 | ) 76 | 77 | def forward( 78 | self, 79 | input_tensor, 80 | temb, 81 | encoder_hidden_states, 82 | attention_mask=None, 83 | anchor_frame_idx=None, 84 | ): 85 | hidden_states = input_tensor 86 | hidden_states = self.temporal_transformer( 87 | hidden_states, encoder_hidden_states, attention_mask 88 | ) 89 | 90 | output = hidden_states 91 | return output 92 | 93 | 94 | class TemporalTransformer3DModel(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels, 98 | num_attention_heads, 99 | attention_head_dim, 100 | num_layers, 101 | attention_block_types=( 102 | "Temporal_Self", 103 | "Temporal_Self", 104 | ), 105 | dropout=0.0, 106 | norm_num_groups=32, 107 | cross_attention_dim=768, 108 | activation_fn="geglu", 109 | attention_bias=False, 110 | upcast_attention=False, 111 | cross_frame_attention_mode=None, 112 | temporal_position_encoding=False, 113 | temporal_position_encoding_max_len=24, 114 | ): 115 | super().__init__() 116 | 117 | inner_dim = num_attention_heads * attention_head_dim 118 | 119 | self.norm = torch.nn.GroupNorm( 120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 121 | ) 122 | self.proj_in = nn.Linear(in_channels, inner_dim) 123 | 124 | self.transformer_blocks = nn.ModuleList( 125 | [ 126 | TemporalTransformerBlock( 127 | dim=inner_dim, 128 | num_attention_heads=num_attention_heads, 129 | attention_head_dim=attention_head_dim, 130 | attention_block_types=attention_block_types, 131 | dropout=dropout, 132 | norm_num_groups=norm_num_groups, 133 | cross_attention_dim=cross_attention_dim, 134 | activation_fn=activation_fn, 135 | attention_bias=attention_bias, 136 | upcast_attention=upcast_attention, 137 | cross_frame_attention_mode=cross_frame_attention_mode, 138 | temporal_position_encoding=temporal_position_encoding, 139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 140 | ) 141 | for d in range(num_layers) 142 | ] 143 | ) 144 | self.proj_out = nn.Linear(inner_dim, in_channels) 145 | 146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 147 | assert ( 148 | hidden_states.dim() == 5 149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 150 | video_length = hidden_states.shape[2] 151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 152 | 153 | batch, channel, height, weight = hidden_states.shape 154 | residual = hidden_states 155 | 156 | hidden_states = self.norm(hidden_states) 157 | inner_dim = hidden_states.shape[1] 158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 159 | batch, height * weight, inner_dim 160 | ) 161 | hidden_states = self.proj_in(hidden_states) 162 | 163 | # Transformer Blocks 164 | for block in self.transformer_blocks: 165 | hidden_states = block( 166 | hidden_states, 167 | encoder_hidden_states=encoder_hidden_states, 168 | video_length=video_length, 169 | ) 170 | 171 | # output 172 | hidden_states = self.proj_out(hidden_states) 173 | hidden_states = ( 174 | hidden_states.reshape(batch, height, weight, inner_dim) 175 | .permute(0, 3, 1, 2) 176 | .contiguous() 177 | ) 178 | 179 | output = hidden_states + residual 180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 181 | 182 | return output 183 | 184 | 185 | class TemporalTransformerBlock(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | num_attention_heads, 190 | attention_head_dim, 191 | attention_block_types=( 192 | "Temporal_Self", 193 | "Temporal_Self", 194 | ), 195 | dropout=0.0, 196 | norm_num_groups=32, 197 | cross_attention_dim=768, 198 | activation_fn="geglu", 199 | attention_bias=False, 200 | upcast_attention=False, 201 | cross_frame_attention_mode=None, 202 | temporal_position_encoding=False, 203 | temporal_position_encoding_max_len=24, 204 | ): 205 | super().__init__() 206 | 207 | attention_blocks = [] 208 | norms = [] 209 | 210 | for block_name in attention_block_types: 211 | attention_blocks.append( 212 | VersatileAttention( 213 | attention_mode=block_name.split("_")[0], 214 | cross_attention_dim=cross_attention_dim 215 | if block_name.endswith("_Cross") 216 | else None, 217 | query_dim=dim, 218 | heads=num_attention_heads, 219 | dim_head=attention_head_dim, 220 | dropout=dropout, 221 | bias=attention_bias, 222 | upcast_attention=upcast_attention, 223 | cross_frame_attention_mode=cross_frame_attention_mode, 224 | temporal_position_encoding=temporal_position_encoding, 225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 226 | ) 227 | ) 228 | norms.append(nn.LayerNorm(dim)) 229 | 230 | self.attention_blocks = nn.ModuleList(attention_blocks) 231 | self.norms = nn.ModuleList(norms) 232 | 233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 234 | self.ff_norm = nn.LayerNorm(dim) 235 | 236 | def forward( 237 | self, 238 | hidden_states, 239 | encoder_hidden_states=None, 240 | attention_mask=None, 241 | video_length=None, 242 | ): 243 | for attention_block, norm in zip(self.attention_blocks, self.norms): 244 | norm_hidden_states = norm(hidden_states) 245 | hidden_states = ( 246 | attention_block( 247 | norm_hidden_states, 248 | encoder_hidden_states=encoder_hidden_states 249 | if attention_block.is_cross_attention 250 | else None, 251 | video_length=video_length, 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 257 | 258 | output = hidden_states 259 | return output 260 | 261 | 262 | class PositionalEncoding(nn.Module): 263 | def __init__(self, d_model, dropout=0.0, max_len=24): 264 | super().__init__() 265 | self.dropout = nn.Dropout(p=dropout) 266 | position = torch.arange(max_len).unsqueeze(1) 267 | div_term = torch.exp( 268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 269 | ) 270 | pe = torch.zeros(1, max_len, d_model) 271 | pe[0, :, 0::2] = torch.sin(position * div_term) 272 | pe[0, :, 1::2] = torch.cos(position * div_term) 273 | self.register_buffer("pe", pe) 274 | 275 | def forward(self, x): 276 | x = x + self.pe[:, : x.size(1)] 277 | return self.dropout(x) 278 | 279 | 280 | class VersatileAttention(Attention): 281 | def __init__( 282 | self, 283 | attention_mode=None, 284 | cross_frame_attention_mode=None, 285 | temporal_position_encoding=False, 286 | temporal_position_encoding_max_len=24, 287 | *args, 288 | **kwargs, 289 | ): 290 | super().__init__(*args, **kwargs) 291 | assert attention_mode == "Temporal" 292 | 293 | self.attention_mode = attention_mode 294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 295 | 296 | self.pos_encoder = ( 297 | PositionalEncoding( 298 | kwargs["query_dim"], 299 | dropout=0.0, 300 | max_len=temporal_position_encoding_max_len, 301 | ) 302 | if (temporal_position_encoding and attention_mode == "Temporal") 303 | else None 304 | ) 305 | 306 | def extra_repr(self): 307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 308 | 309 | def set_use_memory_efficient_attention_xformers( 310 | self, 311 | use_memory_efficient_attention_xformers: bool, 312 | attention_op: Optional[Callable] = None, 313 | ): 314 | if use_memory_efficient_attention_xformers: 315 | if not is_xformers_available(): 316 | raise ModuleNotFoundError( 317 | ( 318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 319 | " xformers" 320 | ), 321 | name="xformers", 322 | ) 323 | elif not torch.cuda.is_available(): 324 | raise ValueError( 325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 326 | " only available for GPU " 327 | ) 328 | else: 329 | try: 330 | # Make sure we can run the memory efficient attention 331 | _ = xformers.ops.memory_efficient_attention( 332 | torch.randn((1, 2, 40), device="cuda"), 333 | torch.randn((1, 2, 40), device="cuda"), 334 | torch.randn((1, 2, 40), device="cuda"), 335 | ) 336 | except Exception as e: 337 | raise e 338 | 339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13. 340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13. 341 | # You don't need XFormersAttnProcessor here. 342 | # processor = XFormersAttnProcessor( 343 | # attention_op=attention_op, 344 | # ) 345 | processor = AttnProcessor() 346 | else: 347 | processor = AttnProcessor() 348 | 349 | self.set_processor(processor) 350 | 351 | def forward( 352 | self, 353 | hidden_states, 354 | encoder_hidden_states=None, 355 | attention_mask=None, 356 | video_length=None, 357 | **cross_attention_kwargs, 358 | ): 359 | if self.attention_mode == "Temporal": 360 | d = hidden_states.shape[1] # d means HxW 361 | hidden_states = rearrange( 362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 363 | ) 364 | 365 | if self.pos_encoder is not None: 366 | hidden_states = self.pos_encoder(hidden_states) 367 | 368 | encoder_hidden_states = ( 369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) 370 | if encoder_hidden_states is not None 371 | else encoder_hidden_states 372 | ) 373 | 374 | else: 375 | raise NotImplementedError 376 | 377 | hidden_states = self.processor( 378 | self, 379 | hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | 385 | if self.attention_mode == "Temporal": 386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 387 | 388 | return hidden_states 389 | -------------------------------------------------------------------------------- /src/models/mutual_self_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from src.models.attention import TemporalBasicTransformerBlock 8 | 9 | from .attention import BasicTransformerBlock 10 | 11 | 12 | def torch_dfs(model: torch.nn.Module): 13 | result = [model] 14 | for child in model.children(): 15 | result += torch_dfs(child) 16 | return result 17 | 18 | 19 | class ReferenceAttentionControl: 20 | def __init__( 21 | self, 22 | unet, 23 | mode="write", 24 | do_classifier_free_guidance=False, 25 | attention_auto_machine_weight=float("inf"), 26 | gn_auto_machine_weight=1.0, 27 | style_fidelity=1.0, 28 | reference_attn=True, 29 | reference_adain=False, 30 | fusion_blocks="midup", 31 | batch_size=1, 32 | ) -> None: 33 | # 10. Modify self attention and group norm 34 | self.unet = unet 35 | assert mode in ["read", "write"] 36 | assert fusion_blocks in ["midup", "full"] 37 | self.reference_attn = reference_attn 38 | self.reference_adain = reference_adain 39 | self.fusion_blocks = fusion_blocks 40 | self.register_reference_hooks( 41 | mode, 42 | do_classifier_free_guidance, 43 | attention_auto_machine_weight, 44 | gn_auto_machine_weight, 45 | style_fidelity, 46 | reference_attn, 47 | reference_adain, 48 | fusion_blocks, 49 | batch_size=batch_size, 50 | ) 51 | 52 | def register_reference_hooks( 53 | self, 54 | mode, 55 | do_classifier_free_guidance, 56 | attention_auto_machine_weight, 57 | gn_auto_machine_weight, 58 | style_fidelity, 59 | reference_attn, 60 | reference_adain, 61 | dtype=torch.float16, 62 | batch_size=1, 63 | num_images_per_prompt=1, 64 | device=torch.device("cpu"), 65 | fusion_blocks="midup", 66 | ): 67 | MODE = mode 68 | do_classifier_free_guidance = do_classifier_free_guidance 69 | attention_auto_machine_weight = attention_auto_machine_weight 70 | gn_auto_machine_weight = gn_auto_machine_weight 71 | style_fidelity = style_fidelity 72 | reference_attn = reference_attn 73 | reference_adain = reference_adain 74 | fusion_blocks = fusion_blocks 75 | num_images_per_prompt = num_images_per_prompt 76 | dtype = dtype 77 | if do_classifier_free_guidance: 78 | uc_mask = ( 79 | torch.Tensor( 80 | [1] * batch_size * num_images_per_prompt * 16 81 | + [0] * batch_size * num_images_per_prompt * 16 82 | ) 83 | .to(device) 84 | .bool() 85 | ) 86 | else: 87 | uc_mask = ( 88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 89 | .to(device) 90 | .bool() 91 | ) 92 | 93 | def hacked_basic_transformer_inner_forward( 94 | self, 95 | hidden_states: torch.FloatTensor, 96 | attention_mask: Optional[torch.FloatTensor] = None, 97 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 98 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 99 | timestep: Optional[torch.LongTensor] = None, 100 | cross_attention_kwargs: Dict[str, Any] = None, 101 | class_labels: Optional[torch.LongTensor] = None, 102 | video_length=None, 103 | self_attention_additional_feats=None, 104 | mode=None, 105 | ): 106 | if self.use_ada_layer_norm: # False 107 | norm_hidden_states = self.norm1(hidden_states, timestep) 108 | elif self.use_ada_layer_norm_zero: 109 | ( 110 | norm_hidden_states, 111 | gate_msa, 112 | shift_mlp, 113 | scale_mlp, 114 | gate_mlp, 115 | ) = self.norm1( 116 | hidden_states, 117 | timestep, 118 | class_labels, 119 | hidden_dtype=hidden_states.dtype, 120 | ) 121 | else: 122 | norm_hidden_states = self.norm1(hidden_states) 123 | 124 | # 1. Self-Attention 125 | # self.only_cross_attention = False 126 | cross_attention_kwargs = ( 127 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 128 | ) 129 | if self.only_cross_attention: 130 | attn_output = self.attn1( 131 | norm_hidden_states, 132 | encoder_hidden_states=encoder_hidden_states 133 | if self.only_cross_attention 134 | else None, 135 | attention_mask=attention_mask, 136 | **cross_attention_kwargs, 137 | ) 138 | else: 139 | if MODE == "write": 140 | self.bank.append(norm_hidden_states.clone()) 141 | attn_output = self.attn1( 142 | norm_hidden_states, 143 | encoder_hidden_states=encoder_hidden_states 144 | if self.only_cross_attention 145 | else None, 146 | attention_mask=attention_mask, 147 | **cross_attention_kwargs, 148 | ) 149 | if MODE == "read": 150 | bank_fea = [ 151 | rearrange( 152 | d.unsqueeze(1).repeat(1, video_length, 1, 1), 153 | "b t l c -> (b t) l c", 154 | ) 155 | for d in self.bank 156 | ] 157 | modify_norm_hidden_states = torch.cat( 158 | [norm_hidden_states] + bank_fea, dim=1 159 | ) 160 | hidden_states_uc = ( 161 | self.attn1( 162 | norm_hidden_states, 163 | encoder_hidden_states=modify_norm_hidden_states, 164 | attention_mask=attention_mask, 165 | ) 166 | + hidden_states 167 | ) 168 | if do_classifier_free_guidance: 169 | hidden_states_c = hidden_states_uc.clone() 170 | _uc_mask = uc_mask.clone() 171 | if hidden_states.shape[0] != _uc_mask.shape[0]: 172 | _uc_mask = ( 173 | torch.Tensor( 174 | [1] * (hidden_states.shape[0] // 2) 175 | + [0] * (hidden_states.shape[0] // 2) 176 | ) 177 | .to(device) 178 | .bool() 179 | ) 180 | hidden_states_c[_uc_mask] = ( 181 | self.attn1( 182 | norm_hidden_states[_uc_mask], 183 | encoder_hidden_states=norm_hidden_states[_uc_mask], 184 | attention_mask=attention_mask, 185 | ) 186 | + hidden_states[_uc_mask] 187 | ) 188 | hidden_states = hidden_states_c.clone() 189 | else: 190 | hidden_states = hidden_states_uc 191 | 192 | # self.bank.clear() 193 | if self.attn2 is not None: 194 | # Cross-Attention 195 | norm_hidden_states = ( 196 | self.norm2(hidden_states, timestep) 197 | if self.use_ada_layer_norm 198 | else self.norm2(hidden_states) 199 | ) 200 | hidden_states = ( 201 | self.attn2( 202 | norm_hidden_states, 203 | encoder_hidden_states=encoder_hidden_states, 204 | attention_mask=attention_mask, 205 | ) 206 | + hidden_states 207 | ) 208 | 209 | # Feed-forward 210 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 211 | 212 | # Temporal-Attention 213 | if self.unet_use_temporal_attention: 214 | d = hidden_states.shape[1] 215 | hidden_states = rearrange( 216 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 217 | ) 218 | norm_hidden_states = ( 219 | self.norm_temp(hidden_states, timestep) 220 | if self.use_ada_layer_norm 221 | else self.norm_temp(hidden_states) 222 | ) 223 | hidden_states = ( 224 | self.attn_temp(norm_hidden_states) + hidden_states 225 | ) 226 | hidden_states = rearrange( 227 | hidden_states, "(b d) f c -> (b f) d c", d=d 228 | ) 229 | 230 | return hidden_states 231 | 232 | if self.use_ada_layer_norm_zero: 233 | attn_output = gate_msa.unsqueeze(1) * attn_output 234 | hidden_states = attn_output + hidden_states 235 | 236 | if self.attn2 is not None: 237 | norm_hidden_states = ( 238 | self.norm2(hidden_states, timestep) 239 | if self.use_ada_layer_norm 240 | else self.norm2(hidden_states) 241 | ) 242 | 243 | # 2. Cross-Attention 244 | attn_output = self.attn2( 245 | norm_hidden_states, 246 | encoder_hidden_states=encoder_hidden_states, 247 | attention_mask=encoder_attention_mask, 248 | **cross_attention_kwargs, 249 | ) 250 | hidden_states = attn_output + hidden_states 251 | 252 | # 3. Feed-forward 253 | norm_hidden_states = self.norm3(hidden_states) 254 | 255 | if self.use_ada_layer_norm_zero: 256 | norm_hidden_states = ( 257 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 258 | ) 259 | 260 | ff_output = self.ff(norm_hidden_states) 261 | 262 | if self.use_ada_layer_norm_zero: 263 | ff_output = gate_mlp.unsqueeze(1) * ff_output 264 | 265 | hidden_states = ff_output + hidden_states 266 | 267 | return hidden_states 268 | 269 | if self.reference_attn: 270 | if self.fusion_blocks == "midup": 271 | attn_modules = [ 272 | module 273 | for module in ( 274 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 275 | ) 276 | if isinstance(module, BasicTransformerBlock) 277 | or isinstance(module, TemporalBasicTransformerBlock) 278 | ] 279 | elif self.fusion_blocks == "full": 280 | attn_modules = [ 281 | module 282 | for module in torch_dfs(self.unet) 283 | if isinstance(module, BasicTransformerBlock) 284 | or isinstance(module, TemporalBasicTransformerBlock) 285 | ] 286 | attn_modules = sorted( 287 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 288 | ) 289 | 290 | for i, module in enumerate(attn_modules): 291 | module._original_inner_forward = module.forward 292 | if isinstance(module, BasicTransformerBlock): 293 | module.forward = hacked_basic_transformer_inner_forward.__get__( 294 | module, BasicTransformerBlock 295 | ) 296 | if isinstance(module, TemporalBasicTransformerBlock): 297 | module.forward = hacked_basic_transformer_inner_forward.__get__( 298 | module, TemporalBasicTransformerBlock 299 | ) 300 | 301 | module.bank = [] 302 | module.attn_weight = float(i) / float(len(attn_modules)) 303 | 304 | def update(self, writer, dtype=torch.float16): 305 | if self.reference_attn: 306 | if self.fusion_blocks == "midup": 307 | reader_attn_modules = [ 308 | module 309 | for module in ( 310 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 311 | ) 312 | if isinstance(module, TemporalBasicTransformerBlock) 313 | ] 314 | writer_attn_modules = [ 315 | module 316 | for module in ( 317 | torch_dfs(writer.unet.mid_block) 318 | + torch_dfs(writer.unet.up_blocks) 319 | ) 320 | if isinstance(module, BasicTransformerBlock) 321 | ] 322 | elif self.fusion_blocks == "full": 323 | reader_attn_modules = [ 324 | module 325 | for module in torch_dfs(self.unet) 326 | if isinstance(module, TemporalBasicTransformerBlock) 327 | ] 328 | writer_attn_modules = [ 329 | module 330 | for module in torch_dfs(writer.unet) 331 | if isinstance(module, BasicTransformerBlock) 332 | ] 333 | reader_attn_modules = sorted( 334 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 335 | ) 336 | writer_attn_modules = sorted( 337 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 338 | ) 339 | for r, w in zip(reader_attn_modules, writer_attn_modules): 340 | r.bank = [v.clone().to(dtype) for v in w.bank] 341 | # w.bank.clear() 342 | 343 | def clear(self): 344 | if self.reference_attn: 345 | if self.fusion_blocks == "midup": 346 | reader_attn_modules = [ 347 | module 348 | for module in ( 349 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 350 | ) 351 | if isinstance(module, BasicTransformerBlock) 352 | or isinstance(module, TemporalBasicTransformerBlock) 353 | ] 354 | elif self.fusion_blocks == "full": 355 | reader_attn_modules = [ 356 | module 357 | for module in torch_dfs(self.unet) 358 | if isinstance(module, BasicTransformerBlock) 359 | or isinstance(module, TemporalBasicTransformerBlock) 360 | ] 361 | reader_attn_modules = sorted( 362 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 363 | ) 364 | for r in reader_attn_modules: 365 | r.bank.clear() 366 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | class InflatedConv3d(nn.Conv2d): 10 | def forward(self, x): 11 | video_length = x.shape[2] 12 | 13 | x = rearrange(x, "b c f h w -> (b f) c h w") 14 | x = super().forward(x) 15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 16 | 17 | return x 18 | 19 | 20 | class InflatedGroupNorm(nn.GroupNorm): 21 | def forward(self, x): 22 | video_length = x.shape[2] 23 | 24 | x = rearrange(x, "b c f h w -> (b f) c h w") 25 | x = super().forward(x) 26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 27 | 28 | return x 29 | 30 | 31 | class Upsample3D(nn.Module): 32 | def __init__( 33 | self, 34 | channels, 35 | use_conv=False, 36 | use_conv_transpose=False, 37 | out_channels=None, 38 | name="conv", 39 | ): 40 | super().__init__() 41 | self.channels = channels 42 | self.out_channels = out_channels or channels 43 | self.use_conv = use_conv 44 | self.use_conv_transpose = use_conv_transpose 45 | self.name = name 46 | 47 | conv = None 48 | if use_conv_transpose: 49 | raise NotImplementedError 50 | elif use_conv: 51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 52 | 53 | def forward(self, hidden_states, output_size=None): 54 | assert hidden_states.shape[1] == self.channels 55 | 56 | if self.use_conv_transpose: 57 | raise NotImplementedError 58 | 59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 60 | dtype = hidden_states.dtype 61 | if dtype == torch.bfloat16: 62 | hidden_states = hidden_states.to(torch.float32) 63 | 64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 65 | if hidden_states.shape[0] >= 64: 66 | hidden_states = hidden_states.contiguous() 67 | 68 | # if `output_size` is passed we force the interpolation output 69 | # size and do not make use of `scale_factor=2` 70 | if output_size is None: 71 | hidden_states = F.interpolate( 72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 73 | ) 74 | else: 75 | hidden_states = F.interpolate( 76 | hidden_states, size=output_size, mode="nearest" 77 | ) 78 | 79 | # If the input is bfloat16, we cast back to bfloat16 80 | if dtype == torch.bfloat16: 81 | hidden_states = hidden_states.to(dtype) 82 | 83 | # if self.use_conv: 84 | # if self.name == "conv": 85 | # hidden_states = self.conv(hidden_states) 86 | # else: 87 | # hidden_states = self.Conv2d_0(hidden_states) 88 | hidden_states = self.conv(hidden_states) 89 | 90 | return hidden_states 91 | 92 | 93 | class Downsample3D(nn.Module): 94 | def __init__( 95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 96 | ): 97 | super().__init__() 98 | self.channels = channels 99 | self.out_channels = out_channels or channels 100 | self.use_conv = use_conv 101 | self.padding = padding 102 | stride = 2 103 | self.name = name 104 | 105 | if use_conv: 106 | self.conv = InflatedConv3d( 107 | self.channels, self.out_channels, 3, stride=stride, padding=padding 108 | ) 109 | else: 110 | raise NotImplementedError 111 | 112 | def forward(self, hidden_states): 113 | assert hidden_states.shape[1] == self.channels 114 | if self.use_conv and self.padding == 0: 115 | raise NotImplementedError 116 | 117 | assert hidden_states.shape[1] == self.channels 118 | hidden_states = self.conv(hidden_states) 119 | 120 | return hidden_states 121 | 122 | 123 | class ResnetBlock3D(nn.Module): 124 | def __init__( 125 | self, 126 | *, 127 | in_channels, 128 | out_channels=None, 129 | conv_shortcut=False, 130 | dropout=0.0, 131 | temb_channels=512, 132 | groups=32, 133 | groups_out=None, 134 | pre_norm=True, 135 | eps=1e-6, 136 | non_linearity="swish", 137 | time_embedding_norm="default", 138 | output_scale_factor=1.0, 139 | use_in_shortcut=None, 140 | use_inflated_groupnorm=None, 141 | ): 142 | super().__init__() 143 | self.pre_norm = pre_norm 144 | self.pre_norm = True 145 | self.in_channels = in_channels 146 | out_channels = in_channels if out_channels is None else out_channels 147 | self.out_channels = out_channels 148 | self.use_conv_shortcut = conv_shortcut 149 | self.time_embedding_norm = time_embedding_norm 150 | self.output_scale_factor = output_scale_factor 151 | 152 | if groups_out is None: 153 | groups_out = groups 154 | 155 | assert use_inflated_groupnorm != None 156 | if use_inflated_groupnorm: 157 | self.norm1 = InflatedGroupNorm( 158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 159 | ) 160 | else: 161 | self.norm1 = torch.nn.GroupNorm( 162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 163 | ) 164 | 165 | self.conv1 = InflatedConv3d( 166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | if temb_channels is not None: 170 | if self.time_embedding_norm == "default": 171 | time_emb_proj_out_channels = out_channels 172 | elif self.time_embedding_norm == "scale_shift": 173 | time_emb_proj_out_channels = out_channels * 2 174 | else: 175 | raise ValueError( 176 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 177 | ) 178 | 179 | self.time_emb_proj = torch.nn.Linear( 180 | temb_channels, time_emb_proj_out_channels 181 | ) 182 | else: 183 | self.time_emb_proj = None 184 | 185 | if use_inflated_groupnorm: 186 | self.norm2 = InflatedGroupNorm( 187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 188 | ) 189 | else: 190 | self.norm2 = torch.nn.GroupNorm( 191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 192 | ) 193 | self.dropout = torch.nn.Dropout(dropout) 194 | self.conv2 = InflatedConv3d( 195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 196 | ) 197 | 198 | if non_linearity == "swish": 199 | self.nonlinearity = lambda x: F.silu(x) 200 | elif non_linearity == "mish": 201 | self.nonlinearity = Mish() 202 | elif non_linearity == "silu": 203 | self.nonlinearity = nn.SiLU() 204 | 205 | self.use_in_shortcut = ( 206 | self.in_channels != self.out_channels 207 | if use_in_shortcut is None 208 | else use_in_shortcut 209 | ) 210 | 211 | self.conv_shortcut = None 212 | if self.use_in_shortcut: 213 | self.conv_shortcut = InflatedConv3d( 214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 215 | ) 216 | 217 | def forward(self, input_tensor, temb): 218 | hidden_states = input_tensor 219 | 220 | hidden_states = self.norm1(hidden_states) 221 | hidden_states = self.nonlinearity(hidden_states) 222 | 223 | hidden_states = self.conv1(hidden_states) 224 | 225 | if temb is not None: 226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 227 | 228 | if temb is not None and self.time_embedding_norm == "default": 229 | hidden_states = hidden_states + temb 230 | 231 | hidden_states = self.norm2(hidden_states) 232 | 233 | if temb is not None and self.time_embedding_norm == "scale_shift": 234 | scale, shift = torch.chunk(temb, 2, dim=1) 235 | hidden_states = hidden_states * (1 + scale) + shift 236 | 237 | hidden_states = self.nonlinearity(hidden_states) 238 | 239 | hidden_states = self.dropout(hidden_states) 240 | hidden_states = self.conv2(hidden_states) 241 | 242 | if self.conv_shortcut is not None: 243 | input_tensor = self.conv_shortcut(input_tensor) 244 | 245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 246 | 247 | return output_tensor 248 | 249 | 250 | class Mish(torch.nn.Module): 251 | def forward(self, hidden_states): 252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 253 | -------------------------------------------------------------------------------- /src/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from .attention import TemporalBasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer3DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | 19 | 20 | if is_xformers_available(): 21 | import xformers 22 | import xformers.ops 23 | else: 24 | xformers = None 25 | 26 | 27 | class Transformer3DModel(ModelMixin, ConfigMixin): 28 | _supports_gradient_checkpointing = True 29 | 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | name=None, 49 | ): 50 | super().__init__() 51 | self.use_linear_projection = use_linear_projection 52 | self.num_attention_heads = num_attention_heads 53 | self.attention_head_dim = attention_head_dim 54 | inner_dim = num_attention_heads * attention_head_dim 55 | 56 | # Define input layers 57 | self.in_channels = in_channels 58 | self.name=name 59 | 60 | self.norm = torch.nn.GroupNorm( 61 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 62 | ) 63 | if use_linear_projection: 64 | self.proj_in = nn.Linear(in_channels, inner_dim) 65 | else: 66 | self.proj_in = nn.Conv2d( 67 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 68 | ) 69 | 70 | # Define transformers blocks 71 | self.transformer_blocks = nn.ModuleList( 72 | [ 73 | TemporalBasicTransformerBlock( 74 | inner_dim, 75 | num_attention_heads, 76 | attention_head_dim, 77 | dropout=dropout, 78 | cross_attention_dim=cross_attention_dim, 79 | activation_fn=activation_fn, 80 | num_embeds_ada_norm=num_embeds_ada_norm, 81 | attention_bias=attention_bias, 82 | only_cross_attention=only_cross_attention, 83 | upcast_attention=upcast_attention, 84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 85 | unet_use_temporal_attention=unet_use_temporal_attention, 86 | name=f"{self.name}_{d}_TransformerBlock" if self.name else None, 87 | ) 88 | for d in range(num_layers) 89 | ] 90 | ) 91 | 92 | # 4. Define output layers 93 | if use_linear_projection: 94 | self.proj_out = nn.Linear(in_channels, inner_dim) 95 | else: 96 | self.proj_out = nn.Conv2d( 97 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 98 | ) 99 | 100 | self.gradient_checkpointing = False 101 | 102 | def _set_gradient_checkpointing(self, module, value=False): 103 | if hasattr(module, "gradient_checkpointing"): 104 | module.gradient_checkpointing = value 105 | 106 | def forward( 107 | self, 108 | hidden_states, 109 | encoder_hidden_states=None, 110 | self_attention_additional_feats=None, 111 | mode=None, 112 | timestep=None, 113 | return_dict: bool = True, 114 | ): 115 | # Input 116 | assert ( 117 | hidden_states.dim() == 5 118 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 119 | video_length = hidden_states.shape[2] 120 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 121 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 122 | encoder_hidden_states = repeat( 123 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 124 | ) 125 | 126 | batch, channel, height, weight = hidden_states.shape 127 | residual = hidden_states 128 | 129 | hidden_states = self.norm(hidden_states) 130 | if not self.use_linear_projection: 131 | hidden_states = self.proj_in(hidden_states) 132 | inner_dim = hidden_states.shape[1] 133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 134 | batch, height * weight, inner_dim 135 | ) 136 | else: 137 | inner_dim = hidden_states.shape[1] 138 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 139 | batch, height * weight, inner_dim 140 | ) 141 | hidden_states = self.proj_in(hidden_states) 142 | 143 | # Blocks 144 | for i, block in enumerate(self.transformer_blocks): 145 | 146 | if self.training and self.gradient_checkpointing: 147 | 148 | def create_custom_forward(module, return_dict=None): 149 | def custom_forward(*inputs): 150 | if return_dict is not None: 151 | return module(*inputs, return_dict=return_dict) 152 | else: 153 | return module(*inputs) 154 | 155 | return custom_forward 156 | 157 | # if hasattr(self.block, 'bank') and len(self.block.bank) > 0: 158 | # hidden_states 159 | hidden_states = torch.utils.checkpoint.checkpoint( 160 | create_custom_forward(block), 161 | hidden_states, 162 | encoder_hidden_states=encoder_hidden_states, 163 | timestep=timestep, 164 | attention_mask=None, 165 | video_length=video_length, 166 | self_attention_additional_feats=self_attention_additional_feats, 167 | mode=mode, 168 | ) 169 | else: 170 | 171 | hidden_states = block( 172 | hidden_states, 173 | encoder_hidden_states=encoder_hidden_states, 174 | timestep=timestep, 175 | self_attention_additional_feats=self_attention_additional_feats, 176 | mode=mode, 177 | video_length=video_length, 178 | ) 179 | 180 | # Output 181 | if not self.use_linear_projection: 182 | hidden_states = ( 183 | hidden_states.reshape(batch, height, weight, inner_dim) 184 | .permute(0, 3, 1, 2) 185 | .contiguous() 186 | ) 187 | hidden_states = self.proj_out(hidden_states) 188 | else: 189 | hidden_states = self.proj_out(hidden_states) 190 | hidden_states = ( 191 | hidden_states.reshape(batch, height, weight, inner_dim) 192 | .permute(0, 3, 1, 2) 193 | .contiguous() 194 | ) 195 | 196 | output = hidden_states + residual 197 | 198 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 199 | if not return_dict: 200 | return (output,) 201 | 202 | return Transformer3DModelOutput(sample=output) 203 | -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__init__.py -------------------------------------------------------------------------------- /src/pipelines/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/__pycache__/pipeline_motion2vid_merge_infer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__pycache__/pipeline_motion2vid_merge_infer.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | tensor_interpolation = None 4 | 5 | 6 | def get_tensor_interpolation_method(): 7 | return tensor_interpolation 8 | 9 | 10 | def set_tensor_interpolation_method(is_slerp): 11 | global tensor_interpolation 12 | tensor_interpolation = slerp if is_slerp else linear 13 | 14 | 15 | def linear(v1, v2, t): 16 | return (1.0 - t) * v1 + t * v2 17 | 18 | 19 | def slerp( 20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 21 | ) -> torch.Tensor: 22 | u0 = v0 / v0.norm() 23 | u1 = v1 / v1.norm() 24 | dot = (u0 * u1).sum() 25 | if dot.abs() > DOT_THRESHOLD: 26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 27 | return (1.0 - t) * v0 + t * v1 28 | omega = dot.acos() 29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 30 | -------------------------------------------------------------------------------- /src/utils/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/utils/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/visualizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/utils/__pycache__/visualizer.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import sys 6 | from pathlib import Path 7 | 8 | import av 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | import torch.distributed as dist 13 | from einops import rearrange 14 | from PIL import Image 15 | 16 | 17 | def seed_everything(seed): 18 | import random 19 | 20 | import numpy as np 21 | 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | np.random.seed(seed % (2**32)) 25 | random.seed(seed) 26 | 27 | 28 | def import_filename(filename): 29 | spec = importlib.util.spec_from_file_location("mymodule", filename) 30 | module = importlib.util.module_from_spec(spec) 31 | sys.modules[spec.name] = module 32 | spec.loader.exec_module(module) 33 | return module 34 | 35 | 36 | def delete_additional_ckpt(base_path, num_keep): 37 | dirs = [] 38 | for d in os.listdir(base_path): 39 | if d.startswith("checkpoint-"): 40 | dirs.append(d) 41 | num_tot = len(dirs) 42 | if num_tot <= num_keep: 43 | return 44 | # ensure ckpt is sorted and delete the ealier! 45 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] 46 | for d in del_dirs: 47 | path_to_dir = osp.join(base_path, d) 48 | if osp.exists(path_to_dir): 49 | shutil.rmtree(path_to_dir) 50 | 51 | 52 | def save_videos_from_pil(pil_images, path, fps=8): 53 | import av 54 | 55 | save_fmt = Path(path).suffix 56 | os.makedirs(os.path.dirname(path), exist_ok=True) 57 | width, height = pil_images[0].size 58 | 59 | if save_fmt == ".mp4": 60 | codec = "libx264" 61 | container = av.open(path, "w") 62 | stream = container.add_stream(codec, rate=fps) 63 | 64 | stream.width = width 65 | stream.height = height 66 | 67 | for pil_image in pil_images: 68 | # pil_image = Image.fromarray(image_arr).convert("RGB") 69 | av_frame = av.VideoFrame.from_image(pil_image) 70 | container.mux(stream.encode(av_frame)) 71 | container.mux(stream.encode()) 72 | container.close() 73 | 74 | elif save_fmt == ".gif": 75 | pil_images[0].save( 76 | fp=path, 77 | format="GIF", 78 | append_images=pil_images[1:], 79 | save_all=True, 80 | duration=(1 / fps * 1000), 81 | loop=0, 82 | ) 83 | else: 84 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 85 | 86 | 87 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=2, fps=8): 88 | videos = rearrange(videos, "b c t h w -> t b c h w") 89 | height, width = videos.shape[-2:] 90 | outputs = [] 91 | 92 | for x in videos: 93 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 94 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 95 | if rescale: 96 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 97 | x = (x * 255).numpy().astype(np.uint8) 98 | x = Image.fromarray(x) 99 | 100 | outputs.append(x) 101 | 102 | os.makedirs(os.path.dirname(path), exist_ok=True) 103 | 104 | save_videos_from_pil(outputs, path, fps) 105 | 106 | 107 | def read_frames(video_path): 108 | container = av.open(video_path) 109 | 110 | video_stream = next(s for s in container.streams if s.type == "video") 111 | frames = [] 112 | for packet in container.demux(video_stream): 113 | for frame in packet.decode(): 114 | image = Image.frombytes( 115 | "RGB", 116 | (frame.width, frame.height), 117 | frame.to_rgb().to_ndarray(), 118 | ) 119 | frames.append(image) 120 | 121 | return frames 122 | 123 | 124 | def get_fps(video_path): 125 | container = av.open(video_path) 126 | video_stream = next(s for s in container.streams if s.type == "video") 127 | fps = video_stream.average_rate 128 | container.close() 129 | return fps 130 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from safetensors import safe_open 11 | from tqdm import tqdm 12 | from scipy.interpolate import PchipInterpolator 13 | from einops import rearrange 14 | from decord import VideoReader 15 | 16 | 17 | 18 | def interpolate_trajectory(points, n_points): 19 | if len(points) == n_points: 20 | return points 21 | 22 | if len(points) == 1: 23 | points = [points[0] for _ in range(n_points)] 24 | return points 25 | 26 | x = [point[0] for point in points] 27 | y = [point[1] for point in points] 28 | 29 | t = np.linspace(0, 1, len(points)) 30 | 31 | fx = PchipInterpolator(t, x) 32 | fy = PchipInterpolator(t, y) 33 | 34 | new_t = np.linspace(0, 1, n_points) 35 | 36 | new_x = fx(new_t) 37 | new_y = fy(new_t) 38 | new_points = list(zip(new_x, new_y)) 39 | 40 | return new_points 41 | 42 | 43 | def interpolate_trajectory_3d(points, n_points): 44 | if len(points) == n_points: 45 | return points 46 | 47 | if len(points) == 1: 48 | points = [points[0] for _ in range(n_points)] 49 | return points 50 | 51 | x = [point[0] for point in points] 52 | y = [point[1] for point in points] 53 | z = [point[2] for point in points] 54 | 55 | t = np.linspace(0, 1, len(points)) 56 | 57 | fx = PchipInterpolator(t, x) 58 | fy = PchipInterpolator(t, y) 59 | fz = PchipInterpolator(t, z) 60 | 61 | new_t = np.linspace(0, 1, n_points) 62 | 63 | new_x = fx(new_t) 64 | new_y = fy(new_t) 65 | new_z = fz(new_t) 66 | new_points = list(zip(new_x, new_y, new_z)) 67 | 68 | return new_points 69 | 70 | 71 | def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): 72 | """Generate a bivariate isotropic or anisotropic Gaussian kernel. 73 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. 74 | Args: 75 | kernel_size (int): 76 | sig_x (float): 77 | sig_y (float): 78 | theta (float): Radian measurement. 79 | grid (ndarray, optional): generated by :func:`mesh_grid`, 80 | with the shape (K, K, 2), K is the kernel size. Default: None 81 | isotropic (bool): 82 | Returns: 83 | kernel (ndarray): normalized kernel. 84 | """ 85 | if grid is None: 86 | grid, _, _ = mesh_grid(kernel_size) 87 | if isotropic: 88 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) 89 | else: 90 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) 91 | kernel = pdf2(sigma_matrix, grid) 92 | kernel = kernel / np.sum(kernel) 93 | return kernel 94 | 95 | 96 | def mesh_grid(kernel_size): 97 | """Generate the mesh grid, centering at zero. 98 | Args: 99 | kernel_size (int): 100 | Returns: 101 | xy (ndarray): with the shape (kernel_size, kernel_size, 2) 102 | xx (ndarray): with the shape (kernel_size, kernel_size) 103 | yy (ndarray): with the shape (kernel_size, kernel_size) 104 | """ 105 | ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) 106 | xx, yy = np.meshgrid(ax, ax) 107 | xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, 108 | 1))).reshape(kernel_size, kernel_size, 2) 109 | return xy, xx, yy 110 | 111 | 112 | def pdf2(sigma_matrix, grid): 113 | """Calculate PDF of the bivariate Gaussian distribution. 114 | Args: 115 | sigma_matrix (ndarray): with the shape (2, 2) 116 | grid (ndarray): generated by :func:`mesh_grid`, 117 | with the shape (K, K, 2), K is the kernel size. 118 | Returns: 119 | kernel (ndarrray): un-normalized kernel. 120 | """ 121 | inverse_sigma = np.linalg.inv(sigma_matrix) 122 | kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) 123 | return kernel 124 | --------------------------------------------------------------------------------