├── .DS_Store
├── LICENSE
├── README.md
├── assets
└── teaser.png
├── cam_utils.py
├── configs
├── .DS_Store
├── eval.yaml
└── inference
│ └── inference_v2.yaml
├── depth_anything_v2
├── depth_anything_v2
│ ├── __pycache__
│ │ ├── dinov2.cpython-310.pyc
│ │ └── dpt.cpython-310.pyc
│ ├── dinov2.py
│ ├── dinov2_layers
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── attention.cpython-310.pyc
│ │ │ ├── block.cpython-310.pyc
│ │ │ ├── drop_path.cpython-310.pyc
│ │ │ ├── layer_scale.cpython-310.pyc
│ │ │ ├── mlp.cpython-310.pyc
│ │ │ ├── patch_embed.cpython-310.pyc
│ │ │ └── swiglu_ffn.cpython-310.pyc
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── drop_path.py
│ │ ├── layer_scale.py
│ │ ├── mlp.py
│ │ ├── patch_embed.py
│ │ └── swiglu_ffn.py
│ ├── dpt.py
│ └── util
│ │ ├── __pycache__
│ │ ├── blocks.cpython-310.pyc
│ │ └── transform.cpython-310.pyc
│ │ ├── blocks.py
│ │ └── transform.py
└── metric_depth_estimation.py
├── examples
├── DollyOut.txt
├── Still.txt
├── TiltDown.txt
├── backview.jpeg
├── backview.json
├── balloon.json
├── balloon.png
├── balloons.json
├── balloons.png
├── bamboo.json
└── bamboo.png
├── inference.py
├── mesh
├── world_envelope.mtl
├── world_envelope.obj
└── world_envelope.png
├── requirements.txt
└── src
├── .DS_Store
├── __init__.py
├── __pycache__
└── __init__.cpython-310.pyc
├── models
├── __pycache__
│ ├── attention.cpython-310.pyc
│ ├── fusion_module.cpython-310.pyc
│ ├── gated_self_attention.cpython-310.pyc
│ ├── motion_encoder.cpython-310.pyc
│ ├── motion_module.cpython-310.pyc
│ ├── mutual_self_attention.cpython-310.pyc
│ ├── pose_guider.cpython-310.pyc
│ ├── resnet.cpython-310.pyc
│ ├── transformer_2d.cpython-310.pyc
│ ├── transformer_3d.cpython-310.pyc
│ ├── unet_2d_blocks.cpython-310.pyc
│ ├── unet_2d_condition.cpython-310.pyc
│ ├── unet_3d.cpython-310.pyc
│ └── unet_3d_blocks.cpython-310.pyc
├── attention.py
├── fusion_module.py
├── motion_encoder.py
├── motion_module.py
├── mutual_self_attention.py
├── resnet.py
├── transformer_2d.py
├── transformer_3d.py
├── unet_2d_blocks.py
├── unet_2d_condition.py
├── unet_3d.py
└── unet_3d_blocks.py
├── pipelines
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── pipeline_motion2vid_merge_infer.cpython-310.pyc
│ └── utils.cpython-310.pyc
├── pipeline_motion2vid_merge_infer.py
└── utils.py
└── utils
├── __pycache__
├── util.cpython-310.pyc
├── utils.cpython-310.pyc
└── visualizer.cpython-310.pyc
├── util.py
├── utils.py
└── visualizer.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Perception-as-Control
2 | Official implementation of "Perception-as-Control: Fine-grained Controllable Image Animation with 3D-aware Motion Representation"
3 |
4 |
5 |
6 |
7 |
8 | ### [Project page](https://chen-yingjie.github.io/projects/Perception-as-Control/index.html) | [Paper](https://arxiv.org/abs/2501.05020) | [Video](https://comming_soon) | [Online Demo](https://comming_soon)
9 |
10 | **Perception-as-Control: Fine-grained Controllable Image Animation with 3D-aware Motion Representation**
11 | [Yingjie Chen](https://chen-yingjie.github.io/),
12 | [Yifang Men](https://menyifang.github.io/),
13 | [Yuan Yao](mailto:yaoy92@gmail.com),
14 | [Miaomiao Cui](mailto:miaomiao.cmm@alibaba-inc.com),
15 | [Liefeng Bo](https://scholar.google.com/citations?user=FJwtMf0AAAAJ&hl=en)
16 |
17 | ## 💡 Abstract
18 | Motion-controllable image animation is a fundamental task with a wide range of potential applications. Recent works have made progress in controlling camera or object motion via the same 2D motion representations or different control signals, while they still struggle in supporting collaborative camera and object motion control with adaptive control granularity. To this end, we introduce 3D-aware motion representation and propose an image animation framework, called Perception-as-Control, to achieve fine-grained collaborative motion control. Specifically, we construct 3D-aware motion representation from a reference image, manipulate it based on interpreted user intentions, and perceive it from different viewpoints. In this way, camera and object motions are transformed into intuitive, consistent visual changes. Then, the proposed framework leverages the perception results as motion control signals, enabling it to support various motion-related video synthesis tasks in a unified and flexible way. Experiments demonstrate the superiority of the proposed method.
19 |
20 | ## 🔥 Updates
21 | - (2025-03-31) We release the inference code and model weights of Perception-as-Control.
22 | - (2025-03-10) We update a new version of paper with more details.
23 | - (2025-01-09) The project page, demo video and technical report are released. The full paper version with more details is in process.
24 |
25 | ## Usage
26 | ### Environment
27 | ```shell
28 | $ pip install -r requirements.txt
29 | ```
30 | ### Pretrained Weights
31 | 1. Download [pretrained weights](https://drive.google.com/drive/folders/1ZncmHG9K_n1BjGhVzQemomWzxQ60bYqg?usp=drive_link) and put them in `$INSTALL_DIR/pretrained_weights`.
32 |
33 | 2. Download pretrained weight of based models and put them in `$INSTALL_DIR/pretrained_weights`:
34 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
35 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)
36 |
37 | The pretrained weights are organized as follows:
38 | ```text
39 | ./pretrained_weights/
40 | |-- denoising_unet.pth
41 | |-- reference_unet.pth
42 | |-- cam_encoder.pth
43 | |-- obj_encoder.pth
44 | |-- sd-vae-ft-mse
45 | | |-- ...
46 | |-- sd-image-variations-diffusers
47 | | |-- ...
48 | ```
49 |
50 | ### Inference
51 | ```shell
52 | $ python inference.py
53 | ```
54 | The results will be saved in `$INSTALL_DIR/outputs`.
55 |
56 | ## 🎥 Demo
57 |
58 | ### Fine-grained collaborative motion control
59 |
60 |
61 |
62 | Camera Motion Control
63 |
64 |
65 | |
66 | Object Motion Control
67 |
68 |
69 | |
70 | Collaborative Motion Control
71 |
72 |
73 | |
74 |
75 |
76 |
77 |
78 |
79 | |
80 |
81 |
82 | |
83 |
84 |
85 | |
86 |
87 |
88 |
89 |
90 |
91 | ### Potential applications
92 |
93 |
94 |
95 |
96 |
97 | Motion Generation
98 |
99 | |
100 |
101 | Motion Clone
102 |
103 | |
104 |
105 |
106 |
107 |
108 | Motion Transfer
109 |
110 | |
111 |
112 | Motion Editing
113 |
114 | |
115 |
116 |
117 |
118 |
119 | For more details, please refer to our [project page](https://chen-yingjie.github.io/projects/Perception-as-Control/index.html).
120 |
121 |
122 | ## 📑 TODO List
123 | - [x] Release inference code and model weights
124 | - [ ] Provide a Gradio demo
125 | - [ ] Release training code
126 |
127 | ## 🔗 Citation
128 |
129 | If you find this code useful for your research, please use the following BibTeX entry.
130 |
131 | ```bibtex
132 | @inproceedings{chen2025perception,
133 | title={Perception-as-Control: Fine-grained Controllable Image Animation with 3D-aware Motion Representation},
134 | author={Chen, Yingjie and Men, Yifang and Yao, Yuan and Cui, Miaomiao and Bo, Liefeng},
135 | journal={arXiv preprint arXiv:2501.05020},
136 | website={https://chen-yingjie.github.io/projects/Perception-as-Control/index.html},
137 | year={2025}}
138 | ```
139 |
140 | ## Acknowledgements
141 |
142 | We would like to thank the contributors to [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone), [Depth-Anything-V2](https://github.com/DepthAnything/Depth-Anything-V2), [SpaTracker](https://github.com/henry123-boy/SpaTracker), [Tartanvo](https://github.com/castacks/tartanvo), [diffusers](https://github.com/huggingface/diffusers) for their open research and exploration.
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/assets/teaser.png
--------------------------------------------------------------------------------
/cam_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import copy
5 | import math
6 | import shutil
7 | import random
8 | import argparse
9 | import subprocess
10 | from PIL import Image
11 | from scipy.spatial.transform import Rotation as R
12 | from scipy.interpolate import interp1d
13 |
14 | import torch
15 | import numpy as np
16 | # import open3d as o3d
17 | import pytorch3d
18 | from pytorch3d.structures import Meshes
19 | from pytorch3d.renderer import (look_at_view_transform, FoVPerspectiveCameras,
20 | PointLights, RasterizationSettings, PointsRasterizer, PointsRasterizationSettings,
21 | MeshRenderer, MeshRasterizer, SoftPhongShader, PointsRenderer, AlphaCompositor,
22 | blending, Textures)
23 | from pytorch3d.io import load_obj
24 |
25 | from depth_anything_v2.metric_depth_estimation import depth_estimation
26 |
27 |
28 | def draw_points(image, points, output_path):
29 | image = image.copy()
30 | for point in points:
31 | u, v = map(int, point[:2])
32 | image = cv2.circle(image, (u, v), 5, (0, 0, 255), -1)
33 |
34 | return image
35 |
36 |
37 | class Render():
38 | def __init__(self, img_size=(320, 576), focal=100, device='cpu',
39 | ply_path=None, uvmap_obj_path=None, is_pointcloud=False):
40 |
41 | self.img_size = img_size
42 | self.focal = focal
43 | self.device = device
44 | self.is_pointcloud = is_pointcloud
45 |
46 | if self.is_pointcloud:
47 | if ply_path != None:
48 | self.model = self.load_point_cloud(ply_path)
49 | else:
50 | if uvmap_obj_path != None:
51 | self.model = self.load_mesh(uvmap_obj_path)
52 |
53 | self.set_renderer()
54 | self.get_renderer()
55 |
56 | def get_points(self):
57 | return self.vts.detach().clone().cpu().numpy()
58 |
59 | def load_point_cloud(self, ply_path):
60 | ply = o3d.io.read_point_cloud(ply_path)
61 | self.vts = torch.Tensor(np.asarray(ply.points)).to(self.device)
62 | rgbs = np.asarray(ply.colors)[:, ::-1].copy()
63 | self.rgbs = torch.Tensor(rgbs).to(self.device) * 255
64 |
65 | colored_pointclouds = pytorch3d.structures.Pointclouds(points=[self.vts], features=[self.rgbs])
66 | colored_pointclouds.center = torch.tensor([0, 0, 0])
67 |
68 | return colored_pointclouds
69 |
70 | def load_mesh(self, uvmap_obj_path):
71 | batch_size = 1
72 |
73 | verts, faces, aux = load_obj(uvmap_obj_path)
74 | verts_uvs = aux.verts_uvs[None, ...].repeat(batch_size, 1, 1)
75 | faces_uvs = faces.textures_idx[None, ...].repeat(batch_size, 1, 1)
76 |
77 | tex_maps = aux.texture_images
78 | texture_image = list(tex_maps.values())[0]
79 | texture_image = texture_image * 255.0
80 | texture_maps = torch.tensor(texture_image).unsqueeze(0).repeat(
81 | batch_size, 1, 1, 1).float()
82 | tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_maps).to(self.device)
83 |
84 | mesh = Meshes(
85 | verts=torch.tensor(verts[None, ...]).float().to(self.device),
86 | faces=torch.tensor(faces.verts_idx[None, ...]).float().to(self.device),
87 | textures=tex
88 | )
89 |
90 | return mesh
91 |
92 | def update_view(self, distance=1.0, pitch=0.0, azimuth=0.0, roll=0.0, target_point=[0.0, 0.0, 1.0]):
93 | target_point = torch.tensor(target_point, device=self.device)
94 |
95 | pitch_rad = math.radians(pitch)
96 | azimuth_rad = math.radians(azimuth)
97 |
98 | z = distance * math.cos(pitch_rad) * math.cos(azimuth_rad)
99 | x = distance * math.cos(pitch_rad) * math.sin(azimuth_rad)
100 | y = distance * math.sin(pitch_rad)
101 | self.camera_position = target_point - torch.tensor([x, y, z], device=self.device)
102 |
103 | R, T = look_at_view_transform(
104 | eye=self.camera_position.unsqueeze(0),
105 | at=target_point.unsqueeze(0),
106 | up=torch.tensor([0.0, 1.0, 0.0], device=self.device).unsqueeze(0),
107 | device=self.device
108 | )
109 |
110 | def rotate_roll(R, roll):
111 | roll_rad = math.radians(roll)
112 | roll_matrix = torch.tensor([
113 | [math.cos(roll_rad), -math.sin(roll_rad), 0],
114 | [math.sin(roll_rad), math.cos(roll_rad), 0],
115 | [0, 0, 1],
116 | ], device=self.device)
117 | return torch.matmul(R, roll_matrix)
118 |
119 | if roll != 0:
120 | R = rotate_roll(R, roll)
121 |
122 | self.cameras.R = R
123 | self.cameras.T = T
124 |
125 | return R.cpu().squeeze().numpy(), T.cpu().squeeze().numpy()
126 |
127 | def update_RT(self, RT, S=np.eye(3)):
128 |
129 | if self.is_pointcloud:
130 | R_norm = torch.tensor([[[1., 0., 0.],
131 | [0., 1., 0.],
132 | [0., 0., 1.]]])
133 | T_norm = torch.tensor([[0., 0., 0.]])
134 | else:
135 | R_norm, T_norm = look_at_view_transform(-50, 0, 0)
136 |
137 | RT_norm = torch.cat((R_norm.squeeze(0), T_norm.squeeze(0)[..., None]),
138 | 1).numpy()
139 | RT_norm = np.concatenate((RT_norm, np.array([0, 0, 0, 1])[None, ...]),
140 | 0)
141 |
142 | R = RT[:, :3]
143 | T = RT[:, 3]
144 | R = S @ R @ S
145 | T = S @ T
146 |
147 | RT = np.concatenate((R, T[..., None]), 1)
148 | RT = np.concatenate((RT, np.array([0, 0, 0, 1])[None, ...]), 0)
149 |
150 | RT = RT @ RT_norm
151 | R = RT[:3, :3][None, ...]
152 | T = RT[:3, 3][None, ...]
153 |
154 | self.cameras.R = torch.tensor(R).to(self.device)
155 | self.cameras.T = torch.tensor(T).to(self.device)
156 |
157 | return R, T
158 |
159 | def interpolate_slerp(self, RTs, T=16):
160 | Rs = [RT[0] for RT in RTs]
161 | Ts = [RT[1] for RT in RTs]
162 |
163 | times = np.linspace(0, 1, len(RTs))
164 | interp_times = np.linspace(0, 1, T)
165 | quaternions = [R.from_matrix(R_).as_quat() for R_ in Rs]
166 | interp_Ts = interp1d(times, Ts, axis=0)(interp_times)
167 |
168 | def slerp(q1, q2, t):
169 | q1 = R.from_quat(q1)
170 | q2 = R.from_quat(q2)
171 | return (q1 * (q2 * q1.inv()) ** t).as_quat()
172 |
173 | interp_quaternions = []
174 | for i in range(len(quaternions) - 1):
175 | t = (interp_times - times[i]) / (times[i + 1] - times[i])
176 | valid_t = t[(t >= 0) & (t <= 1)]
177 | for t_val in valid_t:
178 | interp_quaternions.append(slerp(quaternions[i], quaternions[i + 1], t_val).squeeze())
179 |
180 | if len(interp_quaternions) < T:
181 | interp_quaternions.extend([quaternions[-1]] * (T - len(interp_quaternions)))
182 |
183 | interp_quaternions = np.array(interp_quaternions)
184 | interp_Rs = [R.from_quat(q).as_matrix() for q in interp_quaternions]
185 | interpolated_poses = []
186 | for t in range(T):
187 | interpolated_poses.append((interp_Rs[t].squeeze(), interp_Ts[t].squeeze()))
188 |
189 | return interpolated_poses
190 |
191 | def set_renderer(self):
192 |
193 | self.lights = PointLights(device=self.device,
194 | location=[[0.0, 0.0, 1e5]],
195 | ambient_color=[[1, 1, 1]],
196 | specular_color=[[0., 0., 0.]],
197 | diffuse_color=[[0., 0., 0.]])
198 |
199 | if self.is_pointcloud:
200 | self.raster_settings = PointsRasterizationSettings(
201 | image_size=self.img_size,
202 | radius=0.01,
203 | points_per_pixel=10,
204 | bin_size=0,
205 | )
206 | else:
207 | self.raster_settings = RasterizationSettings(
208 | image_size=self.img_size,
209 | blur_radius=0.0,
210 | faces_per_pixel=10,
211 | bin_size=0,
212 | )
213 | self.blend_params = blending.BlendParams(background_color=[0, 0, 0])
214 |
215 | R_norm = torch.tensor([[[1., 0., 0.],
216 | [0., 1., 0.],
217 | [0., 0., 1.]]])
218 | T_norm = torch.tensor([[0., 0., 0.]])
219 |
220 | self.cameras = FoVPerspectiveCameras(
221 | device=self.device,
222 | R=R_norm,
223 | T=T_norm,
224 | znear=0.01,
225 | zfar=200,
226 | fov=2 * np.arctan(self.img_size[0] // 2 / self.focal) * 180. / np.pi,
227 | aspect_ratio=1.0,
228 | )
229 |
230 | return
231 |
232 | def get_renderer(self):
233 | if self.is_pointcloud:
234 |
235 | self.renderer = PointsRenderer(
236 | rasterizer=PointsRasterizer(cameras=self.cameras, raster_settings=self.raster_settings),
237 | compositor=AlphaCompositor()
238 | )
239 | else:
240 | self.renderer = MeshRenderer(rasterizer=MeshRasterizer(
241 | cameras=self.cameras, raster_settings=self.raster_settings),
242 | shader=SoftPhongShader(device=self.device,
243 | cameras=self.cameras,
244 | lights=self.lights,
245 | blend_params=self.blend_params))
246 |
247 | return
248 |
249 | def render(self):
250 | rendered_img = self.renderer(self.model).cpu()
251 | return rendered_img
252 |
253 | def project_vs(self, coords_3d=None):
254 | if coords_3d is None:
255 | coords_3d = self.model.points_padded()[0]
256 | else:
257 | coords_3d = torch.tensor(coords_3d).to(self.device).float()
258 |
259 | coords_3d = self.cameras.transform_points_screen(coords_3d, image_size=self.img_size)
260 |
261 | return coords_3d
262 |
263 | def get_vertical_distances(self, points_world):
264 | R = self.cameras.R.cpu().numpy()[0]
265 | T = self.cameras.T.cpu().numpy()[0]
266 | T = T[:, np.newaxis]
267 | points_camera = (R @ points_world.T + T).T
268 | vertical_distances = points_camera[:, 2]
269 |
270 | return vertical_distances
271 |
272 |
273 | def generate_query(img_size, video_length):
274 |
275 | height, width = img_size[0], img_size[1]
276 | query_points = []
277 | for i in range(0, width, width // 30):
278 | for j in range(0, height, height // 30):
279 | query_points.append([i, j])
280 | query_points = np.array(query_points)[None, ...]
281 | query_points = np.repeat(query_points, video_length, 0)
282 | return query_points
283 |
284 |
285 | def estimate_depth(imagepath, query_points, img_size, focal, output_dir):
286 |
287 | height, width = img_size[0], img_size[1]
288 | imagename = imagepath.split('/')[-1].split('.')[0]
289 |
290 | metric_depth, relative_depth, _ = depth_estimation(imagepath)
291 |
292 | query_points = np.round(query_points)
293 | relative_points_depth = [relative_depth[int(p[1]), int(p[0])] if 0 <= int(p[1]) < relative_depth.shape[0] and 0 <= int(p[0]) < relative_depth.shape[1] else 0.5 for p in query_points[0]]
294 | relative_points_depth = np.array(relative_points_depth)[None, ...]
295 | relative_points_depth = np.repeat(relative_points_depth, query_points.shape[0], 0)[..., None]
296 |
297 | points_depth = [metric_depth[int(p[1]), int(p[0])] if 0 <= int(p[1]) < metric_depth.shape[0] and 0 <= int(p[0]) < metric_depth.shape[1] else 0.5 for p in query_points[0]]
298 | points_depth = np.array(points_depth)[None, ...]
299 | points_depth = np.repeat(points_depth, query_points.shape[0], 0)[..., None]
300 | points_3d = np.concatenate([query_points, points_depth], -1)
301 |
302 | points_3d[:, :, 0] = -(points_3d[:, :, 0] - width / 2) / focal
303 | points_3d[:, :, 1] = -(points_3d[:, :, 1] - height / 2) / focal
304 |
305 | points_3d[:, :, 0] *= points_3d[:, :, 2]
306 | points_3d[:, :, 1] *= points_3d[:, :, 2]
307 |
308 | return points_3d, relative_points_depth
309 |
310 |
311 | def cam_adaptation(imagepath, vo_path, query_path, cam_type, video_length, output_dir):
312 |
313 | os.makedirs(output_dir, exist_ok=True)
314 |
315 | tmp_folder = './tmp_cam'
316 | if os.path.exists(tmp_folder):
317 | shutil.rmtree(tmp_folder)
318 | os.makedirs(tmp_folder, exist_ok=False)
319 |
320 | img = Image.open(imagepath).convert('RGB')
321 | img_size = (img.size[1], img.size[0])
322 |
323 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
324 |
325 | S = np.eye(3)
326 | focal_point = 500
327 | uvmap_obj_path = './mesh/world_envelope.obj'
328 |
329 | renderer = Render(img_size=img_size, device=device, uvmap_obj_path=uvmap_obj_path, is_pointcloud=False)
330 | renderer_point = Render(img_size=img_size, focal=focal_point, device=device, is_pointcloud=True)
331 |
332 | if cam_type == 'generated':
333 | pattern_name = vo_path.split('/')[-1].split('.')[0]
334 | RTs = generate_vo(renderer_point, pattern_name)
335 | else:
336 | camera_poses = np.loadtxt(vo_path).reshape(-1, 3, 4)
337 | RTs = [(pose[:, :3], pose[:, 3]) for pose in camera_poses]
338 |
339 | if len(RTs) < video_length:
340 | interpolated_poses = renderer.interpolate_slerp(RTs, T=video_length)
341 | else:
342 | interpolated_poses = RTs
343 | interpolated_camera_poses = []
344 | for i, pose in enumerate(interpolated_poses):
345 | R = pose[0]
346 | T = pose[1]
347 | RT = np.concatenate((R, T[..., None]), 1)
348 | interpolated_camera_poses.append(RT)
349 | interpolated_camera_poses = np.array(interpolated_camera_poses)
350 | camera_poses = interpolated_camera_poses
351 |
352 | if query_path == 'none':
353 | query_points = generate_query(img_size, video_length)
354 | else:
355 | query_points = np.load(query_path)[:, :, :2]
356 |
357 | points_3d, points_depth = estimate_depth(imagepath, query_points, img_size, focal_point, output_dir)
358 |
359 | projected_3d_points = []
360 | for idx, RT in enumerate(camera_poses):
361 | renderer.update_RT(RT, S)
362 | rendered_img = renderer.render().squeeze().numpy()[:, :, :3]
363 |
364 | renderer_point.update_RT(RT, S)
365 | projected_3d_point = renderer_point.project_vs(points_3d[idx]).cpu().numpy()
366 | projected_3d_point_depth = renderer.get_vertical_distances(points_3d[idx])
367 | projected_3d_point[:, 2] = projected_3d_point_depth
368 | projected_3d_points.append(projected_3d_point)
369 |
370 | cv2.imwrite(os.path.join(tmp_folder, str(idx).zfill(5) + ".png"), rendered_img)
371 |
372 | cam_path = "./tmp/rendered_cam.mp4"
373 | track_path = "./tmp/rendered_tracks.npy"
374 |
375 | projected_3d_points = np.array(projected_3d_points)
376 | np.save(track_path, projected_3d_points)
377 |
378 | subprocess.call(
379 | "ffmpeg -loglevel quiet -i {} -c:v h264 -pix_fmt yuv420p -y {}".format(
380 | os.path.join(tmp_folder, "%05d.png"), cam_path),
381 | shell=True)
382 | shutil.rmtree(tmp_folder)
383 |
384 | return projected_3d_points, points_depth, cam_path
385 |
386 |
387 | if __name__ == '__main__':
388 | parser = argparse.ArgumentParser()
389 |
390 | # set the gpu
391 | parser.add_argument('--gpu', type=int, default=0, help='gpu id')
392 | # set start idx & end idx
393 | parser.add_argument('--imagepath', type=str, default='./assets/boat3.jpg', help='image path')
394 | parser.add_argument('--vo_path', type=str, default='./cam_patterns/PanDown.txt', help='video path')
395 | parser.add_argument('--query_path', type=str, default='none', help='quey point path')
396 | parser.add_argument('--output_dir', type=str, default='./tmp', help='output dir')
397 | parser.add_argument('--video_length', type=int, default=16, help='camera pose length')
398 | parser.add_argument('--cam_type', type=str, default='generated', help='generate or user-provided')
399 |
400 | args = parser.parse_args()
401 |
402 | # set the gpu
403 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
404 |
405 | os.makedirs(args.output_dir, exist_ok=True)
406 |
407 | tmp_folder = './tmp_cam'
408 | if os.path.exists(tmp_folder):
409 | shutil.rmtree(tmp_folder)
410 | os.makedirs(tmp_folder, exist_ok=False)
411 |
412 | img = Image.open(args.imagepath).convert('RGB')
413 | img_size = (img.size[1], img.size[0])
414 |
415 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
416 |
417 | S = np.eye(3)
418 | focal_point = 500
419 | uvmap_obj_path = './mesh/world_envelope.obj'
420 |
421 | renderer = Render(img_size=img_size, device=device, uvmap_obj_path=uvmap_obj_path, is_pointcloud=False)
422 | renderer_point = Render(img_size=img_size, focal=focal_point, device=device, is_pointcloud=True)
423 |
424 | if args.cam_type == 'generated':
425 | pattern_name = args.vo_path.split('/')[-1].split('.')[0]
426 | # RTs = generate_vo(renderer_point, pattern_name)
427 | raise NotImplementedError
428 | else:
429 | camera_poses = np.loadtxt(args.vo_path).reshape(-1, 3, 4)
430 | RTs = [(pose[:, :3], pose[:, 3]) for pose in camera_poses]
431 |
432 | if len(RTs) < args.video_length:
433 | interpolated_poses = renderer.interpolate_slerp(RTs, T=args.video_length)
434 | else:
435 | interpolated_poses = RTs
436 | interpolated_camera_poses = []
437 | for i, pose in enumerate(interpolated_poses):
438 | R = pose[0]
439 | T = pose[1]
440 | RT = np.concatenate((R, T[..., None]), 1)
441 | interpolated_camera_poses.append(RT)
442 | interpolated_camera_poses = np.array(interpolated_camera_poses)
443 | camera_poses = interpolated_camera_poses
444 |
445 | if args.query_path == 'none':
446 | query_points = generate_query(img_size, args.video_length)
447 | else:
448 | query_points = np.load(args.query_path)[:, :, :2]
449 |
450 | points_3d, _ = estimate_depth(args.imagepath, query_points, img_size, focal_point, args.output_dir)
451 |
452 | projected_3d_points = []
453 | for idx, RT in enumerate(camera_poses):
454 | renderer.update_RT(RT, S)
455 | rendered_img = renderer.render().squeeze().numpy()[:, :, :3]
456 |
457 | renderer_point.update_RT(RT, S)
458 | projected_3d_point = renderer_point.project_vs(points_3d[idx]).cpu().numpy()
459 | projected_3d_point_depth = renderer.get_vertical_distances(points_3d[idx])
460 | projected_3d_point[:, 2] = projected_3d_point_depth
461 | projected_3d_points.append(projected_3d_point)
462 |
463 | cv2.imwrite(os.path.join(tmp_folder, str(idx).zfill(5) + ".png"), rendered_img)
464 |
465 | cam_path = "./tmp/rendered_cam.mp4"
466 | track_path = "./tmp/rendered_tracks.npy"
467 |
468 | projected_3d_points = np.array(projected_3d_points)
469 | np.save(track_path, projected_3d_points)
470 |
471 | subprocess.call(
472 | "ffmpeg -loglevel quiet -i {} -c:v h264 -pix_fmt yuv420p -y {}".format(
473 | os.path.join(tmp_folder, "%05d.png"), cam_path),
474 | shell=True)
475 | shutil.rmtree(tmp_folder)
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/configs/.DS_Store
--------------------------------------------------------------------------------
/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | base_model_path: './pretrained_weights/sd-image-variations-diffusers'
2 | vae_model_path: './pretrained_weights/sd-vae-ft-mse'
3 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
4 | mm_path: ''
5 |
6 | denoising_unet_path: ./pretrained_weights/denoising_unet.pth
7 | reference_unet_path: ./pretrained_weights/reference_unet.pth
8 | obj_encoder_path: ./pretrained_weights/obj_encoder.pth
9 | cam_encoder_path: ./pretrained_weights/cam_encoder.pth
10 |
11 | use_lora: false
12 | lora_rank: 64
13 |
14 | inference_config: "./configs/inference/inference_v2.yaml"
15 | weight_dtype: 'fp32'
16 |
17 | is_obj: true
18 | is_cam: true
19 | is_depth: true
20 | is_adapted: true
21 | fusion_type: 'max'
22 | is_pad: true
23 |
24 | W: 768
25 | H: 512
26 | circle_scale: 10
27 |
28 | sample_n_frames: 16
29 | sample_stride: 4
30 | guidance_scale: 3.5
31 | steps: 20
32 | seed: 12580
33 |
34 | save_dir: './outputs'
35 |
36 | sample_n_trajs: -1
37 | cam_only: false
38 | obj_only: false
39 |
40 | remove_tmp_results: true
41 |
42 | test_cases:
43 |
44 | - "./examples/balloons.png":
45 | - "./examples/balloons.json"
46 | - "./examples/Still.txt"
47 | - 12597
48 |
49 | - "./examples/backview.jpeg":
50 | - "./examples/backview.json"
51 | - "./examples/DollyOut.txt"
52 | - 12597
53 |
54 | - "./examples/balloon.png":
55 | - "./examples/balloon.json"
56 | - "./examples/TiltDown.txt"
57 | - 12580
58 |
59 | - "./examples/bamboo.png":
60 | - "./examples/bamboo.json"
61 | - "./examples/Still.txt"
62 | - 12587
--------------------------------------------------------------------------------
/configs/inference/inference_v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 32
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "linear"
28 | clip_sample: false
29 | steps_offset: 1
30 | ### Zero-SNR params
31 | prediction_type: "v_prediction"
32 | rescale_betas_zero_snr: True
33 | timestep_spacing: "trailing"
34 |
35 | sampler: DDIM
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/__pycache__/dpt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/__pycache__/dpt.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9 |
10 | from functools import partial
11 | import math
12 | import logging
13 | from typing import Sequence, Tuple, Union, Callable
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.utils.checkpoint
18 | from torch.nn.init import trunc_normal_
19 |
20 | from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27 | if not depth_first and include_root:
28 | fn(module=module, name=name)
29 | for child_name, child_module in module.named_children():
30 | child_name = ".".join((name, child_name)) if name else child_name
31 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32 | if depth_first and include_root:
33 | fn(module=module, name=name)
34 | return module
35 |
36 |
37 | class BlockChunk(nn.ModuleList):
38 | def forward(self, x):
39 | for b in self:
40 | x = b(x)
41 | return x
42 |
43 |
44 | class DinoVisionTransformer(nn.Module):
45 | def __init__(
46 | self,
47 | img_size=224,
48 | patch_size=16,
49 | in_chans=3,
50 | embed_dim=768,
51 | depth=12,
52 | num_heads=12,
53 | mlp_ratio=4.0,
54 | qkv_bias=True,
55 | ffn_bias=True,
56 | proj_bias=True,
57 | drop_path_rate=0.0,
58 | drop_path_uniform=False,
59 | init_values=None, # for layerscale: None or 0 => no layerscale
60 | embed_layer=PatchEmbed,
61 | act_layer=nn.GELU,
62 | block_fn=Block,
63 | ffn_layer="mlp",
64 | block_chunks=1,
65 | num_register_tokens=0,
66 | interpolate_antialias=False,
67 | interpolate_offset=0.1,
68 | ):
69 | """
70 | Args:
71 | img_size (int, tuple): input image size
72 | patch_size (int, tuple): patch size
73 | in_chans (int): number of input channels
74 | embed_dim (int): embedding dimension
75 | depth (int): depth of transformer
76 | num_heads (int): number of attention heads
77 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78 | qkv_bias (bool): enable bias for qkv if True
79 | proj_bias (bool): enable bias for proj in attn if True
80 | ffn_bias (bool): enable bias for ffn if True
81 | drop_path_rate (float): stochastic depth rate
82 | drop_path_uniform (bool): apply uniform drop rate across blocks
83 | weight_init (str): weight init scheme
84 | init_values (float): layer-scale init values
85 | embed_layer (nn.Module): patch embedding layer
86 | act_layer (nn.Module): MLP activation layer
87 | block_fn (nn.Module): transformer block class
88 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90 | num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91 | interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92 | interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93 | """
94 | super().__init__()
95 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
96 |
97 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98 | self.num_tokens = 1
99 | self.n_blocks = depth
100 | self.num_heads = num_heads
101 | self.patch_size = patch_size
102 | self.num_register_tokens = num_register_tokens
103 | self.interpolate_antialias = interpolate_antialias
104 | self.interpolate_offset = interpolate_offset
105 |
106 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107 | num_patches = self.patch_embed.num_patches
108 |
109 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111 | assert num_register_tokens >= 0
112 | self.register_tokens = (
113 | nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114 | )
115 |
116 | if drop_path_uniform is True:
117 | dpr = [drop_path_rate] * depth
118 | else:
119 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120 |
121 | if ffn_layer == "mlp":
122 | logger.info("using MLP layer as FFN")
123 | ffn_layer = Mlp
124 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125 | logger.info("using SwiGLU layer as FFN")
126 | ffn_layer = SwiGLUFFNFused
127 | elif ffn_layer == "identity":
128 | logger.info("using Identity layer as FFN")
129 |
130 | def f(*args, **kwargs):
131 | return nn.Identity()
132 |
133 | ffn_layer = f
134 | else:
135 | raise NotImplementedError
136 |
137 | blocks_list = [
138 | block_fn(
139 | dim=embed_dim,
140 | num_heads=num_heads,
141 | mlp_ratio=mlp_ratio,
142 | qkv_bias=qkv_bias,
143 | proj_bias=proj_bias,
144 | ffn_bias=ffn_bias,
145 | drop_path=dpr[i],
146 | norm_layer=norm_layer,
147 | act_layer=act_layer,
148 | ffn_layer=ffn_layer,
149 | init_values=init_values,
150 | )
151 | for i in range(depth)
152 | ]
153 | if block_chunks > 0:
154 | self.chunked_blocks = True
155 | chunked_blocks = []
156 | chunksize = depth // block_chunks
157 | for i in range(0, depth, chunksize):
158 | # this is to keep the block index consistent if we chunk the block list
159 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161 | else:
162 | self.chunked_blocks = False
163 | self.blocks = nn.ModuleList(blocks_list)
164 |
165 | self.norm = norm_layer(embed_dim)
166 | self.head = nn.Identity()
167 |
168 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169 |
170 | self.init_weights()
171 |
172 | def init_weights(self):
173 | trunc_normal_(self.pos_embed, std=0.02)
174 | nn.init.normal_(self.cls_token, std=1e-6)
175 | if self.register_tokens is not None:
176 | nn.init.normal_(self.register_tokens, std=1e-6)
177 | named_apply(init_weights_vit_timm, self)
178 |
179 | def interpolate_pos_encoding(self, x, w, h):
180 | previous_dtype = x.dtype
181 | npatch = x.shape[1] - 1
182 | N = self.pos_embed.shape[1] - 1
183 | if npatch == N and w == h:
184 | return self.pos_embed
185 | pos_embed = self.pos_embed.float()
186 | class_pos_embed = pos_embed[:, 0]
187 | patch_pos_embed = pos_embed[:, 1:]
188 | dim = x.shape[-1]
189 | w0 = w // self.patch_size
190 | h0 = h // self.patch_size
191 | # we add a small number to avoid floating point error in the interpolation
192 | # see discussion at https://github.com/facebookresearch/dino/issues/8
193 | # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194 | w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195 | # w0, h0 = w0 + 0.1, h0 + 0.1
196 |
197 | sqrt_N = math.sqrt(N)
198 | sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199 | patch_pos_embed = nn.functional.interpolate(
200 | patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201 | scale_factor=(sx, sy),
202 | # (int(w0), int(h0)), # to solve the upsampling shape issue
203 | mode="bicubic",
204 | antialias=self.interpolate_antialias
205 | )
206 |
207 | assert int(w0) == patch_pos_embed.shape[-2]
208 | assert int(h0) == patch_pos_embed.shape[-1]
209 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211 |
212 | def prepare_tokens_with_masks(self, x, masks=None):
213 | B, nc, w, h = x.shape
214 | x = self.patch_embed(x)
215 | if masks is not None:
216 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217 |
218 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219 | x = x + self.interpolate_pos_encoding(x, w, h)
220 |
221 | if self.register_tokens is not None:
222 | x = torch.cat(
223 | (
224 | x[:, :1],
225 | self.register_tokens.expand(x.shape[0], -1, -1),
226 | x[:, 1:],
227 | ),
228 | dim=1,
229 | )
230 |
231 | return x
232 |
233 | def forward_features_list(self, x_list, masks_list):
234 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235 | for blk in self.blocks:
236 | x = blk(x)
237 |
238 | all_x = x
239 | output = []
240 | for x, masks in zip(all_x, masks_list):
241 | x_norm = self.norm(x)
242 | output.append(
243 | {
244 | "x_norm_clstoken": x_norm[:, 0],
245 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247 | "x_prenorm": x,
248 | "masks": masks,
249 | }
250 | )
251 | return output
252 |
253 | def forward_features(self, x, masks=None):
254 | if isinstance(x, list):
255 | return self.forward_features_list(x, masks)
256 |
257 | x = self.prepare_tokens_with_masks(x, masks)
258 |
259 | for blk in self.blocks:
260 | x = blk(x)
261 |
262 | x_norm = self.norm(x)
263 | return {
264 | "x_norm_clstoken": x_norm[:, 0],
265 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267 | "x_prenorm": x,
268 | "masks": masks,
269 | }
270 |
271 | def _get_intermediate_layers_not_chunked(self, x, n=1):
272 | x = self.prepare_tokens_with_masks(x)
273 | # If n is an int, take the n last blocks. If it's a list, take them
274 | output, total_block_len = [], len(self.blocks)
275 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276 | for i, blk in enumerate(self.blocks):
277 | x = blk(x)
278 | if i in blocks_to_take:
279 | output.append(x)
280 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281 | return output
282 |
283 | def _get_intermediate_layers_chunked(self, x, n=1):
284 | x = self.prepare_tokens_with_masks(x)
285 | output, i, total_block_len = [], 0, len(self.blocks[-1])
286 | # If n is an int, take the n last blocks. If it's a list, take them
287 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288 | for block_chunk in self.blocks:
289 | for blk in block_chunk[i:]: # Passing the nn.Identity()
290 | x = blk(x)
291 | if i in blocks_to_take:
292 | output.append(x)
293 | i += 1
294 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295 | return output
296 |
297 | def get_intermediate_layers(
298 | self,
299 | x: torch.Tensor,
300 | n: Union[int, Sequence] = 1, # Layers or n last layers to take
301 | reshape: bool = False,
302 | return_class_token: bool = False,
303 | norm=True
304 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305 | if self.chunked_blocks:
306 | outputs = self._get_intermediate_layers_chunked(x, n)
307 | else:
308 | outputs = self._get_intermediate_layers_not_chunked(x, n)
309 | if norm:
310 | outputs = [self.norm(out) for out in outputs]
311 | class_tokens = [out[:, 0] for out in outputs]
312 | outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313 | if reshape:
314 | B, _, w, h = x.shape
315 | outputs = [
316 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317 | for out in outputs
318 | ]
319 | if return_class_token:
320 | return tuple(zip(outputs, class_tokens))
321 | return tuple(outputs)
322 |
323 | def forward(self, *args, is_training=False, **kwargs):
324 | ret = self.forward_features(*args, **kwargs)
325 | if is_training:
326 | return ret
327 | else:
328 | return self.head(ret["x_norm_clstoken"])
329 |
330 |
331 | def init_weights_vit_timm(module: nn.Module, name: str = ""):
332 | """ViT weight initialization, original timm impl (for reproducibility)"""
333 | if isinstance(module, nn.Linear):
334 | trunc_normal_(module.weight, std=0.02)
335 | if module.bias is not None:
336 | nn.init.zeros_(module.bias)
337 |
338 |
339 | def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340 | model = DinoVisionTransformer(
341 | patch_size=patch_size,
342 | embed_dim=384,
343 | depth=12,
344 | num_heads=6,
345 | mlp_ratio=4,
346 | block_fn=partial(Block, attn_class=MemEffAttention),
347 | num_register_tokens=num_register_tokens,
348 | **kwargs,
349 | )
350 | return model
351 |
352 |
353 | def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354 | model = DinoVisionTransformer(
355 | patch_size=patch_size,
356 | embed_dim=768,
357 | depth=12,
358 | num_heads=12,
359 | mlp_ratio=4,
360 | block_fn=partial(Block, attn_class=MemEffAttention),
361 | num_register_tokens=num_register_tokens,
362 | **kwargs,
363 | )
364 | return model
365 |
366 |
367 | def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368 | model = DinoVisionTransformer(
369 | patch_size=patch_size,
370 | embed_dim=1024,
371 | depth=24,
372 | num_heads=16,
373 | mlp_ratio=4,
374 | block_fn=partial(Block, attn_class=MemEffAttention),
375 | num_register_tokens=num_register_tokens,
376 | **kwargs,
377 | )
378 | return model
379 |
380 |
381 | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382 | """
383 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384 | """
385 | model = DinoVisionTransformer(
386 | patch_size=patch_size,
387 | embed_dim=1536,
388 | depth=40,
389 | num_heads=24,
390 | mlp_ratio=4,
391 | block_fn=partial(Block, attn_class=MemEffAttention),
392 | num_register_tokens=num_register_tokens,
393 | **kwargs,
394 | )
395 | return model
396 |
397 |
398 | def DINOv2(model_name):
399 | model_zoo = {
400 | "vits": vit_small,
401 | "vitb": vit_base,
402 | "vitl": vit_large,
403 | "vitg": vit_giant2
404 | }
405 |
406 | return model_zoo[model_name](
407 | img_size=518,
408 | patch_size=14,
409 | init_values=1.0,
410 | ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411 | block_chunks=0,
412 | num_register_tokens=0,
413 | interpolate_antialias=False,
414 | interpolate_offset=0.1
415 | )
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .mlp import Mlp
8 | from .patch_embed import PatchEmbed
9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10 | from .block import NestedTensorBlock
11 | from .attention import MemEffAttention
12 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
83 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dinov2_layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/dpt.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision.transforms import Compose
6 |
7 | from .dinov2 import DINOv2
8 | from .util.blocks import FeatureFusionBlock, _make_scratch
9 | from .util.transform import Resize, NormalizeImage, PrepareForNet
10 |
11 |
12 | def _make_fusion_block(features, use_bn, size=None):
13 | return FeatureFusionBlock(
14 | features,
15 | nn.ReLU(False),
16 | deconv=False,
17 | bn=use_bn,
18 | expand=False,
19 | align_corners=True,
20 | size=size,
21 | )
22 |
23 |
24 | class ConvBlock(nn.Module):
25 | def __init__(self, in_feature, out_feature):
26 | super().__init__()
27 |
28 | self.conv_block = nn.Sequential(
29 | nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
30 | nn.BatchNorm2d(out_feature),
31 | nn.ReLU(True)
32 | )
33 |
34 | def forward(self, x):
35 | return self.conv_block(x)
36 |
37 |
38 | class DPTHead(nn.Module):
39 | def __init__(
40 | self,
41 | in_channels,
42 | features=256,
43 | use_bn=False,
44 | out_channels=[256, 512, 1024, 1024],
45 | use_clstoken=False
46 | ):
47 | super(DPTHead, self).__init__()
48 |
49 | self.use_clstoken = use_clstoken
50 |
51 | self.projects = nn.ModuleList([
52 | nn.Conv2d(
53 | in_channels=in_channels,
54 | out_channels=out_channel,
55 | kernel_size=1,
56 | stride=1,
57 | padding=0,
58 | ) for out_channel in out_channels
59 | ])
60 |
61 | self.resize_layers = nn.ModuleList([
62 | nn.ConvTranspose2d(
63 | in_channels=out_channels[0],
64 | out_channels=out_channels[0],
65 | kernel_size=4,
66 | stride=4,
67 | padding=0),
68 | nn.ConvTranspose2d(
69 | in_channels=out_channels[1],
70 | out_channels=out_channels[1],
71 | kernel_size=2,
72 | stride=2,
73 | padding=0),
74 | nn.Identity(),
75 | nn.Conv2d(
76 | in_channels=out_channels[3],
77 | out_channels=out_channels[3],
78 | kernel_size=3,
79 | stride=2,
80 | padding=1)
81 | ])
82 |
83 | if use_clstoken:
84 | self.readout_projects = nn.ModuleList()
85 | for _ in range(len(self.projects)):
86 | self.readout_projects.append(
87 | nn.Sequential(
88 | nn.Linear(2 * in_channels, in_channels),
89 | nn.GELU()))
90 |
91 | self.scratch = _make_scratch(
92 | out_channels,
93 | features,
94 | groups=1,
95 | expand=False,
96 | )
97 |
98 | self.scratch.stem_transpose = None
99 |
100 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
101 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
102 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
103 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
104 |
105 | head_features_1 = features
106 | head_features_2 = 32
107 |
108 | self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
109 | self.scratch.output_conv2 = nn.Sequential(
110 | nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
111 | nn.ReLU(True),
112 | nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
113 | nn.Sigmoid()
114 | )
115 |
116 | def forward(self, out_features, patch_h, patch_w):
117 | out = []
118 | for i, x in enumerate(out_features):
119 | if self.use_clstoken:
120 | x, cls_token = x[0], x[1]
121 | readout = cls_token.unsqueeze(1).expand_as(x)
122 | x = self.readout_projects[i](torch.cat((x, readout), -1))
123 | else:
124 | x = x[0]
125 |
126 | x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
127 |
128 | x = self.projects[i](x)
129 | x = self.resize_layers[i](x)
130 |
131 | out.append(x)
132 |
133 | layer_1, layer_2, layer_3, layer_4 = out
134 |
135 | layer_1_rn = self.scratch.layer1_rn(layer_1)
136 | layer_2_rn = self.scratch.layer2_rn(layer_2)
137 | layer_3_rn = self.scratch.layer3_rn(layer_3)
138 | layer_4_rn = self.scratch.layer4_rn(layer_4)
139 |
140 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
141 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
142 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
143 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
144 |
145 | out = self.scratch.output_conv1(path_1)
146 | out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
147 | out = self.scratch.output_conv2(out)
148 |
149 | return out
150 |
151 |
152 | class DepthAnythingV2(nn.Module):
153 | def __init__(
154 | self,
155 | encoder='vitl',
156 | features=256,
157 | out_channels=[256, 512, 1024, 1024],
158 | use_bn=False,
159 | use_clstoken=False,
160 | max_depth=20.0
161 | ):
162 | super(DepthAnythingV2, self).__init__()
163 |
164 | self.intermediate_layer_idx = {
165 | 'vits': [2, 5, 8, 11],
166 | 'vitb': [2, 5, 8, 11],
167 | 'vitl': [4, 11, 17, 23],
168 | 'vitg': [9, 19, 29, 39]
169 | }
170 |
171 | self.max_depth = max_depth
172 |
173 | self.encoder = encoder
174 | self.pretrained = DINOv2(model_name=encoder)
175 |
176 | self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
177 |
178 | def forward(self, x):
179 | patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
180 |
181 | features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
182 |
183 | depth = self.depth_head(features, patch_h, patch_w) * self.max_depth
184 |
185 | return depth.squeeze(1)
186 |
187 | @torch.no_grad()
188 | def infer_image(self, raw_image, input_size=518):
189 | image, (h, w) = self.image2tensor(raw_image, input_size)
190 |
191 | depth = self.forward(image)
192 |
193 | depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
194 |
195 | return depth.cpu().numpy()
196 |
197 | def image2tensor(self, raw_image, input_size=518):
198 | transform = Compose([
199 | Resize(
200 | width=input_size,
201 | height=input_size,
202 | resize_target=False,
203 | keep_aspect_ratio=True,
204 | ensure_multiple_of=14,
205 | resize_method='lower_bound',
206 | image_interpolation_method=cv2.INTER_CUBIC,
207 | ),
208 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
209 | PrepareForNet(),
210 | ])
211 |
212 | h, w = raw_image.shape[:2]
213 |
214 | image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
215 |
216 | image = transform({'image': image})['image']
217 | image = torch.from_numpy(image).unsqueeze(0)
218 |
219 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
220 | image = image.to(DEVICE)
221 |
222 | return image, (h, w)
223 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/depth_anything_v2/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/util/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5 | scratch = nn.Module()
6 |
7 | out_shape1 = out_shape
8 | out_shape2 = out_shape
9 | out_shape3 = out_shape
10 | if len(in_shape) >= 4:
11 | out_shape4 = out_shape
12 |
13 | if expand:
14 | out_shape1 = out_shape
15 | out_shape2 = out_shape * 2
16 | out_shape3 = out_shape * 4
17 | if len(in_shape) >= 4:
18 | out_shape4 = out_shape * 8
19 |
20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23 | if len(in_shape) >= 4:
24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25 |
26 | return scratch
27 |
28 |
29 | class ResidualConvUnit(nn.Module):
30 | """Residual convolution module.
31 | """
32 |
33 | def __init__(self, features, activation, bn):
34 | """Init.
35 |
36 | Args:
37 | features (int): number of features
38 | """
39 | super().__init__()
40 |
41 | self.bn = bn
42 |
43 | self.groups=1
44 |
45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46 |
47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48 |
49 | if self.bn == True:
50 | self.bn1 = nn.BatchNorm2d(features)
51 | self.bn2 = nn.BatchNorm2d(features)
52 |
53 | self.activation = activation
54 |
55 | self.skip_add = nn.quantized.FloatFunctional()
56 |
57 | def forward(self, x):
58 | """Forward pass.
59 |
60 | Args:
61 | x (tensor): input
62 |
63 | Returns:
64 | tensor: output
65 | """
66 |
67 | out = self.activation(x)
68 | out = self.conv1(out)
69 | if self.bn == True:
70 | out = self.bn1(out)
71 |
72 | out = self.activation(out)
73 | out = self.conv2(out)
74 | if self.bn == True:
75 | out = self.bn2(out)
76 |
77 | if self.groups > 1:
78 | out = self.conv_merge(out)
79 |
80 | return self.skip_add.add(out, x)
81 |
82 |
83 | class FeatureFusionBlock(nn.Module):
84 | """Feature fusion block.
85 | """
86 |
87 | def __init__(
88 | self,
89 | features,
90 | activation,
91 | deconv=False,
92 | bn=False,
93 | expand=False,
94 | align_corners=True,
95 | size=None
96 | ):
97 | """Init.
98 |
99 | Args:
100 | features (int): number of features
101 | """
102 | super(FeatureFusionBlock, self).__init__()
103 |
104 | self.deconv = deconv
105 | self.align_corners = align_corners
106 |
107 | self.groups=1
108 |
109 | self.expand = expand
110 | out_features = features
111 | if self.expand == True:
112 | out_features = features // 2
113 |
114 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115 |
116 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118 |
119 | self.skip_add = nn.quantized.FloatFunctional()
120 |
121 | self.size=size
122 |
123 | def forward(self, *xs, size=None):
124 | """Forward pass.
125 |
126 | Returns:
127 | tensor: output
128 | """
129 | output = xs[0]
130 |
131 | if len(xs) == 2:
132 | res = self.resConfUnit1(xs[1])
133 | output = self.skip_add.add(output, res)
134 |
135 | output = self.resConfUnit2(output)
136 |
137 | if (size is None) and (self.size is None):
138 | modifier = {"scale_factor": 2}
139 | elif size is None:
140 | modifier = {"size": self.size}
141 | else:
142 | modifier = {"size": size}
143 |
144 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145 |
146 | output = self.out_conv(output)
147 |
148 | return output
149 |
--------------------------------------------------------------------------------
/depth_anything_v2/depth_anything_v2/util/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | class Resize(object):
6 | """Resize sample to given size (width, height).
7 | """
8 |
9 | def __init__(
10 | self,
11 | width,
12 | height,
13 | resize_target=True,
14 | keep_aspect_ratio=False,
15 | ensure_multiple_of=1,
16 | resize_method="lower_bound",
17 | image_interpolation_method=cv2.INTER_AREA,
18 | ):
19 | """Init.
20 |
21 | Args:
22 | width (int): desired output width
23 | height (int): desired output height
24 | resize_target (bool, optional):
25 | True: Resize the full sample (image, mask, target).
26 | False: Resize image only.
27 | Defaults to True.
28 | keep_aspect_ratio (bool, optional):
29 | True: Keep the aspect ratio of the input sample.
30 | Output sample might not have the given width and height, and
31 | resize behaviour depends on the parameter 'resize_method'.
32 | Defaults to False.
33 | ensure_multiple_of (int, optional):
34 | Output width and height is constrained to be multiple of this parameter.
35 | Defaults to 1.
36 | resize_method (str, optional):
37 | "lower_bound": Output will be at least as large as the given size.
38 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40 | Defaults to "lower_bound".
41 | """
42 | self.__width = width
43 | self.__height = height
44 |
45 | self.__resize_target = resize_target
46 | self.__keep_aspect_ratio = keep_aspect_ratio
47 | self.__multiple_of = ensure_multiple_of
48 | self.__resize_method = resize_method
49 | self.__image_interpolation_method = image_interpolation_method
50 |
51 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53 |
54 | if max_val is not None and y > max_val:
55 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56 |
57 | if y < min_val:
58 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59 |
60 | return y
61 |
62 | def get_size(self, width, height):
63 | # determine new height and width
64 | scale_height = self.__height / height
65 | scale_width = self.__width / width
66 |
67 | if self.__keep_aspect_ratio:
68 | if self.__resize_method == "lower_bound":
69 | # scale such that output size is lower bound
70 | if scale_width > scale_height:
71 | # fit width
72 | scale_height = scale_width
73 | else:
74 | # fit height
75 | scale_width = scale_height
76 | elif self.__resize_method == "upper_bound":
77 | # scale such that output size is upper bound
78 | if scale_width < scale_height:
79 | # fit width
80 | scale_height = scale_width
81 | else:
82 | # fit height
83 | scale_width = scale_height
84 | elif self.__resize_method == "minimal":
85 | # scale as least as possbile
86 | if abs(1 - scale_width) < abs(1 - scale_height):
87 | # fit width
88 | scale_height = scale_width
89 | else:
90 | # fit height
91 | scale_width = scale_height
92 | else:
93 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
94 |
95 | if self.__resize_method == "lower_bound":
96 | new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97 | new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98 | elif self.__resize_method == "upper_bound":
99 | new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100 | new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101 | elif self.__resize_method == "minimal":
102 | new_height = self.constrain_to_multiple_of(scale_height * height)
103 | new_width = self.constrain_to_multiple_of(scale_width * width)
104 | else:
105 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
106 |
107 | return (new_width, new_height)
108 |
109 | def __call__(self, sample):
110 | width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111 |
112 | # resize sample
113 | sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114 |
115 | if self.__resize_target:
116 | if "depth" in sample:
117 | sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118 |
119 | if "mask" in sample:
120 | sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121 |
122 | return sample
123 |
124 |
125 | class NormalizeImage(object):
126 | """Normlize image by given mean and std.
127 | """
128 |
129 | def __init__(self, mean, std):
130 | self.__mean = mean
131 | self.__std = std
132 |
133 | def __call__(self, sample):
134 | sample["image"] = (sample["image"] - self.__mean) / self.__std
135 |
136 | return sample
137 |
138 |
139 | class PrepareForNet(object):
140 | """Prepare sample for usage as network input.
141 | """
142 |
143 | def __init__(self):
144 | pass
145 |
146 | def __call__(self, sample):
147 | image = np.transpose(sample["image"], (2, 0, 1))
148 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149 |
150 | if "depth" in sample:
151 | depth = sample["depth"].astype(np.float32)
152 | sample["depth"] = np.ascontiguousarray(depth)
153 |
154 | if "mask" in sample:
155 | sample["mask"] = sample["mask"].astype(np.float32)
156 | sample["mask"] = np.ascontiguousarray(sample["mask"])
157 |
158 | return sample
--------------------------------------------------------------------------------
/depth_anything_v2/metric_depth_estimation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import matplotlib
4 | import numpy as np
5 | import os
6 | import torch
7 |
8 | import sys
9 | sys.path.append('.')
10 | from .depth_anything_v2.dpt import DepthAnythingV2
11 |
12 | def estimate(depth_anything, input_size, imagepath, output_dir, is_vis=False):
13 | cmap = matplotlib.colormaps.get_cmap('Spectral')
14 |
15 | imagename = imagepath.split('/')[-1].split('.')[0]
16 |
17 | raw_image = cv2.imread(imagepath)
18 |
19 | metric_depth = depth_anything.infer_image(raw_image, input_size)
20 |
21 | output_path = os.path.join(output_dir, imagename + '_depth_meter.npy')
22 | np.save(output_path, metric_depth)
23 |
24 | relative_depth = (metric_depth - metric_depth.min()) / (metric_depth.max() - metric_depth.min())
25 |
26 | vis_depth = None
27 | if is_vis:
28 | vis_depth = relative_depth * 255.0
29 | vis_depth = vis_depth.astype(np.uint8)
30 | vis_depth = (cmap(vis_depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
31 |
32 | output_path = os.path.join(output_dir, imagename + '_depth.png')
33 | cv2.imwrite(output_path, vis_depth)
34 |
35 | return metric_depth, relative_depth, vis_depth
36 |
37 |
38 | def depth_estimation(imagepath,
39 | input_size=518,
40 | encoder='vitl',
41 | load_from='./depth_anything_v2/ckpts/depth_anything_v2_metric_hypersim_vitl.pth',
42 | max_depth=20,
43 | output_dir='./tmp',
44 | is_vis=False):
45 |
46 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
47 |
48 | model_configs = {
49 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
50 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
51 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
52 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
53 | }
54 |
55 | depth_anything = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth})
56 | depth_anything.load_state_dict(torch.load(load_from, map_location='cpu'))
57 | depth_anything = depth_anything.to(DEVICE).eval()
58 |
59 | metric_depth, relative_depth, vis_depth = estimate(depth_anything, input_size, imagepath, output_dir, is_vis)
60 |
61 | return metric_depth, relative_depth, vis_depth
62 |
63 |
64 | if __name__ == '__main__':
65 |
66 | parser = argparse.ArgumentParser()
67 |
68 | parser.add_argument('--input-size', type=int, default=518)
69 |
70 | parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
71 | parser.add_argument('--load-from', type=str, default='../Depth-Anything-V2/checkpoints/depth_anything_v2_metric_hypersim_vitl.pth')
72 | parser.add_argument('--max-depth', type=float, default=20)
73 |
74 | # set the gpu
75 | parser.add_argument('--gpu', type=int, default=0, help='gpu id')
76 | parser.add_argument('--output_dir', type=str, default='./tmp', help='output dir')
77 | parser.add_argument('--imagepath', type=str, default='tmp.jpg', help='imagepath')
78 |
79 | args = parser.parse_args()
80 |
81 | output_dir = args.output_dir
82 | os.makedirs(output_dir, exist_ok=True)
83 |
84 | # set the gpu
85 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
86 |
87 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
88 |
89 | model_configs = {
90 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
91 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
92 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
93 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
94 | }
95 |
96 | depth_anything = DepthAnythingV2(**{**model_configs[args.encoder], 'max_depth': args.max_depth})
97 | depth_anything.load_state_dict(torch.load(args.load_from, map_location='cpu'))
98 | depth_anything = depth_anything.to(DEVICE).eval()
99 |
100 | estimate(depth_anything, args.input_size, args.imagepath, args.output_dir)
--------------------------------------------------------------------------------
/examples/DollyOut.txt:
--------------------------------------------------------------------------------
1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
2 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 6.666666666666666574e-02
3 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 1.333333333333333315e-01
4 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 2.000000000000000111e-01
5 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 2.666666666666666630e-01
6 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 3.333333333333333148e-01
7 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 4.000000000000000222e-01
8 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 4.666666666666666741e-01
9 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 5.333333333333333259e-01
10 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 5.999999999999999778e-01
11 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 6.666666666666666297e-01
12 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 7.333333333333332815e-01
13 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 8.000000000000000444e-01
14 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 8.666666666666666963e-01
15 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 9.333333333333333481e-01
16 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 1.000000000000000000e+00
17 |
--------------------------------------------------------------------------------
/examples/Still.txt:
--------------------------------------------------------------------------------
1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
2 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
3 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
4 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
5 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
6 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
7 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
8 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
9 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
10 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
11 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
12 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
13 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
14 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
15 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
16 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
17 |
--------------------------------------------------------------------------------
/examples/TiltDown.txt:
--------------------------------------------------------------------------------
1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 -0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00
2 | 1.000000000000000222e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.999323080022165522e-01 -1.163526593424524941e-02 -2.327371078232923999e-04 0.000000000000000000e+00 1.163526593424524941e-02 9.999323080022165522e-01 4.103782897194226538e-05
3 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.997292411732788819e-01 -2.326895663969883593e-02 -4.654742156465847998e-04 0.000000000000000000e+00 2.326895663969883593e-02 9.997292411732788819e-01 8.207565794388453075e-05
4 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.993908270505680314e-01 -3.489949580125043666e-02 -6.982113234698773081e-04 0.000000000000000000e+00 3.489949580125043666e-02 9.993908270505680314e-01 1.231134869158268029e-04
5 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.989171113137995661e-01 -4.652531272860008921e-02 -9.309484312931695996e-04 0.000000000000000000e+00 4.652531272860008921e-02 9.989171113137995661e-01 1.641513158877690615e-04
6 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.983081583082533683e-01 -5.814482827546655491e-02 -1.163685539116461999e-03 0.000000000000000000e+00 5.814482827546655491e-02 9.983081583082533683e-01 2.051891448597113201e-04
7 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.975640503856371133e-01 -6.975647194491900460e-02 -1.396422646939754616e-03 0.000000000000000000e+00 6.975647194491900460e-02 9.975640503856371133e-01 2.462269738316536058e-04
8 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.966848882862320291e-01 -8.135867170659374925e-02 -1.629159754763047016e-03 0.000000000000000000e+00 8.135867170659374925e-02 9.966848882862320291e-01 2.872648028035958644e-04
9 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.956707905510119305e-01 -9.294986198764870755e-02 -1.861896862586339199e-03 0.000000000000000000e+00 9.294986198764870755e-02 9.956707905510119305e-01 3.283026317755381230e-04
10 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.945218953792129835e-01 -1.045284631635702705e-01 -2.094633970409631816e-03 0.000000000000000000e+00 1.045284631635702705e-01 9.945218953792129835e-01 3.693404607474803816e-04
11 | 1.000000000000000222e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.932383578896332166e-01 -1.160929128616613598e-01 -2.327371078232923999e-03 0.000000000000000000e+00 1.160929128616613598e-01 9.932383578896332166e-01 4.103782897194226402e-04
12 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.918203518526672591e-01 -1.276416454408651757e-01 -2.560108186056216182e-03 0.000000000000000000e+00 1.276416454408651757e-01 9.918203518526672591e-01 4.514161186913648988e-04
13 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.902680692435954501e-01 -1.391730973879709010e-01 -2.792845293879509232e-03 0.000000000000000000e+00 1.391730973879709010e-01 9.902680692435954501e-01 4.924539476633072116e-04
14 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.885817202165927409e-01 -1.506857075292882542e-01 -3.025582401702801415e-03 0.000000000000000000e+00 1.506857075292882542e-01 9.885817202165927409e-01 5.334917766352494702e-04
15 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.867615330762774528e-01 -1.621779172420052539e-01 -3.258319509526094032e-03 0.000000000000000000e+00 1.621779172420052539e-01 9.867615330762774528e-01 5.745296056071917288e-04
16 | 1.000000000000000222e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 9.848077542468020029e-01 -1.736481706652007739e-01 -3.491056617349386215e-03 0.000000000000000000e+00 1.736481706652007739e-01 9.848077542468020029e-01 6.155674345791339874e-04
17 |
--------------------------------------------------------------------------------
/examples/backview.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/backview.jpeg
--------------------------------------------------------------------------------
/examples/backview.json:
--------------------------------------------------------------------------------
1 | [[[548, 423, 0.5], [562, 412, 0.5], [555, 428, 0.5], [545, 413, 0.5], [553, 433, 0.5]], [[247, 242, 0.5], [264, 233, 0.5], [287, 229, 0.5]], [[411, 185, 0.5]], [[65, 234, 0.5], [77, 229, 0.5], [81, 222, 0.5], [69, 228, 0.5], [64, 229, 0.5], [54, 229, 0.5]], [[148, 107, 0.5], [160, 102, 0.5], [164, 95, 0.5], [152, 101, 0.5], [147, 102, 0.5], [137, 102, 0.5]], [[103, 142, 0.5], [111, 138, 0.5], [114, 133, 0.5], [105, 137, 0.5], [102, 138, 0.5], [95, 138, 0.5]], [[235, 23, 0.5], [243, 19, 0.5], [246, 14, 0.5], [237, 18, 0.5], [234, 19, 0.5], [227, 19, 0.5]], [[745, 166, 0.5], [746, 159, 0.5], [743, 172, 0.5], [748, 179, 0.5], [752, 185, 0.5]], [[729, 70, 0.5], [729, 65, 0.5], [727, 74, 0.5], [731, 79, 0.5], [733, 83, 0.5]], [[366, 287, 0.5], [356, 281, 0.5], [368, 288, 0.5]], [[453, 288, 0.5], [467, 280, 0.5], [450, 290, 0.5]], [[440, 423, 0.5], [449, 420, 0.5], [438, 424, 0.5]], [[368, 407, 0.5], [365, 405, 0.5], [364, 402, 0.5], [372, 408, 0.5]], [[63, 441, 0.5], [71, 441, 0.5], [54, 438, 0.5]], [[140, 444, 0.5], [145, 444, 0.5], [133, 441, 0.5]], [[196, 457, 0.5], [199, 457, 0.5], [191, 454, 0.5]], [[744, 442, 0.5], [746, 442, 0.5], [740, 439, 0.5]], [[463, 70, 0.5]], [[298, 149, 0.5]], [[205, 202, 0.5]], [[83, 337, 0.5]], [[564, 362, 0.5]], [[408, 359, 0.5]], [[404, 232, 0.5]], [[611, 285, 0.5]], [[722, 308, 0.5]], [[644, 327, 0.5]], [[612, 90, 0.5]], [[34, 26, 0.5]], [[3, 102, 0.5]], [[10, 169, 0.5]], [[198, 395, 0.5], [254, 404, 0.5]], [[18, 38, 0.5]], [[424, 486, 0.5]], [[99, 310, 0.5], [23, 316, 0.5]], [[306, 289, 0.5], [234, 289, 0.5]], [[608, 273, 0.5], [561, 269, 0.5]], [[199, 398, 0.5], [273, 377, 0.5]], [[213, 404, 0.5], [287, 383, 0.5]]]
--------------------------------------------------------------------------------
/examples/balloon.json:
--------------------------------------------------------------------------------
1 | [[[560, 147, 0.5], [352, 151, 0.5]], [[359, 283, 0.5], [490, 288, 0.5]], [[124, 99, 0.5]], [[95, 397, 0.5]], [[653, 347, 0.5]], [[642, 465, 0.5]]]
--------------------------------------------------------------------------------
/examples/balloon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/balloon.png
--------------------------------------------------------------------------------
/examples/balloons.json:
--------------------------------------------------------------------------------
1 | [[[520, 90, 0.5], [338, 84, 0.5]], [[557, 197, 0.5], [666, 197, 0.5]], [[334, 246, 0.5], [402, 307, 0.5]], [[604, 283, 0.5], [629, 272, 0.5], [619, 270, 0.5], [600, 270, 0.5], [584, 270, 0.5], [574, 268, 0.5], [574, 252, 0.5], [598, 248, 0.5], [610, 248, 0.5]], [[524, 319, 0.5]]]
--------------------------------------------------------------------------------
/examples/balloons.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/balloons.png
--------------------------------------------------------------------------------
/examples/bamboo.json:
--------------------------------------------------------------------------------
1 | [[[179, 146, 0.5], [209, 146, 0.5], [157, 145, 0.5], [204, 146, 0.5], [162, 146, 0.5]], [[158, 306, 0.5], [185, 306, 0.5], [138, 305, 0.5], [180, 306, 0.5], [142, 306, 0.5]], [[336, 236, 0.5], [336, 215, 0.5], [337, 267, 0.5], [335, 213, 0.5], [335, 253, 0.5]], [[305, 361, 0.5], [284, 376, 0.5], [327, 351, 0.5], [289, 371, 0.5], [319, 356, 0.5]]]
--------------------------------------------------------------------------------
/examples/bamboo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/examples/bamboo.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import numpy as np
5 | import gradio as gr
6 | from PIL import Image
7 | from pathlib import Path
8 | from datetime import datetime
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torchvision
13 | import torchvision.transforms as transforms
14 |
15 | from diffusers import AutoencoderKL, DDIMScheduler
16 | from einops import repeat
17 | from omegaconf import OmegaConf
18 | from decord import VideoReader
19 | from transformers import CLIPVisionModelWithProjection
20 |
21 | from src.models.motion_encoder import MotionEncoder
22 | from src.models.fusion_module import FusionModule
23 | from src.models.unet_2d_condition import UNet2DConditionModel
24 | from src.models.unet_3d import UNet3DConditionModel
25 | from src.pipelines.pipeline_motion2vid_merge_infer import Motion2VideoPipeline
26 | from src.utils.util import save_videos_grid
27 | from src.utils.utils import interpolate_trajectory, interpolate_trajectory_3d
28 | from src.utils.visualizer import Visualizer
29 |
30 | import safetensors
31 |
32 |
33 | def visualize_tracks(background_image_path, splited_tracks, video_length, width, height, save_path='./tmp/hint.mp4'):
34 |
35 | background_image = Image.open(background_image_path).convert('RGBA')
36 | background_image = background_image.resize((width, height))
37 | w, h = background_image.size
38 | transparent_background = np.array(background_image)
39 | transparent_background[:, :, -1] = 128
40 | transparent_background = Image.fromarray(transparent_background)
41 |
42 | # Create a transparent layer with the same size as the background image
43 | transparent_layer = np.zeros((h, w, 4))
44 | for splited_track in splited_tracks:
45 | if len(splited_track) > 1:
46 | splited_track = interpolate_trajectory(splited_track, video_length)
47 | for i in range(len(splited_track)-1):
48 | start_point = (int(splited_track[i][0]), int(splited_track[i][1]))
49 | end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1]))
50 | vx = end_point[0] - start_point[0]
51 | vy = end_point[1] - start_point[1]
52 | arrow_length = np.sqrt(vx**2 + vy**2) + 1e-6
53 | if i == len(splited_track)-2:
54 | cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length)
55 | else:
56 | cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2)
57 | else:
58 | cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1)
59 |
60 | transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
61 | trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
62 |
63 | save_dir = os.path.dirname(save_path)
64 | save_name = os.path.basename(save_path).split('.')[0]
65 | vis_track = Visualizer(save_dir=save_dir,
66 | pad_value=0,
67 | linewidth=10,
68 | mode='optical_flow',
69 | tracks_leave_trace=-1)
70 | video = np.repeat(np.array(background_image.convert('RGB'))[None, None, ...], splited_tracks.shape[1], 1)
71 | video = video.transpose(0, 1, 4, 2, 3)
72 | tracks = splited_tracks[:, :, :2][None].transpose(0, 2, 1, 3)
73 | depths = splited_tracks[:, :, 2][None].transpose(0, 2, 1)
74 | video_tracks = vis_track.visualize(torch.from_numpy(video),
75 | torch.from_numpy(tracks),
76 | depths=torch.from_numpy(depths),
77 | filename=save_name,
78 | is_depth_norm=True,
79 | query_frame=0)
80 |
81 | return trajectory_map, transparent_layer, video_tracks
82 |
83 |
84 | class Net(nn.Module):
85 | def __init__(
86 | self,
87 | reference_unet: UNet2DConditionModel,
88 | denoising_unet: UNet3DConditionModel,
89 | obj_encoder: MotionEncoder,
90 | cam_encoder: MotionEncoder,
91 | fusion_module: FusionModule,
92 | ):
93 | super().__init__()
94 | self.reference_unet = reference_unet
95 | self.denoising_unet = denoising_unet
96 | self.obj_encoder = obj_encoder
97 | self.cam_encoder = cam_encoder
98 | self.fusion_module = fusion_module
99 |
100 |
101 | class Model():
102 | def __init__(self, config_path):
103 | self.config_path = config_path
104 | self.load_config()
105 | self.init_model()
106 | self.init_savedir()
107 |
108 | self.output_dir = './tmp'
109 | os.makedirs(self.output_dir, exist_ok=True)
110 |
111 | def load_config(self):
112 | self.config = OmegaConf.load(self.config_path)
113 |
114 | def init_model(self):
115 |
116 | if self.config.weight_dtype == "fp16":
117 | weight_dtype = torch.float16
118 | else:
119 | weight_dtype = torch.float32
120 |
121 | vae = AutoencoderKL.from_pretrained(
122 | self.config.vae_model_path,
123 | ).to("cuda", dtype=weight_dtype)
124 |
125 | reference_unet = UNet2DConditionModel.from_pretrained(
126 | self.config.base_model_path,
127 | subfolder="unet",
128 | ).to(dtype=weight_dtype, device="cuda")
129 |
130 | inference_config_path = self.config.inference_config
131 | infer_config = OmegaConf.load(inference_config_path)
132 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
133 | self.config.base_model_path,
134 | self.config.mm_path,
135 | subfolder="unet",
136 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
137 | ).to(dtype=weight_dtype, device="cuda")
138 |
139 | obj_encoder = MotionEncoder(
140 | conditioning_embedding_channels=320,
141 | conditioning_channels = 3,
142 | block_out_channels=(16, 32, 96, 256)
143 | ).to(device="cuda")
144 |
145 | cam_encoder = MotionEncoder(
146 | conditioning_embedding_channels=320,
147 | conditioning_channels = 1,
148 | block_out_channels=(16, 32, 96, 256)
149 | ).to(device="cuda")
150 |
151 | fusion_module = FusionModule(
152 | fusion_type=self.config.fusion_type,
153 | ).to(device="cuda")
154 |
155 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
156 | self.config.image_encoder_path
157 | ).to(dtype=weight_dtype, device="cuda")
158 |
159 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
160 | scheduler = DDIMScheduler(**sched_kwargs)
161 |
162 | self.generator = torch.manual_seed(self.config.seed)
163 |
164 | denoising_unet.load_state_dict(
165 | torch.load(self.config.denoising_unet_path, map_location="cpu"),
166 | strict=False,
167 | )
168 | reference_unet.load_state_dict(
169 | torch.load(self.config.reference_unet_path, map_location="cpu"),
170 | )
171 | obj_encoder.load_state_dict(
172 | torch.load(self.config.obj_encoder_path, map_location="cpu"),
173 | )
174 | cam_encoder.load_state_dict(
175 | torch.load(self.config.cam_encoder_path, map_location="cpu"),
176 | )
177 |
178 | pipe = Motion2VideoPipeline(
179 | vae=vae,
180 | image_encoder=image_enc,
181 | reference_unet=reference_unet,
182 | denoising_unet=denoising_unet,
183 | obj_encoder=obj_encoder,
184 | cam_encoder=cam_encoder,
185 | fusion_module=fusion_module,
186 | scheduler=scheduler,
187 | )
188 | self.pipe = pipe.to("cuda", dtype=weight_dtype)
189 |
190 | def init_savedir(self):
191 | date_str = datetime.now().strftime("%Y%m%d")
192 | time_str = datetime.now().strftime("%H%M")
193 | save_dir_name = f"{time_str}"
194 |
195 | if 'save_dir' not in self.config:
196 | self.save_dir = Path(f"output/{date_str}/{save_dir_name}")
197 | else:
198 | self.save_dir = Path(self.config.save_dir)
199 | self.save_dir.mkdir(exist_ok=True, parents=True)
200 |
201 | def get_control_objs(self, track_path, ori_sample_size, sample_size, video_length=16, imagepath=None, vo_path=None, cam_type='generated'):
202 | imagename = imagepath.split('/')[-1].split('.')[0]
203 |
204 | vis = Visualizer(save_dir='./tmp',
205 | grayscale=False,
206 | mode='rainbow_all',
207 | pad_value=0,
208 | linewidth=1,
209 | tracks_leave_trace=1)
210 |
211 | with open(track_path, 'r') as f:
212 | track_points_user = json.load(f)
213 | track_points = []
214 | for splited_track in track_points_user:
215 | splited_track = np.array(splited_track)
216 | if splited_track.shape[-1] == 3:
217 | splited_track = interpolate_trajectory_3d(splited_track, video_length)
218 | else:
219 | splited_track = interpolate_trajectory(splited_track, video_length)
220 | splited_track = np.array(splited_track)
221 | track_points.append(splited_track)
222 | track_points = np.array(track_points)
223 |
224 | if track_points.shape[0] == 0 and not self.config.is_adapted:
225 | return None, None, None, None
226 |
227 | query_path = os.path.join("./tmp/tmp_query.npy")
228 | if track_points.shape[0] != 0:
229 | query_points = track_points.transpose(1, 0, 2)
230 | np.save(query_path, query_points)
231 | tracks = query_points
232 |
233 | # get adapted tracks
234 | if track_points.shape[0] == 0:
235 | query_path = 'none'
236 |
237 | projected_3d_points, points_depth, cam_path = cam_adaptation(imagepath, vo_path, query_path, cam_type, video_length, self.output_dir)
238 |
239 | if self.config.is_adapted or track_points.shape[0] == 0:
240 | tracks = projected_3d_points[:, :, :2]
241 | tracks = np.concatenate([tracks, np.ones_like(tracks[:, :, 0])[..., None] * 0.5], -1)
242 |
243 | # get depth
244 | if self.config.is_depth:
245 | tracks = np.concatenate([tracks.astype(float)[..., :2], points_depth], -1)
246 |
247 | T, _, _ = tracks.shape
248 |
249 | tracks[:, :, 0] /= ori_sample_size[1]
250 | tracks[:, :, 1] /= ori_sample_size[0]
251 |
252 | tracks[:, :, 0] *= sample_size[1]
253 | tracks[:, :, 1] *= sample_size[0]
254 |
255 | tracks[..., 0] = np.clip(tracks[:, :, 0], 0, sample_size[1] - 1)
256 | tracks[..., 1] = np.clip(tracks[:, :, 1], 0, sample_size[0] - 1)
257 |
258 | tracks = tracks[np.newaxis, :, :, :]
259 | tracks = torch.tensor(tracks)
260 |
261 | pred_tracks = tracks[:, :, :, :3]
262 |
263 | # vis tracks
264 | splited_tracks = []
265 | for i in range(pred_tracks.shape[2]):
266 | splited_tracks.append(pred_tracks[0, :, i, :3])
267 | splited_tracks = np.array(splited_tracks)
268 |
269 | video = torch.zeros(T, 3, sample_size[0], sample_size[1])[None].float()
270 |
271 | vis_objs = vis.visualize(video=video,
272 | tracks=pred_tracks[..., :2],
273 | filename=Path(track_path).stem,
274 | depths=pred_tracks[..., 2],
275 | circle_scale=self.config.circle_scale,
276 | is_blur=False,
277 | is_depth_norm=True,
278 | save_video=False
279 | )
280 | tracks = tracks.squeeze().numpy()
281 | vis_objs = vis_objs.squeeze().numpy()
282 | guide_value_objs = vis_objs
283 |
284 | return guide_value_objs, splited_tracks, cam_path
285 |
286 | def get_control_cams(self, cam_path, sample_size):
287 | vr = VideoReader(cam_path)
288 | cams = vr.get_batch(list(range(0, len(vr)))).asnumpy()[:, :, :, ::-1]
289 | resized_cams = []
290 | resized_rgb_cams = []
291 | for i in range(cams.shape[0]):
292 | if i == 0:
293 | cam_height, cam_width = cams[i].shape[:2]
294 | frame = np.array(Image.fromarray(cams[i]).convert('L').resize([sample_size[1], sample_size[0]]))
295 | resized_cams.append(frame)
296 | frame_rgb = np.array(Image.fromarray(cams[i]).convert('RGB').resize([sample_size[1], sample_size[0]]))
297 | resized_rgb_cams.append(frame_rgb)
298 | guide_value_cams = np.array(resized_cams)[..., None]
299 | del vr
300 | guide_value_cams = guide_value_cams.transpose(0, 3, 1, 2)
301 | return guide_value_cams
302 |
303 |
304 | def run(self, ref_image_path, cam_path, track_path, seed=None):
305 | if not seed:
306 | seed = self.config.seed
307 | self.generator = torch.manual_seed(seed)
308 | video_length = self.config.sample_n_frames
309 |
310 | if os.path.exists(cam_path) and cam_path.endswith('.txt'):
311 | vo_path = cam_path
312 | cam_type = 'user-provided'
313 | else:
314 | vo_path = cam_path
315 | cam_type = 'generated'
316 |
317 | ref_name = Path(ref_image_path).stem
318 |
319 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
320 | ref_image_path = f'{self.save_dir}/{ref_name}_ref.png'
321 | w, h = ref_image_pil.size
322 | ori_width, ori_height = w // 2 * 2, h // 2 * 2
323 | ref_image_pil.resize((ori_width, ori_height)).save(ref_image_path)
324 |
325 | image_transform = transforms.Compose(
326 | [transforms.Resize((self.config.H, self.config.W)), transforms.ToTensor()]
327 | )
328 |
329 | ref_image_tensor = image_transform(ref_image_pil) # (c, h, w)
330 | ref_image_pil = Image.fromarray((ref_image_tensor * 255).permute(1, 2, 0).numpy().astype(np.uint8)).convert("RGB")
331 |
332 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
333 | ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video_length)
334 |
335 | control_objs, splited_tracks, cam_path = self.get_control_objs(
336 | track_path,
337 | ori_sample_size=(ori_height, ori_width),
338 | sample_size=(self.config.H, self.config.W),
339 | video_length=video_length,
340 | imagepath=ref_image_path,
341 | vo_path=vo_path.replace('.mp4', '.txt', 1),
342 | cam_type=cam_type,
343 | )
344 |
345 | control_cams = self.get_control_cams(cam_path, (self.config.H, self.config.W))
346 |
347 | ref_image_path = f'{self.save_dir}/{ref_name}_ref.png'
348 | ref_image_pil.save(ref_image_path)
349 |
350 | if control_objs is None or self.config.cam_only:
351 | control_objs = np.zeros_like(control_cams).repeat(3, 1)
352 | if self.config.obj_only:
353 | control_cams = np.zeros_like(control_cams)
354 |
355 | control_objs = torch.from_numpy(control_objs).unsqueeze(0).float()
356 | control_objs = control_objs.transpose(
357 | 1, 2
358 | ) # (bs, c, f, H, W)
359 | control_cams = torch.from_numpy(control_cams).unsqueeze(0).float()
360 | control_cams = control_cams.transpose(
361 | 1, 2
362 | ) # (bs, c, f, H, W)
363 |
364 | video = self.pipe(
365 | reference_image=ref_image_pil,
366 | control_objs=control_objs,
367 | control_cams=control_cams,
368 | width=self.config.W,
369 | height=self.config.H,
370 | video_length=video_length,
371 | num_inference_steps=self.config.steps,
372 | guidance_scale=self.config.guidance_scale,
373 | generator=self.generator,
374 | is_obj=self.config.is_obj,
375 | is_cam=self.config.is_cam
376 | ).videos
377 |
378 | cam_pattern_postfix = vo_path.split('/')[-1].split('.')[0]
379 | video_path = f"{self.save_dir}/{ref_name}_{cam_pattern_postfix}_gen.mp4"
380 | hint_path = f'{self.save_dir}/{ref_name}_{cam_pattern_postfix}_hint.mp4'
381 | vis_obj_path = f'{self.save_dir}/{ref_name}_{cam_pattern_postfix}_obj.mp4'
382 | vis_cam_path = f'{self.save_dir}/{ref_name}_{cam_pattern_postfix}_cam.mp4'
383 |
384 | _, _, _ = visualize_tracks(ref_image_path, splited_tracks, video_length, self.config.W, self.config.H, save_path=hint_path)
385 |
386 | torchvision.io.write_video(vis_obj_path, control_objs[0].permute(1, 2, 3, 0).numpy(), fps=8, video_codec='h264', options={'crf': '10'})
387 | torchvision.io.write_video(vis_cam_path, control_cams[0].permute(1, 2, 3, 0).numpy().repeat(3, 3), fps=8, video_codec='h264', options={'crf': '10'})
388 |
389 | save_videos_grid(
390 | video,
391 | video_path,
392 | n_rows=1,
393 | fps=8,
394 | )
395 |
396 | return video_path, hint_path, vis_obj_path, vis_cam_path
397 |
398 |
399 | if __name__ == '__main__':
400 |
401 | config_path = 'configs/eval.yaml'
402 | config = OmegaConf.load(config_path)
403 | Model = Model(config_path)
404 |
405 | path_sets = []
406 | for test_case in config["test_cases"]:
407 | ref_image_path = list(test_case.keys())[0]
408 | track_path = test_case[ref_image_path][0]
409 | cam_path = test_case[ref_image_path][1]
410 | seed = test_case[ref_image_path][2]
411 | path_set = {'ref_image_path': ref_image_path,
412 | 'track_path': track_path,
413 | 'cam_path': cam_path,
414 | 'seed': seed,
415 | }
416 | path_sets.append(path_set)
417 |
418 | for path_set in path_sets:
419 | ref_image_path = path_set['ref_image_path']
420 | track_path = path_set['track_path']
421 | cam_path = path_set['cam_path']
422 | seed = path_set['seed']
423 |
424 | Model.run(ref_image_path, cam_path, track_path, seed)
--------------------------------------------------------------------------------
/mesh/world_envelope.mtl:
--------------------------------------------------------------------------------
1 | # Blender 4.2.0 MTL File: 'None'
2 | # www.blender.org
3 |
4 | newmtl Material
5 | Ns 250.000000
6 | Ka 1.000000 1.000000 1.000000
7 | Kd 0.800000 0.800000 0.800000
8 | Ks 0.500000 0.500000 0.500000
9 | Ke 0.000000 0.000000 0.000000
10 | Ni 1.450000
11 | d 1.000000
12 | illum 2
13 | map_Kd world_envelope.png
14 |
--------------------------------------------------------------------------------
/mesh/world_envelope.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/mesh/world_envelope.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | av==11.0.0
3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4 | decord==0.6.0
5 | diffusers==0.24.0
6 | einops==0.4.1
7 | gradio==3.41.2
8 | gradio_client==0.5.0
9 | imageio==2.33.0
10 | imageio-ffmpeg==0.4.9
11 | numpy==1.23.5
12 | omegaconf==2.2.3
13 | onnxruntime-gpu==1.16.3
14 | open3d==0.19.0
15 | open-clip-torch==2.20.0
16 | opencv-contrib-python==4.8.1.78
17 | opencv-python==4.8.1.78
18 | Pillow==9.5.0
19 | scikit-image==0.21.0
20 | scikit-learn==1.3.2
21 | scipy==1.11.4
22 | torch==2.0.1
23 | torchdiffeq==0.2.3
24 | torchmetrics==1.2.1
25 | torchsde==0.2.5
26 | torchvision==0.15.2
27 | tqdm==4.66.1
28 | transformers==4.30.2
29 | mlflow==2.9.2
30 | xformers==0.0.22
31 | controlnet-aux==0.0.7
32 | git+https://github.com/facebookresearch/pytorch3d.git
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/.DS_Store
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/__init__.py
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/fusion_module.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/fusion_module.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/gated_self_attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/gated_self_attention.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/motion_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/motion_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/motion_module.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/motion_module.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/mutual_self_attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/mutual_self_attention.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/pose_guider.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/pose_guider.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/resnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/resnet.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/transformer_2d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/transformer_2d.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/transformer_3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/transformer_3d.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_2d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_2d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_2d_condition.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_2d_condition.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_3d.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_3d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/models/__pycache__/unet_3d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/fusion_module.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from einops import rearrange
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.nn.init as init
8 | from diffusers.models.modeling_utils import ModelMixin
9 | from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
10 |
11 | from src.models.motion_module import zero_module
12 | from src.models.resnet import InflatedConv3d
13 |
14 |
15 | class FusionModule(ModelMixin):
16 | def __init__(
17 | self,
18 | in_channels: int = 640,
19 | out_channels: int = 320,
20 | fusion_type: str = 'conv',
21 | fusion_with_norm: bool = True,
22 | ):
23 | super().__init__()
24 |
25 | self.fusion_type = fusion_type
26 | self.fusion_with_norm = fusion_with_norm
27 | self.norm1 = nn.LayerNorm(out_channels)
28 | self.norm2 = nn.LayerNorm(out_channels)
29 |
30 | if self.fusion_type == 'sum':
31 | self.fusion = None
32 | elif self.fusion_type == 'max':
33 | self.fusion = None
34 | elif self.fusion_type == 'conv':
35 | self.fusion = InflatedConv3d(
36 | in_channels, out_channels, kernel_size=3, padding=1
37 | )
38 |
39 | def forward(self, feat1, feat2):
40 | b, c, t, h, w = feat1.shape
41 |
42 | if self.fusion_type == 'sum':
43 | return feat1 + feat2
44 | elif self.fusion_type == 'max':
45 | return torch.max(feat1, feat2)
46 | elif self.fusion_type == 'conv':
47 | b, c, t, h, w = feat1.shape
48 | if self.fusion_with_norm:
49 | feat1 = rearrange(feat1, "b c t h w -> (b t) (h w) c")
50 | feat2 = rearrange(feat2, "b c t h w -> (b t) (h w) c")
51 | feat1 = self.norm1(feat1)
52 | feat2 = self.norm2(feat2)
53 | feat1 = feat1.view(b, t, h, w, c)
54 | feat2 = feat2.view(b, t, h, w, c)
55 | feat1 = feat1.permute(0, 4, 1, 2, 3).contiguous()
56 | feat2 = feat2.permute(0, 4, 1, 2, 3).contiguous()
57 | feat = torch.concat((feat1, feat2), 1)
58 | feat = self.fusion(feat)
59 |
60 | return feat
--------------------------------------------------------------------------------
/src/models/motion_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 | from diffusers.models.modeling_utils import ModelMixin
7 |
8 | from src.models.motion_module import zero_module
9 | from src.models.resnet import InflatedConv3d
10 |
11 |
12 | class MotionEncoder(ModelMixin):
13 | def __init__(
14 | self,
15 | conditioning_embedding_channels: int,
16 | conditioning_channels: int = 3,
17 | block_out_channels: Tuple[int] = (16, 32, 64, 128),
18 | ):
19 | super().__init__()
20 | self.conv_in = InflatedConv3d(
21 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22 | )
23 |
24 | self.blocks = nn.ModuleList([])
25 |
26 | for i in range(len(block_out_channels) - 1):
27 | channel_in = block_out_channels[i]
28 | channel_out = block_out_channels[i + 1]
29 | self.blocks.append(
30 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31 | )
32 | self.blocks.append(
33 | InflatedConv3d(
34 | channel_in, channel_out, kernel_size=3, padding=1, stride=2
35 | )
36 | )
37 |
38 | self.conv_out = zero_module(
39 | InflatedConv3d(
40 | block_out_channels[-1],
41 | conditioning_embedding_channels,
42 | kernel_size=3,
43 | padding=1,
44 | )
45 | )
46 |
47 | def forward(self, conditioning):
48 | embedding = self.conv_in(conditioning)
49 | embedding = F.silu(embedding)
50 |
51 | for block in self.blocks:
52 | embedding = block(embedding)
53 | embedding = F.silu(embedding)
54 |
55 | embedding = self.conv_out(embedding)
56 |
57 | return embedding
58 |
--------------------------------------------------------------------------------
/src/models/motion_module.py:
--------------------------------------------------------------------------------
1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional
5 |
6 | import torch
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.attention_processor import Attention, AttnProcessor
9 | from diffusers.utils import BaseOutput
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from einops import rearrange, repeat
12 | from torch import nn
13 |
14 |
15 | def zero_module(module):
16 | # Zero out the parameters of a module and return it.
17 | for p in module.parameters():
18 | p.detach().zero_()
19 | return module
20 |
21 |
22 | @dataclass
23 | class TemporalTransformer3DModelOutput(BaseOutput):
24 | sample: torch.FloatTensor
25 |
26 |
27 | if is_xformers_available():
28 | import xformers
29 | import xformers.ops
30 | else:
31 | xformers = None
32 |
33 |
34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35 | if motion_module_type == "Vanilla":
36 | return VanillaTemporalModule(
37 | in_channels=in_channels,
38 | **motion_module_kwargs,
39 | )
40 | else:
41 | raise ValueError
42 |
43 |
44 | class VanillaTemporalModule(nn.Module):
45 | def __init__(
46 | self,
47 | in_channels,
48 | num_attention_heads=8,
49 | num_transformer_block=2,
50 | attention_block_types=("Temporal_Self", "Temporal_Self"),
51 | cross_frame_attention_mode=None,
52 | temporal_position_encoding=False,
53 | temporal_position_encoding_max_len=24,
54 | temporal_attention_dim_div=1,
55 | zero_initialize=True,
56 | ):
57 | super().__init__()
58 |
59 | self.temporal_transformer = TemporalTransformer3DModel(
60 | in_channels=in_channels,
61 | num_attention_heads=num_attention_heads,
62 | attention_head_dim=in_channels
63 | // num_attention_heads
64 | // temporal_attention_dim_div,
65 | num_layers=num_transformer_block,
66 | attention_block_types=attention_block_types,
67 | cross_frame_attention_mode=cross_frame_attention_mode,
68 | temporal_position_encoding=temporal_position_encoding,
69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70 | )
71 |
72 | if zero_initialize:
73 | self.temporal_transformer.proj_out = zero_module(
74 | self.temporal_transformer.proj_out
75 | )
76 |
77 | def forward(
78 | self,
79 | input_tensor,
80 | temb,
81 | encoder_hidden_states,
82 | attention_mask=None,
83 | anchor_frame_idx=None,
84 | ):
85 | hidden_states = input_tensor
86 | hidden_states = self.temporal_transformer(
87 | hidden_states, encoder_hidden_states, attention_mask
88 | )
89 |
90 | output = hidden_states
91 | return output
92 |
93 |
94 | class TemporalTransformer3DModel(nn.Module):
95 | def __init__(
96 | self,
97 | in_channels,
98 | num_attention_heads,
99 | attention_head_dim,
100 | num_layers,
101 | attention_block_types=(
102 | "Temporal_Self",
103 | "Temporal_Self",
104 | ),
105 | dropout=0.0,
106 | norm_num_groups=32,
107 | cross_attention_dim=768,
108 | activation_fn="geglu",
109 | attention_bias=False,
110 | upcast_attention=False,
111 | cross_frame_attention_mode=None,
112 | temporal_position_encoding=False,
113 | temporal_position_encoding_max_len=24,
114 | ):
115 | super().__init__()
116 |
117 | inner_dim = num_attention_heads * attention_head_dim
118 |
119 | self.norm = torch.nn.GroupNorm(
120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121 | )
122 | self.proj_in = nn.Linear(in_channels, inner_dim)
123 |
124 | self.transformer_blocks = nn.ModuleList(
125 | [
126 | TemporalTransformerBlock(
127 | dim=inner_dim,
128 | num_attention_heads=num_attention_heads,
129 | attention_head_dim=attention_head_dim,
130 | attention_block_types=attention_block_types,
131 | dropout=dropout,
132 | norm_num_groups=norm_num_groups,
133 | cross_attention_dim=cross_attention_dim,
134 | activation_fn=activation_fn,
135 | attention_bias=attention_bias,
136 | upcast_attention=upcast_attention,
137 | cross_frame_attention_mode=cross_frame_attention_mode,
138 | temporal_position_encoding=temporal_position_encoding,
139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140 | )
141 | for d in range(num_layers)
142 | ]
143 | )
144 | self.proj_out = nn.Linear(inner_dim, in_channels)
145 |
146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147 | assert (
148 | hidden_states.dim() == 5
149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150 | video_length = hidden_states.shape[2]
151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152 |
153 | batch, channel, height, weight = hidden_states.shape
154 | residual = hidden_states
155 |
156 | hidden_states = self.norm(hidden_states)
157 | inner_dim = hidden_states.shape[1]
158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159 | batch, height * weight, inner_dim
160 | )
161 | hidden_states = self.proj_in(hidden_states)
162 |
163 | # Transformer Blocks
164 | for block in self.transformer_blocks:
165 | hidden_states = block(
166 | hidden_states,
167 | encoder_hidden_states=encoder_hidden_states,
168 | video_length=video_length,
169 | )
170 |
171 | # output
172 | hidden_states = self.proj_out(hidden_states)
173 | hidden_states = (
174 | hidden_states.reshape(batch, height, weight, inner_dim)
175 | .permute(0, 3, 1, 2)
176 | .contiguous()
177 | )
178 |
179 | output = hidden_states + residual
180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181 |
182 | return output
183 |
184 |
185 | class TemporalTransformerBlock(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | num_attention_heads,
190 | attention_head_dim,
191 | attention_block_types=(
192 | "Temporal_Self",
193 | "Temporal_Self",
194 | ),
195 | dropout=0.0,
196 | norm_num_groups=32,
197 | cross_attention_dim=768,
198 | activation_fn="geglu",
199 | attention_bias=False,
200 | upcast_attention=False,
201 | cross_frame_attention_mode=None,
202 | temporal_position_encoding=False,
203 | temporal_position_encoding_max_len=24,
204 | ):
205 | super().__init__()
206 |
207 | attention_blocks = []
208 | norms = []
209 |
210 | for block_name in attention_block_types:
211 | attention_blocks.append(
212 | VersatileAttention(
213 | attention_mode=block_name.split("_")[0],
214 | cross_attention_dim=cross_attention_dim
215 | if block_name.endswith("_Cross")
216 | else None,
217 | query_dim=dim,
218 | heads=num_attention_heads,
219 | dim_head=attention_head_dim,
220 | dropout=dropout,
221 | bias=attention_bias,
222 | upcast_attention=upcast_attention,
223 | cross_frame_attention_mode=cross_frame_attention_mode,
224 | temporal_position_encoding=temporal_position_encoding,
225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226 | )
227 | )
228 | norms.append(nn.LayerNorm(dim))
229 |
230 | self.attention_blocks = nn.ModuleList(attention_blocks)
231 | self.norms = nn.ModuleList(norms)
232 |
233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234 | self.ff_norm = nn.LayerNorm(dim)
235 |
236 | def forward(
237 | self,
238 | hidden_states,
239 | encoder_hidden_states=None,
240 | attention_mask=None,
241 | video_length=None,
242 | ):
243 | for attention_block, norm in zip(self.attention_blocks, self.norms):
244 | norm_hidden_states = norm(hidden_states)
245 | hidden_states = (
246 | attention_block(
247 | norm_hidden_states,
248 | encoder_hidden_states=encoder_hidden_states
249 | if attention_block.is_cross_attention
250 | else None,
251 | video_length=video_length,
252 | )
253 | + hidden_states
254 | )
255 |
256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257 |
258 | output = hidden_states
259 | return output
260 |
261 |
262 | class PositionalEncoding(nn.Module):
263 | def __init__(self, d_model, dropout=0.0, max_len=24):
264 | super().__init__()
265 | self.dropout = nn.Dropout(p=dropout)
266 | position = torch.arange(max_len).unsqueeze(1)
267 | div_term = torch.exp(
268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269 | )
270 | pe = torch.zeros(1, max_len, d_model)
271 | pe[0, :, 0::2] = torch.sin(position * div_term)
272 | pe[0, :, 1::2] = torch.cos(position * div_term)
273 | self.register_buffer("pe", pe)
274 |
275 | def forward(self, x):
276 | x = x + self.pe[:, : x.size(1)]
277 | return self.dropout(x)
278 |
279 |
280 | class VersatileAttention(Attention):
281 | def __init__(
282 | self,
283 | attention_mode=None,
284 | cross_frame_attention_mode=None,
285 | temporal_position_encoding=False,
286 | temporal_position_encoding_max_len=24,
287 | *args,
288 | **kwargs,
289 | ):
290 | super().__init__(*args, **kwargs)
291 | assert attention_mode == "Temporal"
292 |
293 | self.attention_mode = attention_mode
294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295 |
296 | self.pos_encoder = (
297 | PositionalEncoding(
298 | kwargs["query_dim"],
299 | dropout=0.0,
300 | max_len=temporal_position_encoding_max_len,
301 | )
302 | if (temporal_position_encoding and attention_mode == "Temporal")
303 | else None
304 | )
305 |
306 | def extra_repr(self):
307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308 |
309 | def set_use_memory_efficient_attention_xformers(
310 | self,
311 | use_memory_efficient_attention_xformers: bool,
312 | attention_op: Optional[Callable] = None,
313 | ):
314 | if use_memory_efficient_attention_xformers:
315 | if not is_xformers_available():
316 | raise ModuleNotFoundError(
317 | (
318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319 | " xformers"
320 | ),
321 | name="xformers",
322 | )
323 | elif not torch.cuda.is_available():
324 | raise ValueError(
325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326 | " only available for GPU "
327 | )
328 | else:
329 | try:
330 | # Make sure we can run the memory efficient attention
331 | _ = xformers.ops.memory_efficient_attention(
332 | torch.randn((1, 2, 40), device="cuda"),
333 | torch.randn((1, 2, 40), device="cuda"),
334 | torch.randn((1, 2, 40), device="cuda"),
335 | )
336 | except Exception as e:
337 | raise e
338 |
339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341 | # You don't need XFormersAttnProcessor here.
342 | # processor = XFormersAttnProcessor(
343 | # attention_op=attention_op,
344 | # )
345 | processor = AttnProcessor()
346 | else:
347 | processor = AttnProcessor()
348 |
349 | self.set_processor(processor)
350 |
351 | def forward(
352 | self,
353 | hidden_states,
354 | encoder_hidden_states=None,
355 | attention_mask=None,
356 | video_length=None,
357 | **cross_attention_kwargs,
358 | ):
359 | if self.attention_mode == "Temporal":
360 | d = hidden_states.shape[1] # d means HxW
361 | hidden_states = rearrange(
362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
363 | )
364 |
365 | if self.pos_encoder is not None:
366 | hidden_states = self.pos_encoder(hidden_states)
367 |
368 | encoder_hidden_states = (
369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370 | if encoder_hidden_states is not None
371 | else encoder_hidden_states
372 | )
373 |
374 | else:
375 | raise NotImplementedError
376 |
377 | hidden_states = self.processor(
378 | self,
379 | hidden_states,
380 | encoder_hidden_states=encoder_hidden_states,
381 | attention_mask=attention_mask,
382 | **cross_attention_kwargs,
383 | )
384 |
385 | if self.attention_mode == "Temporal":
386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387 |
388 | return hidden_states
389 |
--------------------------------------------------------------------------------
/src/models/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from src.models.attention import TemporalBasicTransformerBlock
8 |
9 | from .attention import BasicTransformerBlock
10 |
11 |
12 | def torch_dfs(model: torch.nn.Module):
13 | result = [model]
14 | for child in model.children():
15 | result += torch_dfs(child)
16 | return result
17 |
18 |
19 | class ReferenceAttentionControl:
20 | def __init__(
21 | self,
22 | unet,
23 | mode="write",
24 | do_classifier_free_guidance=False,
25 | attention_auto_machine_weight=float("inf"),
26 | gn_auto_machine_weight=1.0,
27 | style_fidelity=1.0,
28 | reference_attn=True,
29 | reference_adain=False,
30 | fusion_blocks="midup",
31 | batch_size=1,
32 | ) -> None:
33 | # 10. Modify self attention and group norm
34 | self.unet = unet
35 | assert mode in ["read", "write"]
36 | assert fusion_blocks in ["midup", "full"]
37 | self.reference_attn = reference_attn
38 | self.reference_adain = reference_adain
39 | self.fusion_blocks = fusion_blocks
40 | self.register_reference_hooks(
41 | mode,
42 | do_classifier_free_guidance,
43 | attention_auto_machine_weight,
44 | gn_auto_machine_weight,
45 | style_fidelity,
46 | reference_attn,
47 | reference_adain,
48 | fusion_blocks,
49 | batch_size=batch_size,
50 | )
51 |
52 | def register_reference_hooks(
53 | self,
54 | mode,
55 | do_classifier_free_guidance,
56 | attention_auto_machine_weight,
57 | gn_auto_machine_weight,
58 | style_fidelity,
59 | reference_attn,
60 | reference_adain,
61 | dtype=torch.float16,
62 | batch_size=1,
63 | num_images_per_prompt=1,
64 | device=torch.device("cpu"),
65 | fusion_blocks="midup",
66 | ):
67 | MODE = mode
68 | do_classifier_free_guidance = do_classifier_free_guidance
69 | attention_auto_machine_weight = attention_auto_machine_weight
70 | gn_auto_machine_weight = gn_auto_machine_weight
71 | style_fidelity = style_fidelity
72 | reference_attn = reference_attn
73 | reference_adain = reference_adain
74 | fusion_blocks = fusion_blocks
75 | num_images_per_prompt = num_images_per_prompt
76 | dtype = dtype
77 | if do_classifier_free_guidance:
78 | uc_mask = (
79 | torch.Tensor(
80 | [1] * batch_size * num_images_per_prompt * 16
81 | + [0] * batch_size * num_images_per_prompt * 16
82 | )
83 | .to(device)
84 | .bool()
85 | )
86 | else:
87 | uc_mask = (
88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89 | .to(device)
90 | .bool()
91 | )
92 |
93 | def hacked_basic_transformer_inner_forward(
94 | self,
95 | hidden_states: torch.FloatTensor,
96 | attention_mask: Optional[torch.FloatTensor] = None,
97 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
98 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
99 | timestep: Optional[torch.LongTensor] = None,
100 | cross_attention_kwargs: Dict[str, Any] = None,
101 | class_labels: Optional[torch.LongTensor] = None,
102 | video_length=None,
103 | self_attention_additional_feats=None,
104 | mode=None,
105 | ):
106 | if self.use_ada_layer_norm: # False
107 | norm_hidden_states = self.norm1(hidden_states, timestep)
108 | elif self.use_ada_layer_norm_zero:
109 | (
110 | norm_hidden_states,
111 | gate_msa,
112 | shift_mlp,
113 | scale_mlp,
114 | gate_mlp,
115 | ) = self.norm1(
116 | hidden_states,
117 | timestep,
118 | class_labels,
119 | hidden_dtype=hidden_states.dtype,
120 | )
121 | else:
122 | norm_hidden_states = self.norm1(hidden_states)
123 |
124 | # 1. Self-Attention
125 | # self.only_cross_attention = False
126 | cross_attention_kwargs = (
127 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
128 | )
129 | if self.only_cross_attention:
130 | attn_output = self.attn1(
131 | norm_hidden_states,
132 | encoder_hidden_states=encoder_hidden_states
133 | if self.only_cross_attention
134 | else None,
135 | attention_mask=attention_mask,
136 | **cross_attention_kwargs,
137 | )
138 | else:
139 | if MODE == "write":
140 | self.bank.append(norm_hidden_states.clone())
141 | attn_output = self.attn1(
142 | norm_hidden_states,
143 | encoder_hidden_states=encoder_hidden_states
144 | if self.only_cross_attention
145 | else None,
146 | attention_mask=attention_mask,
147 | **cross_attention_kwargs,
148 | )
149 | if MODE == "read":
150 | bank_fea = [
151 | rearrange(
152 | d.unsqueeze(1).repeat(1, video_length, 1, 1),
153 | "b t l c -> (b t) l c",
154 | )
155 | for d in self.bank
156 | ]
157 | modify_norm_hidden_states = torch.cat(
158 | [norm_hidden_states] + bank_fea, dim=1
159 | )
160 | hidden_states_uc = (
161 | self.attn1(
162 | norm_hidden_states,
163 | encoder_hidden_states=modify_norm_hidden_states,
164 | attention_mask=attention_mask,
165 | )
166 | + hidden_states
167 | )
168 | if do_classifier_free_guidance:
169 | hidden_states_c = hidden_states_uc.clone()
170 | _uc_mask = uc_mask.clone()
171 | if hidden_states.shape[0] != _uc_mask.shape[0]:
172 | _uc_mask = (
173 | torch.Tensor(
174 | [1] * (hidden_states.shape[0] // 2)
175 | + [0] * (hidden_states.shape[0] // 2)
176 | )
177 | .to(device)
178 | .bool()
179 | )
180 | hidden_states_c[_uc_mask] = (
181 | self.attn1(
182 | norm_hidden_states[_uc_mask],
183 | encoder_hidden_states=norm_hidden_states[_uc_mask],
184 | attention_mask=attention_mask,
185 | )
186 | + hidden_states[_uc_mask]
187 | )
188 | hidden_states = hidden_states_c.clone()
189 | else:
190 | hidden_states = hidden_states_uc
191 |
192 | # self.bank.clear()
193 | if self.attn2 is not None:
194 | # Cross-Attention
195 | norm_hidden_states = (
196 | self.norm2(hidden_states, timestep)
197 | if self.use_ada_layer_norm
198 | else self.norm2(hidden_states)
199 | )
200 | hidden_states = (
201 | self.attn2(
202 | norm_hidden_states,
203 | encoder_hidden_states=encoder_hidden_states,
204 | attention_mask=attention_mask,
205 | )
206 | + hidden_states
207 | )
208 |
209 | # Feed-forward
210 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
211 |
212 | # Temporal-Attention
213 | if self.unet_use_temporal_attention:
214 | d = hidden_states.shape[1]
215 | hidden_states = rearrange(
216 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
217 | )
218 | norm_hidden_states = (
219 | self.norm_temp(hidden_states, timestep)
220 | if self.use_ada_layer_norm
221 | else self.norm_temp(hidden_states)
222 | )
223 | hidden_states = (
224 | self.attn_temp(norm_hidden_states) + hidden_states
225 | )
226 | hidden_states = rearrange(
227 | hidden_states, "(b d) f c -> (b f) d c", d=d
228 | )
229 |
230 | return hidden_states
231 |
232 | if self.use_ada_layer_norm_zero:
233 | attn_output = gate_msa.unsqueeze(1) * attn_output
234 | hidden_states = attn_output + hidden_states
235 |
236 | if self.attn2 is not None:
237 | norm_hidden_states = (
238 | self.norm2(hidden_states, timestep)
239 | if self.use_ada_layer_norm
240 | else self.norm2(hidden_states)
241 | )
242 |
243 | # 2. Cross-Attention
244 | attn_output = self.attn2(
245 | norm_hidden_states,
246 | encoder_hidden_states=encoder_hidden_states,
247 | attention_mask=encoder_attention_mask,
248 | **cross_attention_kwargs,
249 | )
250 | hidden_states = attn_output + hidden_states
251 |
252 | # 3. Feed-forward
253 | norm_hidden_states = self.norm3(hidden_states)
254 |
255 | if self.use_ada_layer_norm_zero:
256 | norm_hidden_states = (
257 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
258 | )
259 |
260 | ff_output = self.ff(norm_hidden_states)
261 |
262 | if self.use_ada_layer_norm_zero:
263 | ff_output = gate_mlp.unsqueeze(1) * ff_output
264 |
265 | hidden_states = ff_output + hidden_states
266 |
267 | return hidden_states
268 |
269 | if self.reference_attn:
270 | if self.fusion_blocks == "midup":
271 | attn_modules = [
272 | module
273 | for module in (
274 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
275 | )
276 | if isinstance(module, BasicTransformerBlock)
277 | or isinstance(module, TemporalBasicTransformerBlock)
278 | ]
279 | elif self.fusion_blocks == "full":
280 | attn_modules = [
281 | module
282 | for module in torch_dfs(self.unet)
283 | if isinstance(module, BasicTransformerBlock)
284 | or isinstance(module, TemporalBasicTransformerBlock)
285 | ]
286 | attn_modules = sorted(
287 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
288 | )
289 |
290 | for i, module in enumerate(attn_modules):
291 | module._original_inner_forward = module.forward
292 | if isinstance(module, BasicTransformerBlock):
293 | module.forward = hacked_basic_transformer_inner_forward.__get__(
294 | module, BasicTransformerBlock
295 | )
296 | if isinstance(module, TemporalBasicTransformerBlock):
297 | module.forward = hacked_basic_transformer_inner_forward.__get__(
298 | module, TemporalBasicTransformerBlock
299 | )
300 |
301 | module.bank = []
302 | module.attn_weight = float(i) / float(len(attn_modules))
303 |
304 | def update(self, writer, dtype=torch.float16):
305 | if self.reference_attn:
306 | if self.fusion_blocks == "midup":
307 | reader_attn_modules = [
308 | module
309 | for module in (
310 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
311 | )
312 | if isinstance(module, TemporalBasicTransformerBlock)
313 | ]
314 | writer_attn_modules = [
315 | module
316 | for module in (
317 | torch_dfs(writer.unet.mid_block)
318 | + torch_dfs(writer.unet.up_blocks)
319 | )
320 | if isinstance(module, BasicTransformerBlock)
321 | ]
322 | elif self.fusion_blocks == "full":
323 | reader_attn_modules = [
324 | module
325 | for module in torch_dfs(self.unet)
326 | if isinstance(module, TemporalBasicTransformerBlock)
327 | ]
328 | writer_attn_modules = [
329 | module
330 | for module in torch_dfs(writer.unet)
331 | if isinstance(module, BasicTransformerBlock)
332 | ]
333 | reader_attn_modules = sorted(
334 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
335 | )
336 | writer_attn_modules = sorted(
337 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
338 | )
339 | for r, w in zip(reader_attn_modules, writer_attn_modules):
340 | r.bank = [v.clone().to(dtype) for v in w.bank]
341 | # w.bank.clear()
342 |
343 | def clear(self):
344 | if self.reference_attn:
345 | if self.fusion_blocks == "midup":
346 | reader_attn_modules = [
347 | module
348 | for module in (
349 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
350 | )
351 | if isinstance(module, BasicTransformerBlock)
352 | or isinstance(module, TemporalBasicTransformerBlock)
353 | ]
354 | elif self.fusion_blocks == "full":
355 | reader_attn_modules = [
356 | module
357 | for module in torch_dfs(self.unet)
358 | if isinstance(module, BasicTransformerBlock)
359 | or isinstance(module, TemporalBasicTransformerBlock)
360 | ]
361 | reader_attn_modules = sorted(
362 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
363 | )
364 | for r in reader_attn_modules:
365 | r.bank.clear()
366 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | class InflatedConv3d(nn.Conv2d):
10 | def forward(self, x):
11 | video_length = x.shape[2]
12 |
13 | x = rearrange(x, "b c f h w -> (b f) c h w")
14 | x = super().forward(x)
15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16 |
17 | return x
18 |
19 |
20 | class InflatedGroupNorm(nn.GroupNorm):
21 | def forward(self, x):
22 | video_length = x.shape[2]
23 |
24 | x = rearrange(x, "b c f h w -> (b f) c h w")
25 | x = super().forward(x)
26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27 |
28 | return x
29 |
30 |
31 | class Upsample3D(nn.Module):
32 | def __init__(
33 | self,
34 | channels,
35 | use_conv=False,
36 | use_conv_transpose=False,
37 | out_channels=None,
38 | name="conv",
39 | ):
40 | super().__init__()
41 | self.channels = channels
42 | self.out_channels = out_channels or channels
43 | self.use_conv = use_conv
44 | self.use_conv_transpose = use_conv_transpose
45 | self.name = name
46 |
47 | conv = None
48 | if use_conv_transpose:
49 | raise NotImplementedError
50 | elif use_conv:
51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52 |
53 | def forward(self, hidden_states, output_size=None):
54 | assert hidden_states.shape[1] == self.channels
55 |
56 | if self.use_conv_transpose:
57 | raise NotImplementedError
58 |
59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60 | dtype = hidden_states.dtype
61 | if dtype == torch.bfloat16:
62 | hidden_states = hidden_states.to(torch.float32)
63 |
64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65 | if hidden_states.shape[0] >= 64:
66 | hidden_states = hidden_states.contiguous()
67 |
68 | # if `output_size` is passed we force the interpolation output
69 | # size and do not make use of `scale_factor=2`
70 | if output_size is None:
71 | hidden_states = F.interpolate(
72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73 | )
74 | else:
75 | hidden_states = F.interpolate(
76 | hidden_states, size=output_size, mode="nearest"
77 | )
78 |
79 | # If the input is bfloat16, we cast back to bfloat16
80 | if dtype == torch.bfloat16:
81 | hidden_states = hidden_states.to(dtype)
82 |
83 | # if self.use_conv:
84 | # if self.name == "conv":
85 | # hidden_states = self.conv(hidden_states)
86 | # else:
87 | # hidden_states = self.Conv2d_0(hidden_states)
88 | hidden_states = self.conv(hidden_states)
89 |
90 | return hidden_states
91 |
92 |
93 | class Downsample3D(nn.Module):
94 | def __init__(
95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96 | ):
97 | super().__init__()
98 | self.channels = channels
99 | self.out_channels = out_channels or channels
100 | self.use_conv = use_conv
101 | self.padding = padding
102 | stride = 2
103 | self.name = name
104 |
105 | if use_conv:
106 | self.conv = InflatedConv3d(
107 | self.channels, self.out_channels, 3, stride=stride, padding=padding
108 | )
109 | else:
110 | raise NotImplementedError
111 |
112 | def forward(self, hidden_states):
113 | assert hidden_states.shape[1] == self.channels
114 | if self.use_conv and self.padding == 0:
115 | raise NotImplementedError
116 |
117 | assert hidden_states.shape[1] == self.channels
118 | hidden_states = self.conv(hidden_states)
119 |
120 | return hidden_states
121 |
122 |
123 | class ResnetBlock3D(nn.Module):
124 | def __init__(
125 | self,
126 | *,
127 | in_channels,
128 | out_channels=None,
129 | conv_shortcut=False,
130 | dropout=0.0,
131 | temb_channels=512,
132 | groups=32,
133 | groups_out=None,
134 | pre_norm=True,
135 | eps=1e-6,
136 | non_linearity="swish",
137 | time_embedding_norm="default",
138 | output_scale_factor=1.0,
139 | use_in_shortcut=None,
140 | use_inflated_groupnorm=None,
141 | ):
142 | super().__init__()
143 | self.pre_norm = pre_norm
144 | self.pre_norm = True
145 | self.in_channels = in_channels
146 | out_channels = in_channels if out_channels is None else out_channels
147 | self.out_channels = out_channels
148 | self.use_conv_shortcut = conv_shortcut
149 | self.time_embedding_norm = time_embedding_norm
150 | self.output_scale_factor = output_scale_factor
151 |
152 | if groups_out is None:
153 | groups_out = groups
154 |
155 | assert use_inflated_groupnorm != None
156 | if use_inflated_groupnorm:
157 | self.norm1 = InflatedGroupNorm(
158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159 | )
160 | else:
161 | self.norm1 = torch.nn.GroupNorm(
162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163 | )
164 |
165 | self.conv1 = InflatedConv3d(
166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
167 | )
168 |
169 | if temb_channels is not None:
170 | if self.time_embedding_norm == "default":
171 | time_emb_proj_out_channels = out_channels
172 | elif self.time_embedding_norm == "scale_shift":
173 | time_emb_proj_out_channels = out_channels * 2
174 | else:
175 | raise ValueError(
176 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
177 | )
178 |
179 | self.time_emb_proj = torch.nn.Linear(
180 | temb_channels, time_emb_proj_out_channels
181 | )
182 | else:
183 | self.time_emb_proj = None
184 |
185 | if use_inflated_groupnorm:
186 | self.norm2 = InflatedGroupNorm(
187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188 | )
189 | else:
190 | self.norm2 = torch.nn.GroupNorm(
191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192 | )
193 | self.dropout = torch.nn.Dropout(dropout)
194 | self.conv2 = InflatedConv3d(
195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
196 | )
197 |
198 | if non_linearity == "swish":
199 | self.nonlinearity = lambda x: F.silu(x)
200 | elif non_linearity == "mish":
201 | self.nonlinearity = Mish()
202 | elif non_linearity == "silu":
203 | self.nonlinearity = nn.SiLU()
204 |
205 | self.use_in_shortcut = (
206 | self.in_channels != self.out_channels
207 | if use_in_shortcut is None
208 | else use_in_shortcut
209 | )
210 |
211 | self.conv_shortcut = None
212 | if self.use_in_shortcut:
213 | self.conv_shortcut = InflatedConv3d(
214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
215 | )
216 |
217 | def forward(self, input_tensor, temb):
218 | hidden_states = input_tensor
219 |
220 | hidden_states = self.norm1(hidden_states)
221 | hidden_states = self.nonlinearity(hidden_states)
222 |
223 | hidden_states = self.conv1(hidden_states)
224 |
225 | if temb is not None:
226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227 |
228 | if temb is not None and self.time_embedding_norm == "default":
229 | hidden_states = hidden_states + temb
230 |
231 | hidden_states = self.norm2(hidden_states)
232 |
233 | if temb is not None and self.time_embedding_norm == "scale_shift":
234 | scale, shift = torch.chunk(temb, 2, dim=1)
235 | hidden_states = hidden_states * (1 + scale) + shift
236 |
237 | hidden_states = self.nonlinearity(hidden_states)
238 |
239 | hidden_states = self.dropout(hidden_states)
240 | hidden_states = self.conv2(hidden_states)
241 |
242 | if self.conv_shortcut is not None:
243 | input_tensor = self.conv_shortcut(input_tensor)
244 |
245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246 |
247 | return output_tensor
248 |
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
--------------------------------------------------------------------------------
/src/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | name=None,
49 | ):
50 | super().__init__()
51 | self.use_linear_projection = use_linear_projection
52 | self.num_attention_heads = num_attention_heads
53 | self.attention_head_dim = attention_head_dim
54 | inner_dim = num_attention_heads * attention_head_dim
55 |
56 | # Define input layers
57 | self.in_channels = in_channels
58 | self.name=name
59 |
60 | self.norm = torch.nn.GroupNorm(
61 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
62 | )
63 | if use_linear_projection:
64 | self.proj_in = nn.Linear(in_channels, inner_dim)
65 | else:
66 | self.proj_in = nn.Conv2d(
67 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
68 | )
69 |
70 | # Define transformers blocks
71 | self.transformer_blocks = nn.ModuleList(
72 | [
73 | TemporalBasicTransformerBlock(
74 | inner_dim,
75 | num_attention_heads,
76 | attention_head_dim,
77 | dropout=dropout,
78 | cross_attention_dim=cross_attention_dim,
79 | activation_fn=activation_fn,
80 | num_embeds_ada_norm=num_embeds_ada_norm,
81 | attention_bias=attention_bias,
82 | only_cross_attention=only_cross_attention,
83 | upcast_attention=upcast_attention,
84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85 | unet_use_temporal_attention=unet_use_temporal_attention,
86 | name=f"{self.name}_{d}_TransformerBlock" if self.name else None,
87 | )
88 | for d in range(num_layers)
89 | ]
90 | )
91 |
92 | # 4. Define output layers
93 | if use_linear_projection:
94 | self.proj_out = nn.Linear(in_channels, inner_dim)
95 | else:
96 | self.proj_out = nn.Conv2d(
97 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
98 | )
99 |
100 | self.gradient_checkpointing = False
101 |
102 | def _set_gradient_checkpointing(self, module, value=False):
103 | if hasattr(module, "gradient_checkpointing"):
104 | module.gradient_checkpointing = value
105 |
106 | def forward(
107 | self,
108 | hidden_states,
109 | encoder_hidden_states=None,
110 | self_attention_additional_feats=None,
111 | mode=None,
112 | timestep=None,
113 | return_dict: bool = True,
114 | ):
115 | # Input
116 | assert (
117 | hidden_states.dim() == 5
118 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
119 | video_length = hidden_states.shape[2]
120 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
121 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
122 | encoder_hidden_states = repeat(
123 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
124 | )
125 |
126 | batch, channel, height, weight = hidden_states.shape
127 | residual = hidden_states
128 |
129 | hidden_states = self.norm(hidden_states)
130 | if not self.use_linear_projection:
131 | hidden_states = self.proj_in(hidden_states)
132 | inner_dim = hidden_states.shape[1]
133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134 | batch, height * weight, inner_dim
135 | )
136 | else:
137 | inner_dim = hidden_states.shape[1]
138 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
139 | batch, height * weight, inner_dim
140 | )
141 | hidden_states = self.proj_in(hidden_states)
142 |
143 | # Blocks
144 | for i, block in enumerate(self.transformer_blocks):
145 |
146 | if self.training and self.gradient_checkpointing:
147 |
148 | def create_custom_forward(module, return_dict=None):
149 | def custom_forward(*inputs):
150 | if return_dict is not None:
151 | return module(*inputs, return_dict=return_dict)
152 | else:
153 | return module(*inputs)
154 |
155 | return custom_forward
156 |
157 | # if hasattr(self.block, 'bank') and len(self.block.bank) > 0:
158 | # hidden_states
159 | hidden_states = torch.utils.checkpoint.checkpoint(
160 | create_custom_forward(block),
161 | hidden_states,
162 | encoder_hidden_states=encoder_hidden_states,
163 | timestep=timestep,
164 | attention_mask=None,
165 | video_length=video_length,
166 | self_attention_additional_feats=self_attention_additional_feats,
167 | mode=mode,
168 | )
169 | else:
170 |
171 | hidden_states = block(
172 | hidden_states,
173 | encoder_hidden_states=encoder_hidden_states,
174 | timestep=timestep,
175 | self_attention_additional_feats=self_attention_additional_feats,
176 | mode=mode,
177 | video_length=video_length,
178 | )
179 |
180 | # Output
181 | if not self.use_linear_projection:
182 | hidden_states = (
183 | hidden_states.reshape(batch, height, weight, inner_dim)
184 | .permute(0, 3, 1, 2)
185 | .contiguous()
186 | )
187 | hidden_states = self.proj_out(hidden_states)
188 | else:
189 | hidden_states = self.proj_out(hidden_states)
190 | hidden_states = (
191 | hidden_states.reshape(batch, height, weight, inner_dim)
192 | .permute(0, 3, 1, 2)
193 | .contiguous()
194 | )
195 |
196 | output = hidden_states + residual
197 |
198 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
199 | if not return_dict:
200 | return (output,)
201 |
202 | return Transformer3DModelOutput(sample=output)
203 |
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__init__.py
--------------------------------------------------------------------------------
/src/pipelines/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/pipelines/__pycache__/pipeline_motion2vid_merge_infer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__pycache__/pipeline_motion2vid_merge_infer.cpython-310.pyc
--------------------------------------------------------------------------------
/src/pipelines/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/pipelines/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | tensor_interpolation = None
4 |
5 |
6 | def get_tensor_interpolation_method():
7 | return tensor_interpolation
8 |
9 |
10 | def set_tensor_interpolation_method(is_slerp):
11 | global tensor_interpolation
12 | tensor_interpolation = slerp if is_slerp else linear
13 |
14 |
15 | def linear(v1, v2, t):
16 | return (1.0 - t) * v1 + t * v2
17 |
18 |
19 | def slerp(
20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21 | ) -> torch.Tensor:
22 | u0 = v0 / v0.norm()
23 | u1 = v1 / v1.norm()
24 | dot = (u0 * u1).sum()
25 | if dot.abs() > DOT_THRESHOLD:
26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27 | return (1.0 - t) * v0 + t * v1
28 | omega = dot.acos()
29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
30 |
--------------------------------------------------------------------------------
/src/utils/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/utils/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/utils/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/visualizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chen-yingjie/Perception-as-Control/6c8ea213fe0c81281d46eeb4e184a987be3a8eb3/src/utils/__pycache__/visualizer.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import sys
6 | from pathlib import Path
7 |
8 | import av
9 | import numpy as np
10 | import torch
11 | import torchvision
12 | import torch.distributed as dist
13 | from einops import rearrange
14 | from PIL import Image
15 |
16 |
17 | def seed_everything(seed):
18 | import random
19 |
20 | import numpy as np
21 |
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed_all(seed)
24 | np.random.seed(seed % (2**32))
25 | random.seed(seed)
26 |
27 |
28 | def import_filename(filename):
29 | spec = importlib.util.spec_from_file_location("mymodule", filename)
30 | module = importlib.util.module_from_spec(spec)
31 | sys.modules[spec.name] = module
32 | spec.loader.exec_module(module)
33 | return module
34 |
35 |
36 | def delete_additional_ckpt(base_path, num_keep):
37 | dirs = []
38 | for d in os.listdir(base_path):
39 | if d.startswith("checkpoint-"):
40 | dirs.append(d)
41 | num_tot = len(dirs)
42 | if num_tot <= num_keep:
43 | return
44 | # ensure ckpt is sorted and delete the ealier!
45 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
46 | for d in del_dirs:
47 | path_to_dir = osp.join(base_path, d)
48 | if osp.exists(path_to_dir):
49 | shutil.rmtree(path_to_dir)
50 |
51 |
52 | def save_videos_from_pil(pil_images, path, fps=8):
53 | import av
54 |
55 | save_fmt = Path(path).suffix
56 | os.makedirs(os.path.dirname(path), exist_ok=True)
57 | width, height = pil_images[0].size
58 |
59 | if save_fmt == ".mp4":
60 | codec = "libx264"
61 | container = av.open(path, "w")
62 | stream = container.add_stream(codec, rate=fps)
63 |
64 | stream.width = width
65 | stream.height = height
66 |
67 | for pil_image in pil_images:
68 | # pil_image = Image.fromarray(image_arr).convert("RGB")
69 | av_frame = av.VideoFrame.from_image(pil_image)
70 | container.mux(stream.encode(av_frame))
71 | container.mux(stream.encode())
72 | container.close()
73 |
74 | elif save_fmt == ".gif":
75 | pil_images[0].save(
76 | fp=path,
77 | format="GIF",
78 | append_images=pil_images[1:],
79 | save_all=True,
80 | duration=(1 / fps * 1000),
81 | loop=0,
82 | )
83 | else:
84 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
85 |
86 |
87 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=2, fps=8):
88 | videos = rearrange(videos, "b c t h w -> t b c h w")
89 | height, width = videos.shape[-2:]
90 | outputs = []
91 |
92 | for x in videos:
93 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
94 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
95 | if rescale:
96 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
97 | x = (x * 255).numpy().astype(np.uint8)
98 | x = Image.fromarray(x)
99 |
100 | outputs.append(x)
101 |
102 | os.makedirs(os.path.dirname(path), exist_ok=True)
103 |
104 | save_videos_from_pil(outputs, path, fps)
105 |
106 |
107 | def read_frames(video_path):
108 | container = av.open(video_path)
109 |
110 | video_stream = next(s for s in container.streams if s.type == "video")
111 | frames = []
112 | for packet in container.demux(video_stream):
113 | for frame in packet.decode():
114 | image = Image.frombytes(
115 | "RGB",
116 | (frame.width, frame.height),
117 | frame.to_rgb().to_ndarray(),
118 | )
119 | frames.append(image)
120 |
121 | return frames
122 |
123 |
124 | def get_fps(video_path):
125 | container = av.open(video_path)
126 | video_stream = next(s for s in container.streams if s.type == "video")
127 | fps = video_stream.average_rate
128 | container.close()
129 | return fps
130 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from typing import Union
5 |
6 | import torch
7 | import torchvision
8 | import torch.distributed as dist
9 |
10 | from safetensors import safe_open
11 | from tqdm import tqdm
12 | from scipy.interpolate import PchipInterpolator
13 | from einops import rearrange
14 | from decord import VideoReader
15 |
16 |
17 |
18 | def interpolate_trajectory(points, n_points):
19 | if len(points) == n_points:
20 | return points
21 |
22 | if len(points) == 1:
23 | points = [points[0] for _ in range(n_points)]
24 | return points
25 |
26 | x = [point[0] for point in points]
27 | y = [point[1] for point in points]
28 |
29 | t = np.linspace(0, 1, len(points))
30 |
31 | fx = PchipInterpolator(t, x)
32 | fy = PchipInterpolator(t, y)
33 |
34 | new_t = np.linspace(0, 1, n_points)
35 |
36 | new_x = fx(new_t)
37 | new_y = fy(new_t)
38 | new_points = list(zip(new_x, new_y))
39 |
40 | return new_points
41 |
42 |
43 | def interpolate_trajectory_3d(points, n_points):
44 | if len(points) == n_points:
45 | return points
46 |
47 | if len(points) == 1:
48 | points = [points[0] for _ in range(n_points)]
49 | return points
50 |
51 | x = [point[0] for point in points]
52 | y = [point[1] for point in points]
53 | z = [point[2] for point in points]
54 |
55 | t = np.linspace(0, 1, len(points))
56 |
57 | fx = PchipInterpolator(t, x)
58 | fy = PchipInterpolator(t, y)
59 | fz = PchipInterpolator(t, z)
60 |
61 | new_t = np.linspace(0, 1, n_points)
62 |
63 | new_x = fx(new_t)
64 | new_y = fy(new_t)
65 | new_z = fz(new_t)
66 | new_points = list(zip(new_x, new_y, new_z))
67 |
68 | return new_points
69 |
70 |
71 | def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
72 | """Generate a bivariate isotropic or anisotropic Gaussian kernel.
73 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
74 | Args:
75 | kernel_size (int):
76 | sig_x (float):
77 | sig_y (float):
78 | theta (float): Radian measurement.
79 | grid (ndarray, optional): generated by :func:`mesh_grid`,
80 | with the shape (K, K, 2), K is the kernel size. Default: None
81 | isotropic (bool):
82 | Returns:
83 | kernel (ndarray): normalized kernel.
84 | """
85 | if grid is None:
86 | grid, _, _ = mesh_grid(kernel_size)
87 | if isotropic:
88 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
89 | else:
90 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
91 | kernel = pdf2(sigma_matrix, grid)
92 | kernel = kernel / np.sum(kernel)
93 | return kernel
94 |
95 |
96 | def mesh_grid(kernel_size):
97 | """Generate the mesh grid, centering at zero.
98 | Args:
99 | kernel_size (int):
100 | Returns:
101 | xy (ndarray): with the shape (kernel_size, kernel_size, 2)
102 | xx (ndarray): with the shape (kernel_size, kernel_size)
103 | yy (ndarray): with the shape (kernel_size, kernel_size)
104 | """
105 | ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
106 | xx, yy = np.meshgrid(ax, ax)
107 | xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
108 | 1))).reshape(kernel_size, kernel_size, 2)
109 | return xy, xx, yy
110 |
111 |
112 | def pdf2(sigma_matrix, grid):
113 | """Calculate PDF of the bivariate Gaussian distribution.
114 | Args:
115 | sigma_matrix (ndarray): with the shape (2, 2)
116 | grid (ndarray): generated by :func:`mesh_grid`,
117 | with the shape (K, K, 2), K is the kernel size.
118 | Returns:
119 | kernel (ndarrray): un-normalized kernel.
120 | """
121 | inverse_sigma = np.linalg.inv(sigma_matrix)
122 | kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
123 | return kernel
124 |
--------------------------------------------------------------------------------