31 |
32 | *With interactive editing empowered by SC-GS, users can effortlessly edit and customize their digital assets with interactive editing features.*
33 |
34 |
35 |
36 |
37 |
38 | *Given (a) an image sequence from a monocular dynamic video, we propose to represent the motion with a set of sparse control points, which can be used to drive 3D Gaussians for high-fidelity rendering.Our approach enables both (b) dynamic view synthesis and (c) motion editing due to the motion representation based on sparse control points*
39 |
40 |
41 | ## Updates
42 |
43 | ### 2025.05.21
44 |
45 | #### Editing Real-World Static Objects
46 |
47 | Solving a reported issues that invertable laplacian matrix causing slow editing. Editing real world static object is now flexible and show interesting results.
48 |
49 | #### 1. Masking the Object to Edit
50 | When editing, remember to mask the object you want to modify. If you are using MiVOS and encounter an issue with non-digital image names (e.g., `frame_000.jpg` causing errors), you can resolve it by replacing the line in [this file](https://github.com/hkchengrex/MiVOS/blob/f2600a6eea8709c7b9f1a7575adc725def680b81/interact/interactive_utils.py#L26) with the following:
51 |
52 | ```python
53 | fnames = sorted(glob.glob(os.path.join(path, '*.jpg')), key=lambda x: int(''.join(char for char in os.path.basename(x).split('.')[0] if char.isdigit())))
54 | ```
55 |
56 | #### 2. Training and Editing Static Scenes
57 |
58 | Run the following command to train and edit a static scene:
59 |
60 | ```bash
61 | CUDA_VISIBLE_DEVICES=0 python train_gui.py \
62 | --source_path "XXX/person-small" \
63 | --model_path "outputs/person/" \
64 | --is_scene_static \
65 | --gui \
66 | --deform_type "node" \
67 | --node_num "512" \
68 | --gt_alpha_mask_as_dynamic_mask \
69 | --gs_with_motion_mask \
70 | --W "800" \
71 | --H "800" \
72 | --white_background \
73 | --init_isotropic_gs_with_all_colmap_pcl
74 | ```
75 |
76 | #### 3. Editing Results on I-N2N Scenes
77 | By following the editing guidance, you can easily achieve satisfactory geometry editing results on static scenes, as demonstrated in [Instruct-NeRF2NeRF](https://instruct-nerf2nerf.github.io/):
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 | ### 2024-03-17:
92 |
93 | 1. Editing **static scenes** is now supported! Simply include the `--is_scene_static` argument and you are good to go!
94 |
95 | 2. Video rendering is now supported with interpolation of editing results. Press the button `sv_kpt` to save each edited result and press `render_traj` to render the interpolated motions as a video. Click the `spiral` to switch the camera-motion pattern of the rendered video between a spiral trace and a fixed pose.
96 |
97 | 3. On self-captured real-world scenes where Gaussian number will be too large, the dimension of hyper coordinates that seperate close but disconnected parts can be set to 2 to speed up the rendering: ` --hyper_dim 2`. Also remember to remove `--is_blender` in such cases!
98 |
99 | ### 2024-03-07
100 |
101 | We offer two ARAP deformation strategies for motion editing: 1. iterative deformation and 2. deformation from Laplacian initialization.
102 |
103 | ### 2024-03-06
104 |
105 | To prevent initialization failure of control points, you use the argument `--init_isotropic_gs_with_all_colmap_pcl` on self-captured datasets.
106 |
107 |
108 | ## Install
109 |
110 | ```bash
111 | git clone https://github.com/yihua7/SC-GS --recursive
112 | cd SC-GS
113 |
114 | pip install -r requirements.txt
115 |
116 | # a modified gaussian splatting (+ depth, alpha rendering)
117 | pip install ./submodules/diff-gaussian-rasterization
118 |
119 | # simple-knn
120 | pip install ./submodules/simple-knn
121 | ```
122 |
123 | ## Run
124 |
125 | ### Train wit GUI
126 |
127 | * To begin the training, select the 'start' button. The program will begin with pre-training control points in the form of Gaussians for 10,000 steps before progressing to train dynamic Gaussians.
128 |
129 | * To view the control points, click on the 'Node' button found on the panel located after 'Visualization'.
130 |
131 | ```bash
132 | # Train with GUI (for the resolution of 400*400 with best PSNR)
133 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800 --gui
134 |
135 | # Train with GUI (for the resolution of 800*800)
136 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --W 800 --H 800 --random_bg_color --white_background --gui
137 | ```
138 |
139 | ### Train with terminal
140 |
141 | * Simply remove the option `--gui` as following:
142 |
143 | ```bash
144 | # Train with terminal only (for the resolution of 400*400 with best PSNR)
145 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800
146 | ```
147 |
148 | ### Evalualuate
149 |
150 | * Every 1000 steps during the training, the program will evaluate SC-GS on the test set and print the results **on the UI interface and terminal**. You can view them easily.
151 |
152 | * You can also run the evaluation command by replacing `train_gui.py` with `render.py` in the command of training. Results will be saved in the specified log directory `outputs/XXX`. The following is an example:
153 |
154 | ```bash
155 | # Evaluate with GUI (for the resolution of 400*400 with best PSNR)
156 | CUDA_VISIBLE_DEVICES=0 python render.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800
157 | ```
158 |
159 | ## Editing
160 |
161 | ### 2 min editing guidance:
162 |
163 | (The video was recorded prior to the addition of the editing mode selection menu in the UI. In the video, the deformation was performed using the `arap_from_init` method.)
164 |
165 | https://github.com/yihua7/SC-GS/assets/35869256/7a71d29b-975e-4870-afb1-7cdc96bb9482
166 |
167 | ### Editing Mode
168 |
169 | We offer two deformation strategies for editing: **(1)** iterative ARAP deformation and **(2)** ARAP starts with the initial frozen moment. Users can select their preferred strategy from the Editing Mode drop-down menu on the UI interface.
170 |
171 |
172 |
173 | (1) **Iterative deformation (`arap_iterative`)**:
174 |
175 | - **Pros**: It allows easy achievement of large-scale deformation without rotating artifacts.
176 |
177 | - **Cons**: It may be difficult to revert to the previous state after unintentionally obtaining unwanted deformations due to the iterative state update.
178 |
179 | (2) **Deformation from the initial frozen moment (`arap_from_init`)**:
180 |
181 | - **Pros**: It ensures that the deformed state can be restored when control points return to their previous positions, making it easier to control without deviation.
182 |
183 | - **Cons**: For large-scale rotational deformation, ARAP algorithm may fail to achieve the optimum since the initialization from the Laplace deformation is not robust to deal with rotation. This may result in certain areas not experiencing corresponding large-scale rotations.
184 |
185 | **Users can personally operate and experience the differences between the two strategies. They can then choose the most suitable strategy to achieve their desired editing effect.**
186 |
187 | ### Tips on Editing with the deformation from the initial frozen moment (`arap_from_init`)
188 |
189 | 1. **When and why will artifacts appear when using `arap_from_init`?** Most artifacts of editing are caused by the inaccurate initialization of ARAP deformation, which is an iterative optimization process of position and rotation. To optimize both position and rotation to a global optimum, a good initialization of ARAP is highly required. The mode `arap_from_init` uses Laplacian deformation for initialization, which only minimizes the error of the Laplacian coordinate that changes related to rotation. Hence Laplacian deformation is not robust enough for rotation, resulting in inaccurate initialization in the face of large rotation. As a result, some areas fail to achieve correct rotations in subsequent ARAP deformation results.
190 |
191 | 2. **How to deal with artifacts?** To address this issue, the following steps are recommended, of which the core idea is to **include as many control points as possible** for large-scale deformation: (1) If you treat a big region as a rigid part and would like to apply a large deformation, use more control points to include the whole part and manipulate these control points to deform. This allows for a better Laplacian deformation result and better initialization of ARAP deformation. (2) Edit hierarchically. If you need to apply deformation of different levels, please first add control points at the finest part and deform it. After that, you can include more control points; treat them as a rigid body; and perform deformation of larger levels.
192 |
193 | 3. More tips: (1) To more efficiently add handle points, you can set the parameter `n_rings` to 3 or 4 on the GUI interface. (2) You can press `Node` button to visualize control points and check if there are any points in the region of interest missed. Press `RGB` to switch back the Gaussian rendering.
194 |
195 | 4. The above are some operational tricks for editing with `arap_from_init`, which require a sufficient understanding of ARAP deformation or more practice and attempts. This will allow for a clearer understanding of how to operate and achieve the desired deformation results.
196 |
197 | ## SOTA Performance
198 |
199 | Quantitative comparison on D-NeRF datasets. We present the average PSNR/SSIM/LPIPS (VGG) values for novel view synthesis on dynamic scenes from D-NeRF, with each cell colored to indicate the best, second best, and third best.
200 |
201 |
202 |
203 |
204 | ## Dataset
205 |
206 | Our datareader script can recognize and read the following dataset format automatically:
207 |
208 | * [D-NeRF](https://www.albertpumarola.com/research/D-NeRF/index.html): dynamic scenes of synthetic objects ([download](https://www.dropbox.com/s/0bf6fl0ye2vz3vr/data.zip?e=1&dl=0))
209 |
210 | * [NeRF-DS](https://jokeryan.github.io/projects/nerf-ds/): dynamic scenes of specular objects ([download](https://github.com/JokerYan/NeRF-DS/releases/tag/v0.1-pre-release))
211 |
212 | * Self-captured videos: 1. install [MiVOS](https://github.com/hkchengrex/MiVOS) and place [interactive_invoke.py](data_tools/interactive_invoke.py) under the installed path. 2. Set the video path in [phone_catch.py](data_tools/phone_catch.py) and run ```python ./data_tools/phone_catch.py``` to achieve frame extraction, video segmentation, and COLMAP pose estimation in sequence. Please refer to [NeRF-Texture](https://github.com/yihua7/NeRF-Texture) for detailed tutorials.
213 |
214 | * Static self-captured scenes: For self-captured static scenes, editing is now also supported! Simply include the `--is_scene_static` argument and you are good to go!
215 |
216 | **Important Note for Using Self-captured Videos**:
217 |
218 | * Please remember to remove `--is_blender` option in your command, which causes the control points to be initialized from random point clouds instead of COLMAP point clouds.
219 | * Additionally, you can remove `--gt_alpha_mask_as_scene_mask` and add `--gt_alpha_mask_as_dynamic_mask --gs_with_motion_mask` if you want to model both the dynamic foreground masked by MiVOS and the static background simultaneously.
220 | * If removing `--is_blender` still meets the failure of control point initialization, please use the option: `--init_isotropic_gs_with_all_colmap_pcl`. This will initialize the isotropic Gaussians with all COLMAP point clouds, which can help avoid the risk of control points becoming extinct.
221 | * The dimension of hyper coordinates that seperate close but disconnected parts can be set to 2 to avoid the slow rendering: `--hyper_dim 2`.
222 |
223 |
224 | ## Acknowledgement
225 |
226 | * This framework has been adapted from the notable [Deformable 3D Gaussians](https://github.com/ingra14m/Deformable-3D-Gaussians), an excellent and pioneering work by [Ziyi Yang](https://github.com/ingra14m).
227 | ```
228 | @article{yang2023deformable3dgs,
229 | title={Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction},
230 | author={Yang, Ziyi and Gao, Xinyu and Zhou, Wen and Jiao, Shaohui and Zhang, Yuqing and Jin, Xiaogang},
231 | journal={arXiv preprint arXiv:2309.13101},
232 | year={2023}
233 | }
234 | ```
235 |
236 | * Credits to authors of [3D Gaussians](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) for their excellent code.
237 | ```
238 | @Article{kerbl3Dgaussians,
239 | author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
240 | title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
241 | journal = {ACM Transactions on Graphics},
242 | number = {4},
243 | volume = {42},
244 | month = {July},
245 | year = {2023},
246 | url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
247 | }
248 | ```
249 |
250 | ## Citing
251 | If you find our work useful, please consider citing:
252 | ```BibTeX
253 | @article{huang2023sc,
254 | title={SC-GS: Sparse-Controlled Gaussian Splatting for Editable Dynamic Scenes},
255 | author={Huang, Yi-Hua and Sun, Yang-Tian and Yang, Ziyi and Lyu, Xiaoyang and Cao, Yan-Pei and Qi, Xiaojuan},
256 | journal={arXiv preprint arXiv:2312.14937},
257 | year={2023}
258 | }
259 | ```
260 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene import Scene, DeformModel
14 | import os
15 | from tqdm import tqdm
16 | from os import makedirs
17 | from gaussian_renderer import render
18 | import torchvision
19 | from utils.general_utils import safe_state
20 | from utils.pose_utils import pose_spherical
21 | from argparse import ArgumentParser
22 | from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
23 | from gaussian_renderer import GaussianModel
24 | import imageio
25 | import numpy as np
26 | from pytorch_msssim import ms_ssim
27 | from piq import LPIPS
28 | lpips = LPIPS()
29 | from utils.image_utils import ssim as ssim_func
30 | from utils.image_utils import psnr, lpips, alex_lpips
31 |
32 |
33 | def render_set(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform):
34 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
35 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
36 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth")
37 |
38 | makedirs(render_path, exist_ok=True)
39 | makedirs(gts_path, exist_ok=True)
40 | makedirs(depth_path, exist_ok=True)
41 |
42 | # Measurement
43 | psnr_list, ssim_list, lpips_list = [], [], []
44 | ms_ssim_list, alex_lpips_list = [], []
45 |
46 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
47 | renderings = []
48 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
49 | if load2gpt_on_the_fly:
50 | view.load2device()
51 | fid = view.fid
52 | xyz = gaussians.get_xyz
53 | if deform.name == 'mlp':
54 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
55 | elif deform.name == 'node':
56 | time_input = deform.deform.expand_time(fid)
57 | d_values = deform.step(xyz.detach(), time_input, feature=gaussians.feature, motion_mask=gaussians.motion_mask)
58 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color']
59 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res)
60 | alpha = results["alpha"]
61 | rendering = torch.clamp(torch.cat([results["render"], alpha]), 0.0, 1.0)
62 |
63 | # Measurement
64 | image = rendering[:3]
65 | gt_image = torch.clamp(view.original_image.to("cuda"), 0.0, 1.0)
66 | psnr_list.append(psnr(image[None], gt_image[None]).mean())
67 | ssim_list.append(ssim_func(image[None], gt_image[None], data_range=1.).mean())
68 | lpips_list.append(lpips(image[None], gt_image[None]).mean())
69 | ms_ssim_list.append(ms_ssim(image[None], gt_image[None], data_range=1.).mean())
70 | alex_lpips_list.append(alex_lpips(image[None], gt_image[None]).mean())
71 |
72 | renderings.append(to8b(rendering.cpu().numpy()))
73 | depth = results["depth"]
74 | depth = depth / (depth.max() + 1e-5)
75 |
76 | gt = view.original_image[0:4, :, :]
77 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
78 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
79 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png"))
80 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
81 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8)
82 |
83 | # Measurement
84 | psnr_test = torch.stack(psnr_list).mean()
85 | ssim_test = torch.stack(ssim_list).mean()
86 | lpips_test = torch.stack(lpips_list).mean()
87 | ms_ssim_test = torch.stack(ms_ssim_list).mean()
88 | alex_lpips_test = torch.stack(alex_lpips_list).mean()
89 | print("\n[ITER {}] Evaluating {}: PSNR {} SSIM {} LPIPS {} MS SSIM{} ALEX_LPIPS {}".format(iteration, name, psnr_test, ssim_test, lpips_test, ms_ssim_test, alex_lpips_test))
90 |
91 |
92 | def interpolate_time(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform):
93 | render_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "renders")
94 | depth_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "depth")
95 |
96 | makedirs(render_path, exist_ok=True)
97 | makedirs(depth_path, exist_ok=True)
98 |
99 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
100 |
101 | frame = 150
102 | idx = torch.randint(0, len(views), (1,)).item()
103 | view = views[idx]
104 | renderings = []
105 | for t in tqdm(range(0, frame, 1), desc="Rendering progress"):
106 | fid = torch.Tensor([t / (frame - 1)]).cuda()
107 | xyz = gaussians.get_xyz
108 | if deform.name == 'deform':
109 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
110 | elif deform.name == 'node':
111 | time_input = deform.deform.expand_time(fid)
112 | d_values = deform.step(xyz.detach(), time_input, feature=gaussians.feature, motion_mask=gaussians.motion_mask)
113 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color']
114 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res)
115 | rendering = results["render"]
116 | renderings.append(to8b(rendering.cpu().numpy()))
117 | depth = results["depth"]
118 | depth = depth / (depth.max() + 1e-5)
119 |
120 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(t) + ".png"))
121 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(t) + ".png"))
122 |
123 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
124 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8)
125 |
126 |
127 | def interpolate_all(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform):
128 | render_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "renders")
129 | depth_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "depth")
130 |
131 | makedirs(render_path, exist_ok=True)
132 | makedirs(depth_path, exist_ok=True)
133 |
134 | frame = 150
135 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, frame + 1)[:-1]], 0)
136 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
137 |
138 | idx = torch.randint(0, len(views), (1,)).item()
139 | view = views[idx] # Choose a specific time for rendering
140 |
141 | renderings = []
142 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")):
143 | fid = torch.Tensor([i / (frame - 1)]).cuda()
144 |
145 | matrix = np.linalg.inv(np.array(pose))
146 | R = -np.transpose(matrix[:3, :3])
147 | R[:, 0] = -R[:, 0]
148 | T = -matrix[:3, 3]
149 |
150 | view.reset_extrinsic(R, T)
151 |
152 | xyz = gaussians.get_xyz
153 | if deform.name == 'mlp':
154 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
155 | elif deform.name == 'node':
156 | time_input = deform.deform.expand_time(fid)
157 |
158 | d_values = deform.step(xyz.detach(), time_input, feature=gaussians.feature, motion_mask=gaussians.motion_mask)
159 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color']
160 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res)
161 | rendering = torch.clamp(results["render"], 0.0, 1.0)
162 | renderings.append(to8b(rendering.cpu().numpy()))
163 | depth = results["depth"]
164 | depth = depth / (depth.max() + 1e-5)
165 |
166 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png"))
167 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png"))
168 |
169 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1)
170 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8)
171 |
172 |
173 | def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool, mode: str, load2device_on_the_fly=False):
174 | with torch.no_grad():
175 |
176 | deform = DeformModel(K=dataset.K, deform_type=dataset.deform_type, is_blender=dataset.is_blender, skinning=dataset.skinning, hyper_dim=dataset.hyper_dim, node_num=dataset.node_num, pred_opacity=dataset.pred_opacity, pred_color=dataset.pred_color, use_hash=dataset.use_hash, hash_time=dataset.hash_time, d_rot_as_res=dataset.d_rot_as_res, local_frame=dataset.local_frame, progressive_brand_time=dataset.progressive_brand_time, max_d_scale=dataset.max_d_scale)
177 | deform.load_weights(dataset.model_path, iteration=iteration)
178 |
179 | gs_fea_dim = deform.deform.node_num if dataset.skinning and deform.name == 'node' else dataset.hyper_dim
180 | gaussians = GaussianModel(dataset.sh_degree, fea_dim=gs_fea_dim, with_motion_mask=dataset.gs_with_motion_mask)
181 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
182 |
183 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
184 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
185 |
186 | if mode == "render":
187 | render_func = render_set
188 | elif mode == "time":
189 | render_func = interpolate_time
190 | else:
191 | render_func = interpolate_all
192 |
193 | if not skip_train:
194 | render_func(dataset.model_path, load2device_on_the_fly, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, deform)
195 |
196 | if not skip_test:
197 | render_func(dataset.model_path, load2device_on_the_fly, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, deform)
198 |
199 |
200 | if __name__ == "__main__":
201 | # Set up command line argument parser
202 | parser = ArgumentParser(description="Testing script parameters")
203 | model = ModelParams(parser, sentinel=True)
204 | pipeline = PipelineParams(parser)
205 | op = OptimizationParams(parser)
206 | parser.add_argument("--iteration", default=-1, type=int)
207 | parser.add_argument("--skip_train", action="store_true")
208 | parser.add_argument("--skip_test", action="store_true")
209 | parser.add_argument("--quiet", action="store_true")
210 | parser.add_argument("--mode", default='render', choices=['render', 'time', 'view', 'all', 'pose', 'original'])
211 |
212 | parser.add_argument('--gui', action='store_true', help="start a GUI")
213 | parser.add_argument('--W', type=int, default=800, help="GUI width")
214 | parser.add_argument('--H', type=int, default=800, help="GUI height")
215 | parser.add_argument('--elevation', type=float, default=0, help="default GUI camera elevation")
216 | parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center")
217 | parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy")
218 |
219 | parser.add_argument('--ip', type=str, default="127.0.0.1")
220 | parser.add_argument('--port', type=int, default=6009)
221 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
222 | parser.add_argument("--test_iterations", nargs="+", type=int,
223 | default=[5000, 6000, 7_000] + list(range(10000, 80_0001, 1000)))
224 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 10_000, 20_000, 30_000, 40000])
225 | # parser.add_argument("--quiet", action="store_true")
226 | parser.add_argument("--deform-type", type=str, default='mlp')
227 |
228 | args = get_combined_args(parser)
229 | if not args.model_path.endswith(args.deform_type):
230 | args.model_path = os.path.join(os.path.dirname(os.path.normpath(args.model_path)), os.path.basename(os.path.normpath(args.model_path)) + f'_{args.deform_type}')
231 | print("Rendering " + args.model_path)
232 |
233 | # Initialize system state (RNG)
234 | safe_state(args.quiet)
235 |
236 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.mode, load2device_on_the_fly=args.load2gpu_on_the_fly)
237 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.22.0
2 | opencv-python==4.5.5.62
3 | Pillow==7.0.0
4 | PyYAML==6.0
5 | scipy==1.10.1
6 | tensorboard==2.14.0
7 | torch==1.12.1+cu113 # or any later versions
8 | tqdm==4.66.1
9 | imageio
10 | plyfile
11 | piq
12 | dearpygui
13 | lpips
14 | pytorch_msssim
15 | matplotlib
16 | scikit-image
17 | git+https://github.com/Po-Hsun-Su/pytorch-ssim.git
18 | git+https://github.com/facebookresearch/pytorch3d.git
19 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from scene.deform_model import DeformModel
19 | from arguments import ModelParams
20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
21 |
22 |
23 | class Scene:
24 | gaussians: GaussianModel
25 |
26 | def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
27 | """b
28 | :param path: Path to colmap scene main folder.
29 | """
30 | self.model_path = args.model_path
31 | self.loaded_iter = None
32 | self.gaussians = gaussians
33 |
34 | if load_iteration:
35 | if load_iteration == -1:
36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
37 | else:
38 | self.loaded_iter = load_iteration
39 | print("Loading trained model at iteration {}".format(self.loaded_iter))
40 |
41 | self.train_cameras = {}
42 | self.test_cameras = {}
43 |
44 | if os.path.exists(os.path.join(args.source_path, "sparse")) or os.path.exists(os.path.join(args.source_path, "colmap_sparse")):
45 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
46 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
47 | print("Found transforms_train.json file, assuming Blender data set!")
48 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
49 | elif os.path.exists(os.path.join(args.source_path, "cameras_sphere.npz")):
50 | print("Found cameras_sphere.npz file, assuming DTU data set!")
51 | scene_info = sceneLoadTypeCallbacks["DTU"](args.source_path, "cameras_sphere.npz", "cameras_sphere.npz")
52 | elif os.path.exists(os.path.join(args.source_path, "dataset.json")):
53 | print("Found dataset.json file, assuming Nerfies data set!")
54 | scene_info = sceneLoadTypeCallbacks["nerfies"](args.source_path, args.eval)
55 | elif os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")):
56 | print("Found calibration_full.json, assuming Neu3D data set!")
57 | scene_info = sceneLoadTypeCallbacks["plenopticVideo"](args.source_path, args.eval, 24)
58 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")):
59 | print("Found calibration_full.json, assuming Dynamic-360 data set!")
60 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.images, args.eval)
61 | elif os.path.exists(os.path.join(args.source_path, "train_meta.json")):
62 | print("Found train_meta.json, assuming CMU data set!")
63 | scene_info = sceneLoadTypeCallbacks["CMU"](args.source_path)
64 | else:
65 | assert False, "Could not recognize scene type!"
66 |
67 | if not self.loaded_iter:
68 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply"), 'wb') as dest_file:
69 | dest_file.write(src_file.read())
70 | json_cams = []
71 | camlist = []
72 | if scene_info.test_cameras:
73 | camlist.extend(scene_info.test_cameras)
74 | if scene_info.train_cameras:
75 | camlist.extend(scene_info.train_cameras)
76 | for id, cam in enumerate(camlist):
77 | json_cams.append(camera_to_JSON(id, cam))
78 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
79 | json.dump(json_cams, file)
80 |
81 | # Read flow data
82 | self.flow_dir = os.path.join(args.source_path, "raft_neighbouring")
83 | flow_list = os.listdir(self.flow_dir) if os.path.exists(self.flow_dir) else []
84 | flow_dirs_list = []
85 | for cam in scene_info.train_cameras:
86 | flow_dirs_list.append([os.path.join(self.flow_dir, flow_dir) for flow_dir in flow_list if flow_dir.startswith(cam.image_name+'.')])
87 |
88 | # if shuffle:
89 | # random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
90 | # random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
91 |
92 | self.cameras_extent = scene_info.nerf_normalization["radius"]
93 |
94 | for resolution_scale in resolution_scales:
95 | print("Loading Training Cameras")
96 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, flow_dirs_list)
97 | print("Loading Test Cameras")
98 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
99 |
100 | if self.loaded_iter:
101 | self.gaussians.load_ply(os.path.join(self.model_path,
102 | "point_cloud",
103 | "iteration_" + str(self.loaded_iter),
104 | "point_cloud.ply"),
105 | og_number_points=len(scene_info.point_cloud.points))
106 | else:
107 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
108 |
109 | def save(self, iteration):
110 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
111 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
112 |
113 | def getTrainCameras(self, scale=1.0):
114 | return self.train_cameras[scale]
115 |
116 | def getTestCameras(self, scale=1.0):
117 | return self.test_cameras[scale]
118 |
--------------------------------------------------------------------------------
/scene/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/cameras.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/cameras.cpython-38.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/colmap_loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/colmap_loader.cpython-38.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/dataset_readers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/dataset_readers.cpython-38.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/deform_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/deform_model.cpython-38.pyc
--------------------------------------------------------------------------------
/scene/__pycache__/gaussian_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/gaussian_model.cpython-38.pyc
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16 |
17 |
18 | class Camera(nn.Module):
19 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", fid=None, depth=None, flow_dirs=[]):
20 | super(Camera, self).__init__()
21 |
22 | self.uid = uid
23 | self.colmap_id = colmap_id
24 | self.R = R
25 | self.T = T
26 | self.FoVx = FoVx
27 | self.FoVy = FoVy
28 | self.image_name = image_name
29 | self.flow_dirs = flow_dirs
30 |
31 | try:
32 | self.data_device = torch.device(data_device)
33 | except Exception as e:
34 | print(e)
35 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device")
36 | self.data_device = torch.device("cuda")
37 |
38 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
39 | self.fid = torch.Tensor(np.array([fid])).to(self.data_device)
40 | self.image_width = self.original_image.shape[2]
41 | self.image_height = self.original_image.shape[1]
42 | self.depth = torch.Tensor(depth).to(self.data_device) if depth is not None else None
43 | self.gt_alpha_mask = gt_alpha_mask
44 |
45 | if gt_alpha_mask is not None:
46 | self.gt_alpha_mask = self.gt_alpha_mask.to(self.data_device)
47 | # self.original_image *= gt_alpha_mask.to(self.data_device)
48 |
49 | self.zfar = 100.0
50 | self.znear = 0.01
51 |
52 | self.trans = trans
53 | self.scale = scale
54 |
55 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).to(self.data_device)
56 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0, 1).to(self.data_device)
57 | self.full_proj_transform = (
58 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
59 | self.camera_center = self.world_view_transform.inverse()[3, :3]
60 |
61 | def reset_extrinsic(self, R, T):
62 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, self.trans, self.scale)).transpose(0, 1).cuda()
63 | self.full_proj_transform = (
64 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
65 | self.camera_center = self.world_view_transform.inverse()[3, :3]
66 |
67 | def load2device(self, data_device='cuda'):
68 | self.original_image = self.original_image.to(data_device)
69 | self.world_view_transform = self.world_view_transform.to(data_device)
70 | self.projection_matrix = self.projection_matrix.to(data_device)
71 | self.full_proj_transform = self.full_proj_transform.to(data_device)
72 | self.camera_center = self.camera_center.to(data_device)
73 | self.fid = self.fid.to(data_device)
74 |
75 |
76 | class MiniCam:
77 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
78 | self.image_width = width
79 | self.image_height = height
80 | self.FoVy = fovy
81 | self.FoVx = fovx
82 | self.znear = znear
83 | self.zfar = zfar
84 | self.world_view_transform = world_view_transform
85 | self.full_proj_transform = full_proj_transform
86 | view_inv = torch.inverse(self.world_view_transform)
87 | self.camera_center = view_inv[3][:3]
88 |
89 | def reset_extrinsic(self, R, T):
90 | self.world_view_transform = torch.tensor(getWorld2View2(R, T)).transpose(0, 1).cuda()
91 | self.full_proj_transform = (
92 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
93 | self.camera_center = self.world_view_transform.inverse()[3, :3]
94 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"])
18 | Camera = collections.namedtuple(
19 | "Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22 | Point3D = collections.namedtuple(
23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 | def qvec2rotmat(qvec):
44 | return np.array([
45 | [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]])
54 |
55 |
56 | def rotmat2qvec(R):
57 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
58 | K = np.array([
59 | [Rxx - Ryy - Rzz, 0, 0, 0],
60 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
61 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
62 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
63 | eigvals, eigvecs = np.linalg.eigh(K)
64 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
65 | if qvec[0] < 0:
66 | qvec *= -1
67 | return qvec
68 |
69 |
70 | class Image(BaseImage):
71 | def qvec2rotmat(self):
72 | return qvec2rotmat(self.qvec)
73 |
74 |
75 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
76 | """Read and unpack the next bytes from a binary file.
77 | :param fid:
78 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
79 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
80 | :param endian_character: Any of {@, =, <, >, !}
81 | :return: Tuple of read and unpacked values.
82 | """
83 | data = fid.read(num_bytes)
84 | return struct.unpack(endian_character + format_char_sequence, data)
85 |
86 |
87 | def read_points3D_text(path):
88 | """
89 | see: src/base/reconstruction.cc
90 | void Reconstruction::ReadPoints3DText(const std::string& path)
91 | void Reconstruction::WritePoints3DText(const std::string& path)
92 | """
93 | xyzs = None
94 | rgbs = None
95 | errors = None
96 | with open(path, "r") as fid:
97 | while True:
98 | line = fid.readline()
99 | if not line:
100 | break
101 | line = line.strip()
102 | if len(line) > 0 and line[0] != "#":
103 | elems = line.split()
104 | xyz = np.array(tuple(map(float, elems[1:4])))
105 | rgb = np.array(tuple(map(int, elems[4:7])))
106 | error = np.array(float(elems[7]))
107 | if xyzs is None:
108 | xyzs = xyz[None, ...]
109 | rgbs = rgb[None, ...]
110 | errors = error[None, ...]
111 | else:
112 | xyzs = np.append(xyzs, xyz[None, ...], axis=0)
113 | rgbs = np.append(rgbs, rgb[None, ...], axis=0)
114 | errors = np.append(errors, error[None, ...], axis=0)
115 | return xyzs, rgbs, errors
116 |
117 |
118 | def read_points3D_binary(path_to_model_file):
119 | """
120 | see: src/base/reconstruction.cc
121 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
122 | void Reconstruction::WritePoints3DBinary(const std::string& path)
123 | """
124 |
125 | with open(path_to_model_file, "rb") as fid:
126 | num_points = read_next_bytes(fid, 8, "Q")[0]
127 |
128 | xyzs = np.empty((num_points, 3))
129 | rgbs = np.empty((num_points, 3))
130 | errors = np.empty((num_points, 1))
131 |
132 | for p_id in range(num_points):
133 | binary_point_line_properties = read_next_bytes(
134 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
135 | xyz = np.array(binary_point_line_properties[1:4])
136 | rgb = np.array(binary_point_line_properties[4:7])
137 | error = np.array(binary_point_line_properties[7])
138 | track_length = read_next_bytes(
139 | fid, num_bytes=8, format_char_sequence="Q")[0]
140 | track_elems = read_next_bytes(
141 | fid, num_bytes=8 * track_length,
142 | format_char_sequence="ii" * track_length)
143 | xyzs[p_id] = xyz
144 | rgbs[p_id] = rgb
145 | errors[p_id] = error
146 | return xyzs, rgbs, errors
147 |
148 |
149 | def read_intrinsics_text(path):
150 | """
151 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
152 | """
153 | cameras = {}
154 | with open(path, "r") as fid:
155 | while True:
156 | line = fid.readline()
157 | if not line:
158 | break
159 | line = line.strip()
160 | if len(line) > 0 and line[0] != "#":
161 | elems = line.split()
162 | camera_id = int(elems[0])
163 | model = elems[1]
164 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
165 | width = int(elems[2])
166 | height = int(elems[3])
167 | params = np.array(tuple(map(float, elems[4:])))
168 | cameras[camera_id] = Camera(id=camera_id, model=model,
169 | width=width, height=height,
170 | params=params)
171 | return cameras
172 |
173 |
174 | def read_extrinsics_binary(path_to_model_file):
175 | """
176 | see: src/base/reconstruction.cc
177 | void Reconstruction::ReadImagesBinary(const std::string& path)
178 | void Reconstruction::WriteImagesBinary(const std::string& path)
179 | """
180 | images = {}
181 | with open(path_to_model_file, "rb") as fid:
182 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
183 | for _ in range(num_reg_images):
184 | binary_image_properties = read_next_bytes(
185 | fid, num_bytes=64, format_char_sequence="idddddddi")
186 | image_id = binary_image_properties[0]
187 | qvec = np.array(binary_image_properties[1:5])
188 | tvec = np.array(binary_image_properties[5:8])
189 | camera_id = binary_image_properties[8]
190 | image_name = ""
191 | current_char = read_next_bytes(fid, 1, "c")[0]
192 | while current_char != b"\x00": # look for the ASCII 0 entry
193 | image_name += current_char.decode("utf-8")
194 | current_char = read_next_bytes(fid, 1, "c")[0]
195 | num_points2D = read_next_bytes(fid, num_bytes=8,
196 | format_char_sequence="Q")[0]
197 | x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D,
198 | format_char_sequence="ddq" * num_points2D)
199 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
200 | tuple(map(float, x_y_id_s[1::3]))])
201 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
202 | images[image_id] = Image(
203 | id=image_id, qvec=qvec, tvec=tvec,
204 | camera_id=camera_id, name=image_name,
205 | xys=xys, point3D_ids=point3D_ids)
206 | return images
207 |
208 |
209 | def read_intrinsics_binary(path_to_model_file):
210 | """
211 | see: src/base/reconstruction.cc
212 | void Reconstruction::WriteCamerasBinary(const std::string& path)
213 | void Reconstruction::ReadCamerasBinary(const std::string& path)
214 | """
215 | cameras = {}
216 | with open(path_to_model_file, "rb") as fid:
217 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
218 | for _ in range(num_cameras):
219 | camera_properties = read_next_bytes(
220 | fid, num_bytes=24, format_char_sequence="iiQQ")
221 | camera_id = camera_properties[0]
222 | model_id = camera_properties[1]
223 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
224 | width = camera_properties[2]
225 | height = camera_properties[3]
226 | num_params = CAMERA_MODEL_IDS[model_id].num_params
227 | params = read_next_bytes(fid, num_bytes=8 * num_params,
228 | format_char_sequence="d" * num_params)
229 | cameras[camera_id] = Camera(id=camera_id,
230 | model=model_name,
231 | width=width,
232 | height=height,
233 | params=np.array(params))
234 | assert len(cameras) == num_cameras
235 | return cameras
236 |
237 |
238 | def read_extrinsics_text(path):
239 | """
240 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
241 | """
242 | images = {}
243 | with open(path, "r") as fid:
244 | while True:
245 | line = fid.readline()
246 | if not line:
247 | break
248 | line = line.strip()
249 | if len(line) > 0 and line[0] != "#":
250 | elems = line.split()
251 | image_id = int(elems[0])
252 | qvec = np.array(tuple(map(float, elems[1:5])))
253 | tvec = np.array(tuple(map(float, elems[5:8])))
254 | camera_id = int(elems[8])
255 | image_name = elems[9]
256 | elems = fid.readline().split()
257 | xys = np.column_stack([tuple(map(float, elems[0::3])),
258 | tuple(map(float, elems[1::3]))])
259 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
260 | images[image_id] = Image(
261 | id=image_id, qvec=qvec, tvec=tvec,
262 | camera_id=camera_id, name=image_name,
263 | xys=xys, point3D_ids=point3D_ids)
264 | return images
265 |
266 |
267 | def read_colmap_bin_array(path):
268 | """
269 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
270 |
271 | :param path: path to the colmap binary file.
272 | :return: nd array with the floating point values in the value
273 | """
274 | with open(path, "rb") as fid:
275 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
276 | usecols=(0, 1, 2), dtype=int)
277 | fid.seek(0)
278 | num_delimiter = 0
279 | byte = fid.read(1)
280 | while True:
281 | if byte == b"&":
282 | num_delimiter += 1
283 | if num_delimiter >= 3:
284 | break
285 | byte = fid.read(1)
286 | array = np.fromfile(fid, np.float32)
287 | array = array.reshape((width, height, channels), order="F")
288 | return np.transpose(array, (1, 0, 2)).squeeze()
289 |
--------------------------------------------------------------------------------
/scene/deform_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils.time_utils import DeformNetwork, ControlNodeWarp, StaticNetwork
5 | import os
6 | from utils.system_utils import searchForMaxIteration
7 | from utils.general_utils import get_expon_lr_func
8 |
9 |
10 | model_dict = {'mlp': DeformNetwork, 'node': ControlNodeWarp, 'static': StaticNetwork}
11 |
12 |
13 | class DeformModel:
14 | def __init__(self, deform_type='node', is_blender=False, d_rot_as_res=True, **kwargs):
15 | self.deform = model_dict[deform_type](is_blender=is_blender, d_rot_as_res=d_rot_as_res, **kwargs).cuda()
16 | self.name = self.deform.name
17 | self.optimizer = None
18 | self.spatial_lr_scale = 5
19 | self.d_rot_as_res = d_rot_as_res
20 |
21 | @property
22 | def reg_loss(self):
23 | return self.deform.reg_loss
24 |
25 | def step(self, xyz, time_emb, iteration=0, **kwargs):
26 | return self.deform(xyz, time_emb, iteration=iteration, **kwargs)
27 |
28 | def train_setting(self, training_args):
29 | l = [
30 | {'params': group['params'],
31 | 'lr': training_args.position_lr_init * self.spatial_lr_scale * training_args.deform_lr_scale,
32 | "name": group['name']}
33 | for group in self.deform.trainable_parameters()
34 | ]
35 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
36 |
37 | self.deform_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale * training_args.deform_lr_scale, lr_final=training_args.position_lr_final * training_args.deform_lr_scale, lr_delay_mult=training_args.position_lr_delay_mult, max_steps=training_args.deform_lr_max_steps)
38 | if self.name == 'node':
39 | self.deform.as_gaussians.training_setup(training_args)
40 |
41 | def save_weights(self, model_path, iteration):
42 | out_weights_path = os.path.join(model_path, "deform/iteration_{}".format(iteration))
43 | os.makedirs(out_weights_path, exist_ok=True)
44 | torch.save(self.deform.state_dict(), os.path.join(out_weights_path, 'deform.pth'))
45 |
46 | def load_weights(self, model_path, iteration=-1):
47 | if iteration == -1:
48 | loaded_iter = searchForMaxIteration(os.path.join(model_path, "deform"))
49 | else:
50 | loaded_iter = iteration
51 | weights_path = os.path.join(model_path, "deform/iteration_{}/deform.pth".format(loaded_iter))
52 | if os.path.exists(weights_path):
53 | self.deform.load_state_dict(torch.load(weights_path))
54 | return True
55 | else:
56 | return False
57 |
58 | def update_learning_rate(self, iteration):
59 | for param_group in self.optimizer.param_groups:
60 | if param_group["name"] == "deform" or param_group["name"] == "mlp" or 'node' in param_group['name']:
61 | lr = self.deform_scheduler_args(iteration)
62 | param_group['lr'] = lr
63 | return lr
64 |
65 | def densify(self, max_grad, x, x_grad, **kwargs):
66 | if self.name == 'node':
67 | self.deform.densify(max_grad=max_grad, optimizer=self.optimizer, x=x, x_grad=x_grad, **kwargs)
68 | else:
69 | return
70 |
71 | def update(self, iteration):
72 | self.deform.update(iteration)
73 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from scene import Scene
15 | import uuid
16 | from utils.image_utils import psnr, lpips, alex_lpips
17 | from utils.image_utils import ssim as ssim_func
18 | from piq import LPIPS
19 | lpips = LPIPS()
20 | from argparse import Namespace
21 | from pytorch_msssim import ms_ssim
22 |
23 | try:
24 | from torch.utils.tensorboard import SummaryWriter
25 |
26 | TENSORBOARD_FOUND = True
27 | except ImportError:
28 | TENSORBOARD_FOUND = False
29 |
30 |
31 | def prepare_output_and_logger(args):
32 | if not args.model_path:
33 | if os.getenv('OAR_JOB_ID'):
34 | unique_str = os.getenv('OAR_JOB_ID')
35 | else:
36 | unique_str = str(uuid.uuid4())
37 | args.model_path = os.path.join("./output/", unique_str[0:10])
38 |
39 | # Set up output folder
40 | print("Output folder: {}".format(args.model_path))
41 | os.makedirs(args.model_path, exist_ok=True)
42 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
43 | cfg_log_f.write(str(Namespace(**vars(args))))
44 |
45 | # Create Tensorboard writer
46 | tb_writer = None
47 | if TENSORBOARD_FOUND:
48 | tb_writer = SummaryWriter(args.model_path)
49 | else:
50 | print("Tensorboard not available: not logging progress")
51 | return tb_writer
52 |
53 |
54 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, renderArgs, deform, load2gpu_on_the_fly, progress_bar=None):
55 | if tb_writer:
56 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
57 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
58 | tb_writer.add_scalar('iter_time', elapsed, iteration)
59 |
60 | test_psnr = 0.0
61 | test_ssim = 0.0
62 | test_lpips = 1e10
63 | test_ms_ssim = 0.0
64 | test_alex_lpips = 1e10
65 | # Report test and samples of training set
66 | if iteration in testing_iterations:
67 | torch.cuda.empty_cache()
68 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()},
69 | {'name': 'train',
70 | 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
71 | for config in validation_configs:
72 | if config['cameras'] and len(config['cameras']) > 0:
73 | # images = torch.tensor([], device="cuda")
74 | # gts = torch.tensor([], device="cuda")
75 | psnr_list, ssim_list, lpips_list, l1_list = [], [], [], []
76 | ms_ssim_list, alex_lpips_list = [], []
77 | for idx, viewpoint in enumerate(config['cameras']):
78 | if load2gpu_on_the_fly:
79 | viewpoint.load2device()
80 | fid = viewpoint.fid
81 | xyz = scene.gaussians.get_xyz
82 |
83 | if deform.name == 'mlp':
84 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
85 | elif deform.name == 'node':
86 | time_input = deform.deform.expand_time(fid)
87 | else:
88 | time_input = 0
89 |
90 | d_values = deform.step(xyz.detach(), time_input, feature=scene.gaussians.feature, is_training=False, motion_mask=scene.gaussians.motion_mask, camera_center=viewpoint.camera_center)
91 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color']
92 |
93 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs, d_xyz=d_xyz, d_rotation=d_rotation, d_scaling=d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res)["render"], 0.0, 1.0)
94 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
95 |
96 | l1_list.append(l1_loss(image[None], gt_image[None]).mean())
97 | psnr_list.append(psnr(image[None], gt_image[None]).mean())
98 | ssim_list.append(ssim_func(image[None], gt_image[None], data_range=1.).mean())
99 | lpips_list.append(lpips(image[None], gt_image[None]).mean())
100 | ms_ssim_list.append(ms_ssim(image[None], gt_image[None], data_range=1.).mean())
101 | alex_lpips_list.append(alex_lpips(image[None], gt_image[None]).mean())
102 |
103 | # images = torch.cat((images, image.unsqueeze(0)), dim=0)
104 | # gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
105 |
106 | if load2gpu_on_the_fly:
107 | viewpoint.load2device('cpu')
108 | if tb_writer and (idx < 5):
109 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
110 | if iteration == testing_iterations[0]:
111 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
112 |
113 | l1_test = torch.stack(l1_list).mean()
114 | psnr_test = torch.stack(psnr_list).mean()
115 | ssim_test = torch.stack(ssim_list).mean()
116 | lpips_test = torch.stack(lpips_list).mean()
117 | ms_ssim_test = torch.stack(ms_ssim_list).mean()
118 | alex_lpips_test = torch.stack(alex_lpips_list).mean()
119 | if config['name'] == 'test' or len(validation_configs[0]['cameras']) == 0:
120 | test_psnr = psnr_test
121 | test_ssim = ssim_test
122 | test_lpips = lpips_test
123 | test_ms_ssim = ms_ssim_test
124 | test_alex_lpips = alex_lpips_test
125 | if progress_bar is None:
126 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {} MS SSIM{} ALEX_LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test, ms_ssim_test, alex_lpips_test))
127 | else:
128 | progress_bar.set_description("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {} MS SSIM {} ALEX_LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test, ms_ssim_test, alex_lpips_test))
129 | if tb_writer:
130 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
131 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
132 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ssim', test_ssim, iteration)
133 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - lpips', test_lpips, iteration)
134 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ms-ssim', test_ms_ssim, iteration)
135 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - alex-lpips', test_alex_lpips, iteration)
136 |
137 | if tb_writer:
138 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
139 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
140 | torch.cuda.empty_cache()
141 |
142 | return test_psnr, test_ssim, test_lpips, test_ms_ssim, test_alex_lpips
143 |
144 |
--------------------------------------------------------------------------------
/train_gui.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --is_blender --eval --gui --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800
--------------------------------------------------------------------------------
/train_gui_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class DeformKeypoints:
6 | def __init__(self) -> None:
7 | self.keypoints3d_list = [] # list of keypoints group
8 | self.keypoints_idx_list = [] # keypoints index
9 | self.keypoints3d_delta_list = []
10 | self.selective_keypoints_idx_list = [] # keypoints index
11 | self.idx2group = {}
12 |
13 | self.selective_rotation_keypoints_idx_list = []
14 | # self.rotation_idx2group = {}
15 |
16 | def get_kpt_idx(self,):
17 | return self.keypoints_idx_list
18 |
19 | def get_kpt(self,):
20 | return self.keypoints3d_list
21 |
22 | def get_kpt_delta(self,):
23 | return self.keypoints3d_delta_list
24 |
25 | def get_deformed_kpt_np(self, rate=1.):
26 | return np.array(self.keypoints3d_list) + np.array(self.keypoints3d_delta_list) * rate
27 |
28 | def add_kpts(self, keypoints_coord, keypoints_idx, expand=False):
29 | # keypoints3d: [N, 3], keypoints_idx: [N,], torch.tensor
30 | # self.selective_keypoints_idx_list.clear()
31 | selective_keypoints_idx_list = [] if not expand else self.selective_keypoints_idx_list
32 | for idx in range(len(keypoints_idx)):
33 | if not self.contain_kpt(keypoints_idx[idx].item()):
34 | selective_keypoints_idx_list.append(len(self.keypoints_idx_list))
35 | self.keypoints_idx_list.append(keypoints_idx[idx].item())
36 | self.keypoints3d_list.append(keypoints_coord[idx].cpu().numpy())
37 | self.keypoints3d_delta_list.append(np.zeros_like(self.keypoints3d_list[-1]))
38 |
39 | for kpt_idx in keypoints_idx:
40 | self.idx2group[kpt_idx.item()] = selective_keypoints_idx_list
41 |
42 | self.selective_keypoints_idx_list = selective_keypoints_idx_list
43 |
44 | def contain_kpt(self, idx):
45 | # idx: int
46 | if idx in self.keypoints_idx_list:
47 | return True
48 | else:
49 | return False
50 |
51 | def select_kpt(self, idx):
52 | # idx: int
53 | # output: idx list of this group
54 | if idx in self.keypoints_idx_list:
55 | self.selective_keypoints_idx_list = self.idx2group[idx]
56 |
57 | def select_rotation_kpt(self, idx):
58 | if idx in self.keypoints_idx_list:
59 | self.selective_rotation_keypoints_idx_list = self.idx2group[idx]
60 |
61 | def get_rotation_center(self,):
62 | selected_rotation_points = self.get_deformed_kpt_np()[self.selective_rotation_keypoints_idx_list]
63 | return selected_rotation_points.mean(axis=0)
64 |
65 | def get_selective_center(self,):
66 | selected_points = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list]
67 | return selected_points.mean(axis=0)
68 |
69 | def delete_kpt(self, idx):
70 | for kidx in self.selective_keypoints_idx_list:
71 | list_idx = self.idx2group.pop(kidx)
72 | self.keypoints3d_delta_list.pop(list_idx)
73 | self.keypoints3d_list.pop(list_idx)
74 | self.keypoints_idx_list.pop(list_idx)
75 |
76 | def delete_batch_ktps(self, batch_idx):
77 | pass
78 |
79 | def update_delta(self, delta):
80 | # delta: [3,], np.array
81 | for idx in self.selective_keypoints_idx_list:
82 | self.keypoints3d_delta_list[idx] += delta
83 |
84 | def set_delta(self, delta):
85 | # delta: [N, 3], np.array
86 | for id, idx in enumerate(self.selective_keypoints_idx_list):
87 | self.keypoints3d_delta_list[idx] = delta[id]
88 |
89 |
90 | def set_rotation_delta(self, rot_mat):
91 | kpts3d = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list]
92 | kpts3d_mean = kpts3d.mean(axis=0)
93 | kpts3d = (kpts3d - kpts3d_mean) @ rot_mat.T + kpts3d_mean
94 | delta = kpts3d - np.array(self.keypoints3d_list)[self.selective_keypoints_idx_list]
95 | for id, idx in enumerate(self.selective_keypoints_idx_list):
96 | self.keypoints3d_delta_list[idx] = delta[id]
97 |
--------------------------------------------------------------------------------
/utils/arap_deform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from utils.deform_utils import cal_laplacian, cal_connectivity_from_points,\
4 | produce_edge_matrix_nfmt, lstsq_with_handles, cal_verts_deg, rigid_align
5 | from utils.other_utils import matrix_to_quaternion
6 |
7 |
8 | def cal_L_from_points(points, return_nn_idx=False):
9 | # points: (N, 3)
10 | Nv = len(points)
11 | L = torch.eye(Nv).cuda()
12 |
13 | radius = 0.3 #
14 | K = 10
15 | knn_res = ball_query(points[None], points[None], K=K, radius=radius, return_nn=False)
16 | nn_dist, nn_idx = knn_res.dists[0], knn_res.idx[0] # [Nv, K], [Nv, K]
17 |
18 | for idx, cur_nn_idx in enumerate(nn_idx):
19 | real_cur_nn_idx = cur_nn_idx[cur_nn_idx != -1]
20 | real_cur_nn_idx = real_cur_nn_idx[real_cur_nn_idx != idx]
21 | L[idx, idx] = len(real_cur_nn_idx)
22 | L[idx][real_cur_nn_idx] = -1
23 |
24 | if return_nn_idx:
25 | return L, nn_idx
26 | else:
27 | return L
28 |
29 |
30 | def mask_softmax(x, mask, dim=1):
31 | # x: (N, K), mask: (N, K) 0/1
32 | x = torch.exp(x)
33 | x = x * mask
34 | x = x / x.sum(dim=dim, keepdim=True)
35 | return x
36 |
37 |
38 | class ARAPDeformer:
39 | def __init__(self, verts, K=10, radius=0.3, point_mask=None, trajectory=None, node_radius=None) -> None:
40 | # verts: (N, 3), one_ring_idx: (N, K)
41 | self.device = verts.device
42 | self.verts = verts
43 | self.verts_copy = verts.clone()
44 | self.radius = radius
45 | self.K = K
46 | self.N = len(verts)
47 |
48 | self.ii, self.jj, self.nn, weight = cal_connectivity_from_points(self.verts, self.radius, self.K, trajectory=trajectory, node_radius=node_radius)
49 | self.L = cal_laplacian(Nv=self.N, ii=self.ii, jj=self.jj, nn=self.nn)
50 | # self.L = cal_L_from_points(points=self.verts)
51 |
52 | ##### add learnable deformation weights #####
53 | self.vert_deg = cal_verts_deg(self.N, self.ii)
54 | # weight = torch.ones(self.N, K).float().cuda() # [Nv, K]
55 | # weight[self.ii, self.nn] = -1 / self.vert_deg[self.ii]
56 | self.weight = torch.nn.Parameter(weight, requires_grad=True) # [Nv, K]
57 | self.weight_mask = torch.zeros(self.N, K).float().cuda() # [Nv, K]
58 | self.weight_mask[self.ii, self.nn] = 1
59 |
60 | self.L_opt = torch.eye(self.N).cuda() # replace all the self.L with self.L_opt! s.t. weight is in [0,1], easy to optimize.
61 | self.L_is_degenerate = False
62 | self.cal_L_opt()
63 | self.b = torch.mm(self.L_opt, self.verts) # [Nv, 3]
64 |
65 | self.point_mask = point_mask # [N,]
66 |
67 | def cal_L_opt(self):
68 | self.normalized_weight = self.weight
69 | self.L_opt[self.ii, self.jj] = - self.normalized_weight[self.ii, self.nn] # [Nv, Nv]
70 | self.L_is_degenerate = (torch.linalg.matrix_rank(self.L_opt) < self.L_opt.shape[0])
71 | if self.L_is_degenerate:
72 | print("L_opt is not invertible, use pseudo inverse instead")
73 |
74 | def reset(self):
75 | self.verts = self.verts_copy.clone()
76 |
77 | def world_2_local_index(self, handle_idx):
78 | # handle_idx: [m,]
79 | # point mask [N,]
80 | # idx_offset = torch.cat([torch.zeros_like(self.point_mask[:1]), torch.cumsum(self.point_mask, dim=0)])
81 | idx_offset = torch.cumsum(~self.point_mask, dim=0)
82 | handle_idx_offset = idx_offset[handle_idx]
83 | return handle_idx - handle_idx_offset
84 |
85 |
86 | def deform(self, handle_idx, handle_pos, init_verts=None, return_R=False):
87 | # handle_idx: (M, ), handle_pos: (M, 3)
88 |
89 | if self.point_mask is not None:
90 | handle_idx = self.world_2_local_index(handle_idx)
91 |
92 | ##### calculate b #####
93 | ### b_fixed
94 | unknown_verts = [n for n in range(self.N) if n not in handle_idx.tolist()] # all unknown verts
95 | b_fixed = torch.zeros((self.N, 3), device=self.device) # factor to be subtracted from b, due to constraints
96 | for k, pos in zip(handle_idx, handle_pos):
97 | # b_fixed += torch.einsum("i,j->ij", self.L[:, k], pos) # [Nv,3]
98 | b_fixed += torch.einsum("i,j->ij", self.L_opt[:, k], pos) # [Nv,3]
99 |
100 | ### prepare for b_all
101 | P = produce_edge_matrix_nfmt(self.verts, (self.N, self.K, 3), self.ii, self.jj, self.nn, device=self.device) # [Nv, K, 3]
102 | if init_verts is None:
103 | p_prime = lstsq_with_handles(self.L_opt, self.L_opt@self.verts, handle_idx, handle_pos, A_is_degenarate=self.L_is_degenerate)
104 | else:
105 | p_prime = init_verts
106 |
107 | p_prime_seq = [p_prime]
108 | R = torch.eye(3)[None].repeat(self.N, 1,1).cuda() # compute rotations
109 |
110 | NUM_ITER = 3
111 | D = torch.diag_embed(self.normalized_weight, dim1=1, dim2=2) # [Nv, K, K]
112 | for _ in range(NUM_ITER):
113 | P_prime = produce_edge_matrix_nfmt(p_prime, (self.N, self.K, 3), self.ii, self.jj, self.nn, device=self.device) # [Nv, K, 3]
114 | ### Calculate covariance matrix in bulk
115 | S = torch.bmm(P.permute(0, 2, 1), torch.bmm(D, P_prime)) # [Nv, 3, 3]
116 |
117 | ## in the case of no deflection, set S = 0, such that R = I. This is to avoid numerical errors
118 | unchanged_verts = torch.unique(torch.where((P == P_prime).all(dim=1))[0]) # any verts which are undeformed
119 | S[unchanged_verts] = 0
120 |
121 | U, sig, W = torch.svd(S)
122 | R = torch.bmm(W, U.permute(0, 2, 1)) # compute rotations
123 |
124 | # Need to flip the column of U corresponding to smallest singular value
125 | # for any det(Ri) <= 0
126 | entries_to_flip = torch.nonzero(torch.det(R) <= 0, as_tuple=False).flatten() # idxs where det(R) <= 0
127 | if len(entries_to_flip) > 0:
128 | Umod = U.clone()
129 | cols_to_flip = torch.argmin(sig[entries_to_flip], dim=1) # Get minimum singular value for each entry
130 | Umod[entries_to_flip, :, cols_to_flip] *= -1 # flip cols
131 | R[entries_to_flip] = torch.bmm(W[entries_to_flip], Umod[entries_to_flip].permute(0, 2, 1))
132 |
133 | ### RHS of minimum energy equation
134 | Rsum_shape = (self.N, self.K, 3, 3)
135 | Rsum = torch.zeros(Rsum_shape).to(self.device) # Ri + Rj, as in eq (8)
136 | Rsum[self.ii, self.nn] = R[self.ii] + R[self.jj]
137 |
138 | ### Rsum has shape (V, max_neighbours, 3, 3). P has shape (V, max_neighbours, 3)
139 | ### To batch multiply, collapse first 2 dims into a single batch dim
140 | Rsum_batch, P_batch = Rsum.view(-1, 3, 3), P.view(-1, 3).unsqueeze(-1)
141 |
142 | # RHS of minimum energy equation
143 | b = 0.5 * (torch.bmm(Rsum_batch, P_batch).squeeze(-1).reshape(self.N, self.K, 3) * self.normalized_weight[...,None]).sum(dim=1)
144 |
145 | ### calculate p_prime
146 | p_prime = lstsq_with_handles(self.L_opt, b, handle_idx, handle_pos, A_is_degenarate=self.L_is_degenerate)
147 |
148 | p_prime_seq.append(p_prime)
149 | d_scaling = None
150 |
151 | if return_R:
152 | quat = matrix_to_quaternion(R)
153 | return p_prime, quat, d_scaling
154 | else:
155 | # return p_prime, p_prime_seq
156 | return p_prime
157 |
158 |
159 |
160 | if __name__ == "__main__":
161 | from pytorch3d.io import load_ply
162 | from pytorch3d.ops import ball_query
163 | import pickle
164 | with open("./control_kpt.pkl", "rb") as f:
165 | data = pickle.load(f)
166 |
167 | points = data["pts"]
168 | handle_idx = data["handle_idx"]
169 | handle_pos = data["handle_pos"]
170 |
171 | import trimesh
172 | trimesh.Trimesh(vertices=points).export('deformation_before.ply')
173 |
174 | #### prepare data
175 | points = torch.from_numpy(points).float().cuda()
176 | handle_idx = torch.tensor(handle_idx).long().cuda()
177 | handle_pos = torch.from_numpy(handle_pos).float().cuda()
178 |
179 | deformer = ARAPDeformer(points)
180 |
181 | with torch.no_grad():
182 | points_prime, p_prime_seq = deformer.deform(handle_idx, handle_pos)
183 |
184 | trimesh.Trimesh(vertices=points_prime.cpu().numpy()).export('deformation_after.ply')
185 |
186 | from utils.deform_utils import cal_arap_error
187 | for p_prime in p_prime_seq:
188 | nodes_sequence = torch.cat([points[None], p_prime[None]], dim=0)
189 | arap_error = cal_arap_error(nodes_sequence, deformer.ii, deformer.jj, deformer.nn, K=deformer.K, weight=deformer.normalized_weight)
190 | print(arap_error)
--------------------------------------------------------------------------------
/utils/bezier.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class BezierCurve:
5 | def __init__(self, points: np.ndarray) -> None:
6 | if points.ndim == 2:
7 | points = points[None]
8 | self.points = points # N, T, D
9 | self.T = points.shape[1]
10 |
11 | def __call__(self, t: float):
12 | assert 0 <= t <= 1, f't: {t} out of range [0, 1]!'
13 | return self.interpolate(t, self.points)
14 |
15 | def interpolate(self, t, points):
16 | if points.shape[1] < 2:
17 | raise ValueError(f"points shape error: {points.shape}")
18 | elif points.shape[1] == 2:
19 | point0, point1 = points[:, 0], points[:, 1]
20 | else:
21 | point0 = self.interpolate(t, points[:, :-1])
22 | point1 = self.interpolate(t, points[:, 1:])
23 | return (1 - t) * point0 + t * point1
24 |
25 |
26 | class PieceWiseLinear:
27 | def __init__(self, points: np.ndarray) -> None:
28 | if points.ndim == 2:
29 | points = points[None]
30 | self.points = points # N, T, D
31 | self.T = points.shape[1]
32 |
33 | def __call__(self, t: float):
34 | assert 0 <= t <= 1, f't: {t} out of range [0, 1]!'
35 | return self.interpolate(t, self.points)
36 |
37 | def interpolate(self, t, points):
38 | if points.shape[1] < 2:
39 | raise ValueError(f"points shape error: {points.shape}")
40 | else:
41 | t_scaled = t * (self.T - 1)
42 | t_floor = min(self.T - 2, max(0, int(np.floor(t_scaled))))
43 | t_ceil = t_floor + 1
44 | point0, point1 = points[:, t_floor], points[:, t_ceil]
45 | return (t_ceil - t_scaled) * point0 + (t_scaled - t_floor) * point1
46 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from scene.cameras import Camera
13 | import numpy as np
14 | from utils.general_utils import PILtoTorch, ArrayToTorch
15 | from utils.graphics_utils import fov2focal
16 | import json
17 |
18 | WARNED = False
19 |
20 |
21 | def loadCam(args, id, cam_info, resolution_scale, flow_dirs):
22 | orig_w, orig_h = cam_info.image.size
23 |
24 | if args.resolution in [1, 2, 4, 8]:
25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round(
26 | orig_h / (resolution_scale * args.resolution))
27 | else: # should be a type that converts to float
28 | if args.resolution == -1:
29 | if orig_w > 1600:
30 | global WARNED
31 | if not WARNED:
32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
34 | WARNED = True
35 | global_down = orig_w / 1600
36 | else:
37 | global_down = 1
38 | else:
39 | global_down = orig_w / args.resolution
40 |
41 | scale = float(global_down) * float(resolution_scale)
42 | resolution = (int(orig_w / scale), int(orig_h / scale))
43 |
44 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
45 |
46 | gt_image = resized_image_rgb[:3, ...]
47 | loaded_mask = None
48 |
49 | if resized_image_rgb.shape[0] == 4:
50 | loaded_mask = resized_image_rgb[3:4, ...]
51 |
52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
54 | image=gt_image, gt_alpha_mask=loaded_mask,
55 | image_name=cam_info.image_name, uid=id,
56 | data_device=args.data_device if not args.load2gpu_on_the_fly else 'cpu', fid=cam_info.fid,
57 | depth=cam_info.depth, flow_dirs=flow_dirs)
58 |
59 |
60 | def cameraList_from_camInfos(cam_infos, resolution_scale, args, flow_dirs_list=None):
61 | camera_list = []
62 |
63 | for id, c in enumerate(cam_infos):
64 | camera_list.append(loadCam(args, id, c, resolution_scale, [] if flow_dirs_list is None else flow_dirs_list[id]))
65 |
66 | return camera_list
67 |
68 |
69 | def camera_to_JSON(id, camera: Camera):
70 | Rt = np.zeros((4, 4))
71 | Rt[:3, :3] = camera.R.transpose()
72 | Rt[:3, 3] = camera.T
73 | Rt[3, 3] = 1.0
74 |
75 | W2C = np.linalg.inv(Rt)
76 | pos = W2C[:3, 3]
77 | rot = W2C[:3, :3]
78 | serializable_array_2d = [x.tolist() for x in rot]
79 | camera_entry = {
80 | 'id': id,
81 | 'img_name': camera.image_name,
82 | 'width': camera.width,
83 | 'height': camera.height,
84 | 'position': pos.tolist(),
85 | 'rotation': serializable_array_2d,
86 | 'fy': fov2focal(camera.FovY, camera.height),
87 | 'fx': fov2focal(camera.FovX, camera.width)
88 | }
89 | return camera_entry
90 |
91 |
92 | def camera_nerfies_from_JSON(path, scale):
93 | """Loads a JSON camera into memory."""
94 | with open(path, 'r') as fp:
95 | camera_json = json.load(fp)
96 |
97 | # Fix old camera JSON.
98 | if 'tangential' in camera_json:
99 | camera_json['tangential_distortion'] = camera_json['tangential']
100 |
101 | return dict(
102 | orientation=np.array(camera_json['orientation']),
103 | position=np.array(camera_json['position']),
104 | focal_length=camera_json['focal_length'] * scale,
105 | principal_point=np.array(camera_json['principal_point']) * scale,
106 | skew=camera_json['skew'],
107 | pixel_aspect_ratio=camera_json['pixel_aspect_ratio'],
108 | radial_distortion=np.array(camera_json['radial_distortion']),
109 | tangential_distortion=np.array(camera_json['tangential_distortion']),
110 | image_size=np.array((int(round(camera_json['image_size'][0] * scale)),
111 | int(round(camera_json['image_size'][1] * scale)))),
112 | )
113 |
--------------------------------------------------------------------------------
/utils/deform_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pytorch3d.loss.mesh_laplacian_smoothing import cot_laplacian
4 | from pytorch3d.ops import ball_query
5 | from pytorch3d.io import load_ply
6 | # try:
7 | # print('Using speed up torch_batch_svd!')
8 | # from torch_batch_svd import svd
9 | # except:
10 | # print('Use original torch svd!')
11 | svd = torch.svd
12 | import pytorch3d.ops
13 |
14 |
15 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
16 | r, i, j, k = torch.unbind(quaternions, -1)
17 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
18 | o = torch.stack(
19 | (
20 | 1 - two_s * (j * j + k * k),
21 | two_s * (i * j - k * r),
22 | two_s * (i * k + j * r),
23 | two_s * (i * j + k * r),
24 | 1 - two_s * (i * i + k * k),
25 | two_s * (j * k - i * r),
26 | two_s * (i * k - j * r),
27 | two_s * (j * k + i * r),
28 | 1 - two_s * (i * i + j * j),
29 | ),
30 | -1,
31 | )
32 | return o.reshape(quaternions.shape[:-1] + (3, 3))
33 |
34 |
35 | def produce_edge_matrix_nfmt(verts: torch.Tensor, edge_shape, ii, jj, nn, device="cuda") -> torch.Tensor:
36 | """Given a tensor of verts postion, p (V x 3), produce a tensor E, where, for neighbour list J,
37 | E_in = p_i - p_(J[n])"""
38 |
39 | E = torch.zeros(edge_shape).to(device)
40 | E[ii, nn] = verts[ii] - verts[jj]
41 |
42 | return E
43 |
44 |
45 | ####################### utils for arap #######################
46 |
47 | def geodesic_distance_floyd(cur_node, K=8):
48 | node_num = cur_node.shape[0]
49 | nn_dist, nn_idx, _ = pytorch3d.ops.knn_points(cur_node[None], cur_node[None], None, None, K=K+1)
50 | nn_dist, nn_idx = nn_dist[0]**.5, nn_idx[0]
51 | dist_mat = torch.inf * torch.ones([node_num, node_num], dtype=torch.float32, device=cur_node.device)
52 | dist_mat.scatter_(dim=1, index=nn_idx, src=nn_dist)
53 | dist_mat = torch.minimum(dist_mat, dist_mat.T)
54 | for i in range(nn_dist.shape[0]):
55 | dist_mat = torch.minimum((dist_mat[:, i, None] + dist_mat[None, i, :]), dist_mat)
56 | return dist_mat
57 |
58 | def cal_connectivity_from_points(points=None, radius=0.1, K=10, trajectory=None, least_edge_num=3, node_radius=None, mode='nn', GraphK=4, adaptive_weighting=True):
59 | # input: [Nv,3]
60 | # output: information of edges
61 | # ii : [Ne,] the i th vert
62 | # jj: [Ne,] j th vert is connect to i th vert.
63 | # nn: , [Ne,] the n th neighbour of i th vert is j th vert.
64 | Nv = points.shape[0] if points is not None else trajectory.shape[0]
65 | if trajectory is None:
66 | if mode == 'floyd':
67 | dist_mat = geodesic_distance_floyd(points, K=GraphK)
68 | dist_mat = dist_mat ** 2
69 | mask = torch.eye(Nv).bool()
70 | dist_mat[mask] = torch.inf
71 | nn_dist, nn_idx = dist_mat.sort(dim=1)
72 | nn_dist, nn_idx = nn_dist[:, :K], nn_idx[:, :K]
73 | else:
74 | knn_res = pytorch3d.ops.knn_points(points[None], points[None], None, None, K=K+1)
75 | # Remove themselves
76 | nn_dist, nn_idx = knn_res.dists[0, :, 1:], knn_res.idx[0, :, 1:] # [Nv, K], [Nv, K]
77 | else:
78 | trajectory = trajectory.reshape([Nv, -1]) / trajectory.shape[1] # Average distance of trajectory
79 | if mode == 'floyd':
80 | dist_mat = geodesic_distance_floyd(trajectory, K=GraphK)
81 | dist_mat = dist_mat ** 2
82 | mask = torch.eye(Nv).bool()
83 | dist_mat[mask] = torch.inf
84 | nn_dist, nn_idx = dist_mat.sort(dim=1)
85 | nn_dist, nn_idx = nn_dist[:, :K], nn_idx[:, :K]
86 | else:
87 | knn_res = pytorch3d.ops.knn_points(trajectory[None], trajectory[None], None, None, K=K+1)
88 | # Remove themselves
89 | nn_dist, nn_idx = knn_res.dists[0, :, 1:], knn_res.idx[0, :, 1:] # [Nv, K], [Nv, K]
90 |
91 | # Make sure ranges are within the radius
92 | nn_idx[:, least_edge_num:] = torch.where(nn_dist[:, least_edge_num:] < radius ** 2, nn_idx[:, least_edge_num:], - torch.ones_like(nn_idx[:, least_edge_num:]))
93 |
94 | nn_dist[:, least_edge_num:] = torch.where(nn_dist[:, least_edge_num:] < radius ** 2, nn_dist[:, least_edge_num:], torch.ones_like(nn_dist[:, least_edge_num:]) * torch.inf)
95 | if adaptive_weighting:
96 | nn_dist_1d = nn_dist.reshape(-1)
97 | weight = torch.exp(-nn_dist / nn_dist_1d[~torch.isnan(nn_dist_1d) & ~torch.isinf(nn_dist_1d)].mean())
98 | elif node_radius is None:
99 | weight = torch.exp(-nn_dist)
100 | else:
101 | nn_radius = node_radius[nn_idx]
102 | weight = torch.exp(-nn_dist / (2 * nn_radius ** 2))
103 | weight = weight / weight.sum(dim=-1, keepdim=True)
104 |
105 | ii = torch.arange(Nv)[:, None].cuda().long().expand(Nv, K).reshape([-1])
106 | jj = nn_idx.reshape([-1])
107 | nn = torch.arange(K)[None].cuda().long().expand(Nv, K).reshape([-1])
108 | mask = jj != -1
109 | ii, jj, nn = ii[mask], jj[mask], nn[mask]
110 |
111 | return ii, jj, nn, weight
112 |
113 |
114 | def cal_laplacian(Nv, ii, jj, nn):
115 | # input: Nv: int; ii, jj, nn: [Ne,]
116 | # output: laplacian_mat: [Nv, Nv]
117 | laplacian_mat = torch.zeros(Nv, Nv).cuda()
118 | laplacian_mat[ii, jj] = -1
119 | for idx in ii:
120 | laplacian_mat[idx, idx] += 1 # TODO test whether it is correct
121 | return laplacian_mat
122 |
123 | def cal_verts_deg(Nv, ii):
124 | # input: Nv: int; ii, jj, nn: [Ne,]
125 | # output: verts_deg: [Nv,]
126 | verts_deg = torch.zeros(Nv).cuda()
127 | for idx in ii:
128 | verts_deg[idx] += 1
129 | return verts_deg
130 |
131 | def estimate_rotation(source, target, ii, jj, nn, K=10, weight=None, sample_idx=None):
132 | # input: source, target: [Nv, 3]; ii, jj, nn: [Ne,], weight: [Nv, K]
133 | # output: rotation: [Nv, 3, 3]
134 | Nv = len(source)
135 | source_edge_mat = produce_edge_matrix_nfmt(source, (Nv, K, 3), ii, jj, nn) # [Nv, K, 3]
136 | target_edge_mat = produce_edge_matrix_nfmt(target, (Nv, K, 3), ii, jj, nn) # [Nv, K, 3]
137 | if weight is None:
138 | weight = torch.zeros(Nv, K).cuda()
139 | weight[ii, nn] = 1
140 | print("!!! Edge weight is None !!!")
141 | if sample_idx is not None:
142 | source_edge_mat = source_edge_mat[sample_idx]
143 | target_edge_mat = target_edge_mat[sample_idx]
144 | ### Calculate covariance matrix in bulk
145 | D = torch.diag_embed(weight, dim1=1, dim2=2) # [Nv, K, K]
146 | # S = torch.bmm(source_edge_mat.permute(0, 2, 1), target_edge_mat) # [Nv, 3, 3]
147 | S = torch.bmm(source_edge_mat.permute(0, 2, 1), torch.bmm(D, target_edge_mat)) # [Nv, 3, 3]
148 | ## in the case of no deflection, set S = 0, such that R = I. This is to avoid numerical errors
149 | unchanged_verts = torch.unique(torch.where((source_edge_mat == target_edge_mat).all(dim=1))[0]) # any verts which are undeformed
150 | S[unchanged_verts] = 0
151 |
152 | # t2 = time.time()
153 | U, sig, W = svd(S)
154 | R = torch.bmm(W, U.permute(0, 2, 1)) # compute rotations
155 | # t3 = time.time()
156 |
157 | # Need to flip the column of U corresponding to smallest singular value
158 | # for any det(Ri) <= 0
159 | entries_to_flip = torch.nonzero(torch.det(R) <= 0, as_tuple=False).flatten() # idxs where det(R) <= 0
160 | if len(entries_to_flip) > 0:
161 | Umod = U.clone()
162 | cols_to_flip = torch.argmin(sig[entries_to_flip], dim=1) # Get minimum singular value for each entry
163 | Umod[entries_to_flip, :, cols_to_flip] *= -1 # flip cols
164 | R[entries_to_flip] = torch.bmm(W[entries_to_flip], Umod[entries_to_flip].permute(0, 2, 1))
165 | # t4 = time.time()
166 | # print(f'0-1: {t1-t0}, 1-2: {t2-t1}, 2-3: {t3-t2}, 3-4: {t4-t3}')
167 | return R
168 |
169 | import time
170 | def cal_arap_error(nodes_sequence, ii, jj, nn, K=10, weight=None, sample_num=512):
171 | # input: nodes_sequence: [Nt, Nv, 3]; ii, jj, nn: [Ne,], weight: [Nv, K]
172 | # output: arap error: float
173 | Nt, Nv, _ = nodes_sequence.shape
174 | arap_error = 0
175 | if weight is None:
176 | weight = torch.zeros(Nv, K).cuda()
177 | weight[ii, nn] = 1
178 | source_edge_mat = produce_edge_matrix_nfmt(nodes_sequence[0], (Nv, K, 3), ii, jj, nn) # [Nv, K, 3]
179 | sample_idx = torch.arange(Nv).cuda()
180 | if Nv > sample_num:
181 | sample_idx = torch.from_numpy(np.random.choice(Nv, sample_num)).long().cuda()
182 | else:
183 | source_edge_mat = source_edge_mat[sample_idx]
184 | weight = weight[sample_idx]
185 | for idx in range(1, Nt):
186 | # t1 = time.time()
187 | with torch.no_grad():
188 | rotation = estimate_rotation(nodes_sequence[0], nodes_sequence[idx], ii, jj, nn, K=K, weight=weight, sample_idx=sample_idx) # [Nv, 3, 3]
189 | # Compute energy
190 | target_edge_mat = produce_edge_matrix_nfmt(nodes_sequence[idx], (Nv, K, 3), ii, jj, nn) # [Nv, K, 3]
191 | target_edge_mat = target_edge_mat[sample_idx]
192 | rot_rigid = torch.bmm(rotation, source_edge_mat[sample_idx].permute(0, 2, 1)).permute(0, 2, 1) # [Nv, K, 3]
193 | stretch_vec = target_edge_mat - rot_rigid # stretch vector
194 | stretch_norm = (torch.norm(stretch_vec, dim=2) ** 2) # norm over (x,y,z) space
195 | arap_error += (weight * stretch_norm).sum()
196 | return arap_error
197 |
198 | def cal_L_from_points(points, return_nn_idx=False):
199 | # points: (N, 3)
200 | Nv = len(points)
201 | L = torch.eye(Nv).cuda()
202 | radius = 0.1 #
203 | K = 20
204 | knn_res = ball_query(points[None], points[None], K=K, radius=radius, return_nn=False)
205 | nn_dist, nn_idx = knn_res.dists[0], knn_res.idx[0] # [Nv, K], [Nv, K]
206 | for idx, cur_nn_idx in enumerate(nn_idx):
207 | real_cur_nn_idx = cur_nn_idx[cur_nn_idx != -1]
208 | real_cur_nn_idx = real_cur_nn_idx[real_cur_nn_idx != idx]
209 | L[idx, idx] = len(real_cur_nn_idx)
210 | L[idx][real_cur_nn_idx] = -1
211 | if return_nn_idx:
212 | return L, nn_idx
213 | else:
214 | return L
215 |
216 | def lstsq_with_handles(A, b, handle_idx, handle_pos, A_is_degenarate=False):
217 | b = b - A[:, handle_idx] @ handle_pos
218 | handle_mask = torch.zeros_like(A[:, 0], dtype=bool)
219 | handle_mask[handle_idx] = 1
220 | L = A[:, handle_mask.logical_not()]
221 | if not A_is_degenarate:
222 | x = torch.linalg.lstsq(L, b)[0]
223 | else:
224 | x = torch.linalg.pinv(L) @ b
225 | x_out = torch.zeros_like(b)
226 | x_out[handle_idx] = handle_pos
227 | x_out[handle_mask.logical_not()] = x
228 | return x_out
229 |
230 | def rigid_align(x, y):
231 | x_bar, y_bar = x.mean(0), y.mean(0)
232 | x, y = x - x_bar, y - y_bar
233 | S = x.permute(1, 0) @ y # 3 * 3
234 | U, _, W = svd(S)
235 | R = W @ U.permute(1, 0)
236 | t = y_bar - R @ x_bar
237 | x2y = x @ R.T + t
238 | return x2y, R, t
239 |
240 | def arap_deformation_loss(trajectory, node_radius=None, trajectory_rot=None, K=50, with_rot=True):
241 | init_pcl = trajectory[:, 0]
242 | radius = torch.linalg.norm(init_pcl.max(dim=0).values - init_pcl.min(dim=0).values) / 8
243 | fid = torch.randint(1, trajectory.shape[1], [])
244 | tar_pcl = trajectory[:, fid]
245 |
246 | N = init_pcl.shape[0]
247 | with torch.no_grad():
248 | radius = torch.linalg.norm(init_pcl.max(dim=0).values - init_pcl.min(dim=0).values) / 8
249 | device = init_pcl.device
250 | ii, jj, nn, weight = cal_connectivity_from_points(init_pcl, radius, K, trajectory=trajectory.detach(), node_radius=node_radius, mode='nn')
251 | L_opt = torch.eye(N).cuda()
252 | L_opt[ii, jj] = - weight[ii, nn]
253 |
254 | P = produce_edge_matrix_nfmt(init_pcl, (N, K, 3), ii, jj, nn, device=device)
255 | P_prime = produce_edge_matrix_nfmt(tar_pcl, (N, K, 3), ii, jj, nn, device=device)
256 |
257 | with torch.no_grad():
258 | D = torch.diag_embed(weight, dim1=1, dim2=2)
259 | S = torch.bmm(P.permute(0, 2, 1), torch.bmm(D, P_prime))
260 | U, sig, W = torch.svd(S)
261 | R = torch.bmm(W, U.permute(0, 2, 1))
262 | with torch.no_grad():
263 | # Need to flip the column of U corresponding to smallest singular value
264 | # for any det(Ri) <= 0
265 | entries_to_flip = torch.nonzero(torch.det(R) <= 0, as_tuple=False).flatten() # idxs where det(R) <= 0
266 | if len(entries_to_flip) > 0:
267 | Umod = U.clone()
268 | cols_to_flip = torch.argmin(sig[entries_to_flip], dim=1) # Get minimum singular value for each entry
269 | Umod[entries_to_flip, :, cols_to_flip] *= -1 # flip cols
270 | R[entries_to_flip] = torch.bmm(W[entries_to_flip], Umod[entries_to_flip].permute(0, 2, 1))
271 | arap_error = (weight[..., None] * (P_prime - torch.einsum('bxy,bky->bkx', R, P))).square().mean(dim=0).sum()
272 |
273 | if with_rot:
274 | init_rot = quaternion_to_matrix(trajectory_rot[:, 0])
275 | tar_rot = quaternion_to_matrix(trajectory_rot[:, fid])
276 | R_rot = torch.bmm(R, init_rot)
277 | rot_error = (R_rot - tar_rot).square().mean(dim=0).sum()
278 | else:
279 | rot_error = 0.
280 |
281 | return arap_error, rot_error * 1e2
282 |
--------------------------------------------------------------------------------
/utils/dual_quaternion.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
5 | """
6 | Returns torch.sqrt(torch.max(0, x))
7 | but with a zero subgradient where x is 0.
8 | """
9 | ret = torch.zeros_like(x)
10 | positive_mask = x > 0
11 | ret[positive_mask] = torch.sqrt(x[positive_mask])
12 | return ret
13 |
14 |
15 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
16 | """
17 | Convert rotations given as rotation matrices to quaternions.
18 |
19 | Args:
20 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
21 |
22 | Returns:
23 | quaternions with real part first, as tensor of shape (..., 4).
24 | """
25 | if matrix.size(-1) != 3 or matrix.size(-2) != 3:
26 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
27 |
28 | batch_dim = matrix.shape[:-2]
29 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
30 | matrix.reshape(batch_dim + (9,)), dim=-1
31 | )
32 |
33 | q_abs = _sqrt_positive_part(
34 | torch.stack(
35 | [
36 | 1.0 + m00 + m11 + m22,
37 | 1.0 + m00 - m11 - m22,
38 | 1.0 - m00 + m11 - m22,
39 | 1.0 - m00 - m11 + m22,
40 | ],
41 | dim=-1,
42 | )
43 | )
44 |
45 | # we produce the desired quaternion multiplied by each of r, i, j, k
46 | quat_by_rijk = torch.stack(
47 | [
48 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
49 | # `int`.
50 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
51 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
52 | # `int`.
53 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
54 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
55 | # `int`.
56 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
57 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
58 | # `int`.
59 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
60 | ],
61 | dim=-2,
62 | )
63 |
64 | # We floor here at 0.1 but the exact level is not important; if q_abs is small,
65 | # the candidate won't be picked.
66 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
67 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
68 |
69 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
70 | # forall i; we pick the best-conditioned one (with the largest denominator)
71 |
72 | return quat_candidates[
73 | torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
74 | ].reshape(batch_dim + (4,))
75 |
76 |
77 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
78 | r, i, j, k = torch.unbind(quaternions, -1)
79 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
80 | o = torch.stack(
81 | (
82 | 1 - two_s * (j * j + k * k),
83 | two_s * (i * j - k * r),
84 | two_s * (i * k + j * r),
85 | two_s * (i * j + k * r),
86 | 1 - two_s * (i * i + k * k),
87 | two_s * (j * k - i * r),
88 | two_s * (i * k - j * r),
89 | two_s * (j * k + i * r),
90 | 1 - two_s * (i * i + j * j),
91 | ),
92 | -1,
93 | )
94 | return o.reshape(quaternions.shape[:-1] + (3, 3))
95 |
96 |
97 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
98 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
99 |
100 |
101 | def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
102 | aw, ax, ay, az = torch.unbind(a, -1)
103 | bw, bx, by, bz = torch.unbind(b, -1)
104 | ow = aw * bw - ax * bx - ay * by - az * bz
105 | ox = aw * bx + ax * bw + ay * bz - az * by
106 | oy = aw * by - ax * bz + ay * bw + az * bx
107 | oz = aw * bz + ax * by - ay * bx + az * bw
108 | return torch.stack((ow, ox, oy, oz), -1)
109 |
110 |
111 | def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
112 | ab = quaternion_raw_multiply(a, b)
113 | return standardize_quaternion(ab)
114 |
115 |
116 | def dualquaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
117 | a_real, b_real = a[..., :4], b[..., :4]
118 | a_imag, b_imag = a[..., 4:], b[..., 4:]
119 | o_real = quaternion_multiply(a_real, b_real)
120 | o_imag = quaternion_multiply(a_imag, b_real) + quaternion_multiply(a_real, b_imag)
121 | o = torch.cat([o_real, o_imag], dim=-1)
122 | return o
123 |
124 |
125 | def conjugation(q):
126 | if q.shape[-1] == 4:
127 | q = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
128 | elif q.shape[-1] == 8:
129 | q = torch.cat([q[..., :1], -q[..., 1:4], q[..., 4:5], -q[..., 5:]], dim=-1)
130 | else:
131 | raise TypeError(f'q should be of [..., 4] or [..., 8] but got {q.shape}!')
132 | return q
133 |
134 |
135 | def QT2DQ(q, t, rot_as_q=True):
136 | if not rot_as_q:
137 | q = matrix_to_quaternion(q)
138 | q = torch.nn.functional.normalize(q)
139 | real = q
140 | t = torch.cat([torch.zeros_like(t[..., :1]), t], dim=-1)
141 | image = quaternion_multiply(t, q) / 2
142 | dq = torch.cat([real, image], dim=-1)
143 | return dq
144 |
145 |
146 | def DQ2QT(dq, rot_as_q=False):
147 | real = dq[..., :4]
148 | imag = dq[..., 4:]
149 | real_norm = real.norm(dim=-1, keepdim=True).clamp(min=1e-8)
150 | real, imag = real / real_norm, imag / real_norm
151 |
152 | w0, x0, y0, z0 = torch.unbind(real, -1)
153 | w1, x1, y1, z1 = torch.unbind(imag, -1)
154 |
155 | t = 2* torch.stack([- w1*x0 + x1*w0 - y1*z0 + z1*y0,
156 | - w1*y0 + x1*z0 + y1*w0 - z1*x0,
157 | - w1*z0 - x1*y0 + y1*x0 + z1*w0], dim=-1)
158 | R = torch.stack([1-2*y0**2-2*z0**2, 2*x0*y0-2*w0*z0, 2*x0*z0+2*w0*y0,
159 | 2*x0*y0+2*w0*z0, 1-2*x0**2-2*z0**2, 2*y0*z0-2*w0*x0,
160 | 2*x0*z0-2*w0*y0, 2*y0*z0+2*w0*x0, 1-2*x0**2-2*y0**2], dim=-1).reshape([*w0.shape, 3, 3])
161 | if rot_as_q:
162 | q = matrix_to_quaternion(R)
163 | return q, t
164 | else:
165 | return R, t
166 |
167 |
168 | def DQBlending(q, t, weights, rot_as_q=True):
169 | '''
170 | Input:
171 | q: [..., k, 4]; t: [..., k, 3]; weights: [..., k]
172 | Output:
173 | q_: [..., 4]; t_: [..., 3]
174 | '''
175 | dq = QT2DQ(q=q, t=t)
176 | dq_avg = (dq * weights[..., None]).sum(dim=-2)
177 | q_, t_ = DQ2QT(dq_avg, rot_as_q=rot_as_q)
178 | return q_, t_
179 |
180 |
181 | def interpolate(q0, t0, q1, t1, weight, rot_as_q=True):
182 | dq0 = QT2DQ(q=q0, t=t0)
183 | dq1 = QT2DQ(q=q1, t=t1)
184 | dq_avg = dq0 * weight + dq1 * (1 - weight)
185 | q, t = DQ2QT(dq=dq_avg, rot_as_q=rot_as_q)
186 | return q, t
187 |
188 |
189 | def transformation_blending(transformations, weights):
190 | Rs, Ts = transformations[:, :3, :3], transformations[:, :3, 3]
191 | qs = matrix_to_quaternion(Rs)
192 | q, T = DQBlending(qs[None], Ts[None], weights)
193 | R = quaternion_to_matrix(q)
194 | transformation = torch.eye(4).to(transformations.device)[None].expand(weights.shape[0], 4, 4).clone()
195 | transformation[:, :3, :3] = R
196 | transformation[:, :3, 3] = T
197 | return transformation
198 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 | from PIL import Image
18 |
19 | def inverse_sigmoid(x):
20 | return torch.log(x / (1 - x))
21 |
22 |
23 | def PILtoTorch(pil_image, resolution):
24 | if np.asarray(pil_image).shape[-1] == 4:
25 | # Process rgb and alpha respectively to avoid mask rgb with alpha
26 | rgb = Image.fromarray(np.asarray(pil_image)[..., :3])
27 | a = Image.fromarray(np.asarray(pil_image)[..., 3])
28 | rgb, a = np.asarray(rgb.resize(resolution)), np.asarray(a.resize(resolution))
29 | resized_image = torch.from_numpy(np.concatenate([rgb, a[..., None]], axis=-1)) / 255.0
30 | else:
31 | resized_image_PIL = pil_image.resize(resolution)
32 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
33 | if len(resized_image.shape) == 3:
34 | return resized_image.permute(2, 0, 1)
35 | else:
36 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
37 |
38 |
39 | def ArrayToTorch(array, resolution):
40 | # resized_image = np.resize(array, resolution)
41 | resized_image_torch = torch.from_numpy(array)
42 |
43 | if len(resized_image_torch.shape) == 3:
44 | return resized_image_torch.permute(2, 0, 1)
45 | else:
46 | return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1)
47 |
48 |
49 | def get_expon_lr_func(
50 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
51 | ):
52 | """
53 | Copied from Plenoxels
54 |
55 | Continuous learning rate decay function. Adapted from JaxNeRF
56 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
57 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
58 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
59 | function of lr_delay_mult, such that the initial learning rate is
60 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
61 | to the normal learning rate when steps>lr_delay_steps.
62 | :param conf: config subtree 'lr' or similar
63 | :param max_steps: int, the number of steps during optimization.
64 | :return HoF which takes step as input
65 | """
66 |
67 | def helper(step):
68 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
69 | # Disable this parameter
70 | return 0.0
71 | if lr_delay_steps > 0:
72 | # A kind of reverse cosine decay.
73 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
74 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
75 | )
76 | else:
77 | delay_rate = 1.0
78 | t = np.clip(step / max_steps, 0, 1)
79 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
80 | return delay_rate * log_lerp
81 |
82 | return helper
83 |
84 |
85 | def get_linear_noise_func(
86 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
87 | ):
88 | """
89 | Copied from Plenoxels
90 |
91 | Continuous learning rate decay function. Adapted from JaxNeRF
92 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
93 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
94 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
95 | function of lr_delay_mult, such that the initial learning rate is
96 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
97 | to the normal learning rate when steps>lr_delay_steps.
98 | :param conf: config subtree 'lr' or similar
99 | :param max_steps: int, the number of steps during optimization.
100 | :return HoF which takes step as input
101 | """
102 |
103 | def helper(step):
104 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
105 | # Disable this parameter
106 | return 0.0
107 | if lr_delay_steps > 0:
108 | # A kind of reverse cosine decay.
109 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
110 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
111 | )
112 | else:
113 | delay_rate = 1.0
114 | t = np.clip(step / max_steps, 0, 1)
115 | log_lerp = lr_init * (1 - t) + lr_final * t
116 | return delay_rate * log_lerp
117 |
118 | return helper
119 |
120 |
121 | def strip_lowerdiag(L):
122 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
123 |
124 | uncertainty[:, 0] = L[:, 0, 0]
125 | uncertainty[:, 1] = L[:, 0, 1]
126 | uncertainty[:, 2] = L[:, 0, 2]
127 | uncertainty[:, 3] = L[:, 1, 1]
128 | uncertainty[:, 4] = L[:, 1, 2]
129 | uncertainty[:, 5] = L[:, 2, 2]
130 | return uncertainty
131 |
132 |
133 | def strip_symmetric(sym):
134 | return strip_lowerdiag(sym)
135 |
136 |
137 | def build_rotation(r):
138 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3])
139 |
140 | q = r / norm[:, None]
141 |
142 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
143 |
144 | r = q[:, 0]
145 | x = q[:, 1]
146 | y = q[:, 2]
147 | z = q[:, 3]
148 |
149 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
150 | R[:, 0, 1] = 2 * (x * y - r * z)
151 | R[:, 0, 2] = 2 * (x * z + r * y)
152 | R[:, 1, 0] = 2 * (x * y + r * z)
153 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
154 | R[:, 1, 2] = 2 * (y * z - r * x)
155 | R[:, 2, 0] = 2 * (x * z - r * y)
156 | R[:, 2, 1] = 2 * (y * z + r * x)
157 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
158 | return R
159 |
160 |
161 | def build_scaling_rotation(s, r):
162 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
163 | R = build_rotation(r)
164 |
165 | L[:, 0, 0] = s[:, 0]
166 | L[:, 1, 1] = s[:, 1]
167 | L[:, 2, 2] = s[:, 2]
168 |
169 | L = R @ L
170 | return L
171 |
172 |
173 | def build_scaling_rotation_inverse(s, r):
174 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
175 | R = build_rotation(r)
176 |
177 | L[:, 0, 0] = 1 / s[:, 0]
178 | L[:, 1, 1] = 1 / s[:, 1]
179 | L[:, 2, 2] = 1 / s[:, 2]
180 |
181 | L = R.permute(0, 2, 1) @ L
182 | return L
183 |
184 |
185 | def safe_state(silent):
186 | old_f = sys.stdout
187 |
188 | class F:
189 | def __init__(self, silent):
190 | self.silent = silent
191 |
192 | def write(self, x):
193 | if not self.silent:
194 | if x.endswith("\n"):
195 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
196 | else:
197 | old_f.write(x)
198 |
199 | def flush(self):
200 | old_f.flush()
201 |
202 | sys.stdout = F(silent)
203 |
204 | random.seed(0)
205 | np.random.seed(0)
206 | torch.manual_seed(0)
207 | torch.cuda.set_device(torch.device("cuda:0"))
208 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 |
18 | class BasicPointCloud(NamedTuple):
19 | points: np.array
20 | colors: np.array
21 | normals: np.array
22 |
23 |
24 | def geom_transform_points(points, transf_matrix):
25 | P, _ = points.shape
26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
27 | points_hom = torch.cat([points, ones], dim=1)
28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
29 |
30 | denom = points_out[..., 3:] + 0.0000001
31 | return (points_out[..., :3] / denom).squeeze(dim=0)
32 |
33 |
34 | def getWorld2View(R, t):
35 | Rt = np.zeros((4, 4))
36 | Rt[:3, :3] = R.transpose()
37 | Rt[:3, 3] = t
38 | Rt[3, 3] = 1.0
39 | return np.float32(Rt)
40 |
41 |
42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
43 | Rt = np.zeros((4, 4))
44 | Rt[:3, :3] = R.transpose()
45 | Rt[:3, 3] = t
46 | Rt[3, 3] = 1.0
47 |
48 | C2W = np.linalg.inv(Rt)
49 | cam_center = C2W[:3, 3]
50 | cam_center = (cam_center + translate) * scale
51 | C2W[:3, 3] = cam_center
52 | Rt = np.linalg.inv(C2W)
53 | return np.float32(Rt)
54 |
55 |
56 | def getProjectionMatrix(znear, zfar, fovX, fovY):
57 | tanHalfFovY = math.tan((fovY / 2))
58 | tanHalfFovX = math.tan((fovX / 2))
59 |
60 | top = tanHalfFovY * znear
61 | bottom = -top
62 | right = tanHalfFovX * znear
63 | left = -right
64 |
65 | P = torch.zeros(4, 4)
66 |
67 | z_sign = 1.0
68 |
69 | P[0, 0] = 2.0 * znear / (right - left)
70 | P[1, 1] = 2.0 * znear / (top - bottom)
71 | P[0, 2] = (right + left) / (right - left)
72 | P[1, 2] = (top + bottom) / (top - bottom)
73 | P[3, 2] = z_sign
74 | P[2, 2] = z_sign * zfar / (zfar - znear)
75 | P[2, 3] = -(zfar * znear) / (zfar - znear)
76 | return P
77 |
78 |
79 | def fov2focal(fov, pixels):
80 | return pixels / (2 * math.tan(fov / 2))
81 |
82 |
83 | def focal2fov(focal, pixels):
84 | return 2 * math.atan(pixels / (2 * focal))
85 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 | # NeRF-DS Alex LPIPS
15 | import lpips as lpips_lib
16 | loss_fn_alex = lpips_lib.LPIPS(net='alex')
17 | loss_fn_alex.net.cuda()
18 | loss_fn_alex.scaling_layer.cuda()
19 | loss_fn_alex.lins.cuda()
20 | def alex_lpips(image1, image2):
21 | image1 = image1 * 2 - 1
22 | image2 = image2 * 2 - 1
23 | lpips = loss_fn_alex(image1, image2)
24 | return lpips
25 |
26 |
27 | def mse(img1, img2):
28 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
29 |
30 |
31 | def psnr(img1, img2):
32 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
33 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
34 |
35 |
36 | from piq import ssim, LPIPS
37 | lpips = LPIPS()
38 |
--------------------------------------------------------------------------------
/utils/interactive_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class DeformKeypoints:
6 | def __init__(self) -> None:
7 | self.keypoints3d_list = [] # list of keypoints group
8 | self.keypoints_idx_list = [] # keypoints index
9 | self.keypoints3d_delta_list = []
10 | self.selective_keypoints_idx_list = [] # keypoints index
11 | self.idx2group = {}
12 |
13 | self.selective_rotation_keypoints_idx_list = []
14 | # self.rotation_idx2group = {}
15 |
16 | def get_kpt_idx(self,):
17 | return self.keypoints_idx_list
18 |
19 | def get_kpt(self,):
20 | return self.keypoints3d_list
21 |
22 | def get_kpt_delta(self,):
23 | return self.keypoints3d_delta_list
24 |
25 | def get_deformed_kpt_np(self, rate=1.):
26 | return np.array(self.keypoints3d_list) + np.array(self.keypoints3d_delta_list) * rate
27 |
28 | def add_kpts(self, keypoints_coord, keypoints_idx, expand=False):
29 | # keypoints3d: [N, 3], keypoints_idx: [N,], torch.tensor
30 | # self.selective_keypoints_idx_list.clear()
31 | selective_keypoints_idx_list = [] if not expand else self.selective_keypoints_idx_list
32 | for idx in range(len(keypoints_idx)):
33 | if not self.contain_kpt(keypoints_idx[idx].item()):
34 | selective_keypoints_idx_list.append(len(self.keypoints_idx_list))
35 | self.keypoints_idx_list.append(keypoints_idx[idx].item())
36 | self.keypoints3d_list.append(keypoints_coord[idx].cpu().numpy())
37 | self.keypoints3d_delta_list.append(np.zeros_like(self.keypoints3d_list[-1]))
38 |
39 | for kpt_idx in keypoints_idx:
40 | self.idx2group[kpt_idx.item()] = selective_keypoints_idx_list
41 |
42 | self.selective_keypoints_idx_list = selective_keypoints_idx_list
43 |
44 | def contain_kpt(self, idx):
45 | # idx: int
46 | if idx in self.keypoints_idx_list:
47 | return True
48 | else:
49 | return False
50 |
51 | def select_kpt(self, idx):
52 | # idx: int
53 | # output: idx list of this group
54 | if idx in self.keypoints_idx_list:
55 | self.selective_keypoints_idx_list = self.idx2group[idx]
56 |
57 | def select_rotation_kpt(self, idx):
58 | if idx in self.keypoints_idx_list:
59 | self.selective_rotation_keypoints_idx_list = self.idx2group[idx]
60 |
61 | def get_rotation_center(self,):
62 | selected_rotation_points = self.get_deformed_kpt_np()[self.selective_rotation_keypoints_idx_list]
63 | return selected_rotation_points.mean(axis=0)
64 |
65 | def get_selective_center(self,):
66 | selected_points = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list]
67 | return selected_points.mean(axis=0)
68 |
69 | def delete_kpt(self, idx):
70 | pass
71 |
72 | def delete_batch_ktps(self, batch_idx):
73 | pass
74 |
75 | def update_delta(self, delta):
76 | # delta: [3,], np.array
77 | for idx in self.selective_keypoints_idx_list:
78 | self.keypoints3d_delta_list[idx] += delta
79 |
80 | def set_delta(self, delta):
81 | # delta: [N, 3], np.array
82 | for id, idx in enumerate(self.selective_keypoints_idx_list):
83 | self.keypoints3d_delta_list[idx] = delta[id]
84 |
85 |
86 | def set_rotation_delta(self, rot_mat):
87 | kpts3d = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list]
88 | kpts3d_mean = self.get_rotation_center()
89 | kpts3d = (kpts3d - kpts3d_mean) @ rot_mat.T + kpts3d_mean
90 | delta = kpts3d - np.array(self.keypoints3d_list)[self.selective_keypoints_idx_list]
91 | for id, idx in enumerate(self.selective_keypoints_idx_list):
92 | self.keypoints3d_delta_list[idx] = delta[id]
93 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 |
18 | def l1_loss(network_output, gt):
19 | return torch.abs((network_output - gt)).mean()
20 |
21 |
22 | def kl_divergence(rho, rho_hat):
23 | rho_hat = torch.mean(torch.sigmoid(rho_hat), 0)
24 | rho = torch.tensor([rho] * len(rho_hat)).cuda()
25 | return torch.mean(
26 | rho * torch.log(rho / (rho_hat + 1e-5)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + 1e-5)))
27 |
28 |
29 | def l2_loss(network_output, gt):
30 | return ((network_output - gt) ** 2).mean()
31 |
32 |
33 | def gaussian(window_size, sigma):
34 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
35 | return gauss / gauss.sum()
36 |
37 |
38 | def create_window(window_size, channel):
39 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
40 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
41 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
42 | return window
43 |
44 |
45 | def ssim(img1, img2, window_size=11, size_average=True):
46 | channel = img1.size(-3)
47 | window = create_window(window_size, channel)
48 |
49 | if img1.is_cuda:
50 | window = window.cuda(img1.get_device())
51 | window = window.type_as(img1)
52 |
53 | return _ssim(img1, img2, window, window_size, channel, size_average)
54 |
55 |
56 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
57 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
58 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
59 |
60 | mu1_sq = mu1.pow(2)
61 | mu2_sq = mu2.pow(2)
62 | mu1_mu2 = mu1 * mu2
63 |
64 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
65 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
66 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
67 |
68 | C1 = 0.01 ** 2
69 | C2 = 0.03 ** 2
70 |
71 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
72 |
73 | if size_average:
74 | return ssim_map.mean()
75 | else:
76 | return ssim_map.mean(1).mean(1).mean(1)
77 |
--------------------------------------------------------------------------------
/utils/other_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
6 | """
7 | Returns torch.sqrt(torch.max(0, x))
8 | but with a zero subgradient where x is 0.
9 | """
10 | ret = torch.zeros_like(x)
11 | positive_mask = x > 0
12 | ret[positive_mask] = torch.sqrt(x[positive_mask])
13 | return ret
14 |
15 |
16 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
17 | """
18 | Convert rotations given as rotation matrices to quaternions.
19 |
20 | Args:
21 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
22 |
23 | Returns:
24 | quaternions with real part first, as tensor of shape (..., 4).
25 | """
26 | if matrix.size(-1) != 3 or matrix.size(-2) != 3:
27 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
28 |
29 | batch_dim = matrix.shape[:-2]
30 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
31 | matrix.reshape(batch_dim + (9,)), dim=-1
32 | )
33 |
34 | q_abs = _sqrt_positive_part(
35 | torch.stack(
36 | [
37 | 1.0 + m00 + m11 + m22,
38 | 1.0 + m00 - m11 - m22,
39 | 1.0 - m00 + m11 - m22,
40 | 1.0 - m00 - m11 + m22,
41 | ],
42 | dim=-1,
43 | )
44 | )
45 |
46 | # we produce the desired quaternion multiplied by each of r, i, j, k
47 | quat_by_rijk = torch.stack(
48 | [
49 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
50 | # `int`.
51 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
52 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
53 | # `int`.
54 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
55 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
56 | # `int`.
57 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
58 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
59 | # `int`.
60 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
61 | ],
62 | dim=-2,
63 | )
64 |
65 | # We floor here at 0.1 but the exact level is not important; if q_abs is small,
66 | # the candidate won't be picked.
67 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
68 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
69 |
70 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
71 | # forall i; we pick the best-conditioned one (with the largest denominator)
72 |
73 | return quat_candidates[
74 | torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
75 | ].reshape(batch_dim + (4,))
76 |
77 |
78 | def depth2normal(depth:torch.Tensor, focal:float=None):
79 | if depth.dim() == 2:
80 | depth = depth[None, None]
81 | elif depth.dim() == 3:
82 | depth = depth.squeeze()[None, None]
83 | if focal is None:
84 | focal = depth.shape[-1] / 2 / np.tan(torch.pi/6)
85 | depth = torch.cat([depth[:, :, :1], depth, depth[:, :, -1:]], dim=2)
86 | depth = torch.cat([depth[..., :1], depth, depth[..., -1:]], dim=3)
87 | kernel = torch.tensor([[[ 0, 0, 0],
88 | [-.5, 0, .5],
89 | [ 0, 0, 0]],
90 | [[ 0, -.5, 0],
91 | [ 0, 0, 0],
92 | [ 0, .5, 0]]], device=depth.device, dtype=depth.dtype)[:, None]
93 | normal = torch.nn.functional.conv2d(depth, kernel, padding='valid')[0].permute(1, 2, 0)
94 | normal = normal / (depth[0, 0, 1:-1, 1:-1, None] + 1e-10) * focal
95 | normal = torch.cat([normal, torch.ones_like(normal[..., :1])], dim=-1)
96 | normal = normal / normal.norm(dim=-1, keepdim=True)
97 | return normal.permute(2, 0, 1)
98 |
--------------------------------------------------------------------------------
/utils/pickle_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 |
4 | def save_obj(path, obj):
5 | file = open(path, 'wb')
6 | obj_str = pickle.dumps(obj)
7 | file.write(obj_str)
8 | file.close()
9 |
10 |
11 | def load_obj(path):
12 | file = open(path, 'rb')
13 | obj = pickle.loads(file.read())
14 | file.close()
15 | return obj
16 |
--------------------------------------------------------------------------------
/utils/pose_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from utils.graphics_utils import fov2focal
4 |
5 | trans_t = lambda t: torch.Tensor([
6 | [1, 0, 0, 0],
7 | [0, 1, 0, 0],
8 | [0, 0, 1, t],
9 | [0, 0, 0, 1]]).float()
10 |
11 | rot_phi = lambda phi: torch.Tensor([
12 | [1, 0, 0, 0],
13 | [0, np.cos(phi), -np.sin(phi), 0],
14 | [0, np.sin(phi), np.cos(phi), 0],
15 | [0, 0, 0, 1]]).float()
16 |
17 | rot_theta = lambda th: torch.Tensor([
18 | [np.cos(th), 0, -np.sin(th), 0],
19 | [0, 1, 0, 0],
20 | [np.sin(th), 0, np.cos(th), 0],
21 | [0, 0, 0, 1]]).float()
22 |
23 |
24 | def rodrigues_mat_to_rot(R):
25 | eps = 1e-16
26 | trc = np.trace(R)
27 | trc2 = (trc - 1.) / 2.
28 | # sinacostrc2 = np.sqrt(1 - trc2 * trc2)
29 | s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]])
30 | if (1 - trc2 * trc2) >= eps:
31 | tHeta = np.arccos(trc2)
32 | tHetaf = tHeta / (2 * (np.sin(tHeta)))
33 | else:
34 | tHeta = np.real(np.arccos(trc2))
35 | tHetaf = 0.5 / (1 - tHeta / 6)
36 | omega = tHetaf * s
37 | return omega
38 |
39 |
40 | def rodrigues_rot_to_mat(r):
41 | wx, wy, wz = r
42 | theta = np.sqrt(wx * wx + wy * wy + wz * wz)
43 | a = np.cos(theta)
44 | b = (1 - np.cos(theta)) / (theta * theta)
45 | c = np.sin(theta) / theta
46 | R = np.zeros([3, 3])
47 | R[0, 0] = a + b * (wx * wx)
48 | R[0, 1] = b * wx * wy - c * wz
49 | R[0, 2] = b * wx * wz + c * wy
50 | R[1, 0] = b * wx * wy + c * wz
51 | R[1, 1] = a + b * (wy * wy)
52 | R[1, 2] = b * wy * wz - c * wx
53 | R[2, 0] = b * wx * wz - c * wy
54 | R[2, 1] = b * wz * wy + c * wx
55 | R[2, 2] = a + b * (wz * wz)
56 | return R
57 |
58 |
59 | def normalize(x):
60 | return x / np.linalg.norm(x)
61 |
62 |
63 | def pose_spherical(theta, phi, radius):
64 | c2w = trans_t(radius)
65 | c2w = rot_phi(phi / 180. * np.pi) @ c2w
66 | c2w = rot_theta(theta / 180. * np.pi) @ c2w
67 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
68 | return c2w
69 |
70 | def viewmatrix(z, up, pos):
71 | vec2 = normalize(z)
72 | vec1_avg = up
73 | vec0 = normalize(np.cross(vec1_avg, vec2))
74 | vec1 = normalize(np.cross(vec2, vec0))
75 | m = np.stack([vec0, vec1, vec2, pos], 1)
76 | return m
77 |
78 | def poses_avg(poses):
79 | center = poses[:, :3, 3].mean(0)
80 | vec2 = normalize(poses[:, :3, 2].sum(0))
81 | up = poses[:, :3, 1].sum(0)
82 | c2w = viewmatrix(vec2, up, center)
83 | return c2w
84 |
85 | def render_path_spiral(c2ws, focal, zrate=.1, rots=3, N=300):
86 | c2w = poses_avg(c2ws)
87 | up = normalize(c2ws[:, :3, 1].sum(0))
88 | tt = c2ws[:,:3,3]
89 | rads = np.percentile(np.abs(tt), 90, 0)
90 | rads[:] = rads.max() * .05
91 |
92 | render_poses = []
93 | rads = np.array(list(rads) + [1.])
94 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
95 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
96 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
97 | # c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
98 | # z = normalize(c2w[:3, 2])
99 | render_poses.append(viewmatrix(z, up, c))
100 | render_poses = np.stack(render_poses, axis=0)
101 | render_poses = np.concatenate([render_poses, np.zeros_like(render_poses[..., :1, :])], axis=1)
102 | render_poses[..., 3, 3] = 1
103 | render_poses = np.array(render_poses, dtype=np.float32)
104 | return render_poses
105 |
106 | def render_wander_path(view):
107 | focal_length = fov2focal(view.FoVy, view.image_height)
108 | R = view.R
109 | R[:, 1] = -R[:, 1]
110 | R[:, 2] = -R[:, 2]
111 | T = -view.T.reshape(-1, 1)
112 | pose = np.concatenate([R, T], -1)
113 |
114 | num_frames = 60
115 | max_disp = 5000.0 # 64 , 48
116 |
117 | max_trans = max_disp / focal_length # Maximum camera translation to satisfy max_disp parameter
118 | output_poses = []
119 |
120 | for i in range(num_frames):
121 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))
122 | y_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 # * 3.0 / 4.0
123 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0
124 |
125 | i_pose = np.concatenate([
126 | np.concatenate(
127 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
128 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
129 | ], axis=0) # [np.newaxis, :, :]
130 |
131 | i_pose = np.linalg.inv(i_pose) # torch.tensor(np.linalg.inv(i_pose)).float()
132 |
133 | ref_pose = np.concatenate([pose, np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
134 |
135 | render_pose = np.dot(ref_pose, i_pose)
136 | output_poses.append(torch.Tensor(render_pose))
137 |
138 | return output_poses
139 |
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | # @title Configure dataset directories
2 | import os
3 | from pathlib import Path
4 |
5 | # @markdown The base directory for all captures. This can be anything if you're running this notebook on your own Jupyter runtime.
6 | save_dir = '/data00/yzy/Git_Project/data/dynamic/mine/' # @param {type: 'string'}
7 | capture_name = 'lemon' # @param {type: 'string'}
8 | # The root directory for this capture.
9 | root_dir = Path(save_dir, capture_name)
10 | # Where to save RGB images.
11 | rgb_dir = root_dir / 'rgb'
12 | rgb_raw_dir = root_dir / 'rgb-raw'
13 | # Where to save the COLMAP outputs.
14 | colmap_dir = root_dir / 'colmap'
15 | colmap_db_path = colmap_dir / 'database.db'
16 | colmap_out_path = colmap_dir / 'sparse'
17 |
18 | colmap_out_path.mkdir(exist_ok=True, parents=True)
19 | rgb_raw_dir.mkdir(exist_ok=True, parents=True)
20 |
21 | print(f"""Directories configured:
22 | root_dir = {root_dir}
23 | rgb_raw_dir = {rgb_raw_dir}
24 | rgb_dir = {rgb_dir}
25 | colmap_dir = {colmap_dir}
26 | """)
27 |
28 | # ==================== colmap =========================
29 | # @title Extract features.
30 | # @markdown Computes SIFT features and saves them to the COLMAP DB.
31 | share_intrinsics = True # @param {type: 'boolean'}
32 | assume_upright_cameras = True # @param {type: 'boolean'}
33 |
34 | # @markdown This sets the scale at which we will run COLMAP. A scale of 1 will be more accurate but will be slow.
35 | colmap_image_scale = 4 # @param {type: 'number'}
36 | colmap_rgb_dir = rgb_dir / f'{colmap_image_scale}x'
37 |
38 | # @markdown Check this if you want to re-process SfM.
39 | overwrite = False # @param {type: 'boolean'}
40 |
41 | if overwrite and colmap_db_path.exists():
42 | colmap_db_path.unlink()
43 |
44 | os.system('colmap feature_extractor \
45 | --SiftExtraction.use_gpu 0 \
46 | --SiftExtraction.upright {int(assume_upright_cameras)} \
47 | --ImageReader.camera_model OPENCV \
48 | --ImageReader.single_camera {int(share_intrinsics)} \
49 | --database_path "{str(colmap_db_path)}" \
50 | --image_path "{str(colmap_rgb_dir)}"')
51 |
52 | # @title Match features.
53 | # @markdown Match the SIFT features between images. Use `exhaustive` if you only have a few images and use `vocab_tree` if you have a lot of images.
54 |
55 | match_method = 'exhaustive' # @param ["exhaustive", "vocab_tree"]
56 |
57 | if match_method == 'exhaustive':
58 | os.system('colmap exhaustive_matcher \
59 | --SiftMatching.use_gpu 0 \
60 | --database_path "{str(colmap_db_path)}"')
61 |
62 | # @title Reconstruction.
63 | # @markdown Run structure-from-motion to compute camera parameters.
64 |
65 | refine_principal_point = True # @param {type:"boolean"}
66 | min_num_matches = 32 # @param {type: 'number'}
67 | filter_max_reproj_error = 2 # @param {type: 'number'}
68 | tri_complete_max_reproj_error = 2 # @param {type: 'number'}
69 |
70 | os.system('colmap mapper \
71 | --Mapper.ba_refine_principal_point {int(refine_principal_point)} \
72 | --Mapper.filter_max_reproj_error $filter_max_reproj_error \
73 | --Mapper.tri_complete_max_reproj_error $tri_complete_max_reproj_error \
74 | --Mapper.min_num_matches $min_num_matches \
75 | --database_path "{str(colmap_db_path)}" \
76 | --image_path "{str(colmap_rgb_dir)}" \
77 | --export_path "{str(colmap_out_path)}"')
78 |
79 | print("debug")
80 |
--------------------------------------------------------------------------------
/utils/rigid_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def skew(w: torch.Tensor) -> torch.Tensor:
5 | """Build a skew matrix ("cross product matrix") for vector w.
6 |
7 | Modern Robotics Eqn 3.30.
8 |
9 | Args:
10 | w: (N, 3) A 3-vector
11 |
12 | Returns:
13 | W: (N, 3, 3) A skew matrix such that W @ v == w x v
14 | """
15 | zeros = torch.zeros(w.shape[0], device=w.device)
16 | w_skew_list = [zeros, -w[:, 2], w[:, 1],
17 | w[:, 2], zeros, -w[:, 0],
18 | -w[:, 1], w[:, 0], zeros]
19 | w_skew = torch.stack(w_skew_list, dim=-1).reshape(-1, 3, 3)
20 | return w_skew
21 |
22 |
23 | def rp_to_se3(R: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
24 | """Rotation and translation to homogeneous transform.
25 |
26 | Args:
27 | R: (3, 3) An orthonormal rotation matrix.
28 | p: (3,) A 3-vector representing an offset.
29 |
30 | Returns:
31 | X: (4, 4) The homogeneous transformation matrix described by rotating by R
32 | and translating by p.
33 | """
34 | bottom_row = torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=R.device).repeat(R.shape[0], 1, 1)
35 | transform = torch.cat([torch.cat([R, p], dim=-1), bottom_row], dim=1)
36 |
37 | return transform
38 |
39 |
40 | def exp_so3(w: torch.Tensor, theta: float) -> torch.Tensor:
41 | """Exponential map from Lie algebra so3 to Lie group SO3.
42 |
43 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula.
44 |
45 | Args:
46 | w: (3,) An axis of rotation.
47 | theta: An angle of rotation.
48 |
49 | Returns:
50 | R: (3, 3) An orthonormal rotation matrix representing a rotation of
51 | magnitude theta about axis w.
52 | """
53 | W = skew(w)
54 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device)
55 | W_sqr = torch.bmm(W, W) # batch matrix multiplication
56 | R = identity + torch.sin(theta.unsqueeze(-1)) * W + (1.0 - torch.cos(theta.unsqueeze(-1))) * W_sqr
57 | return R
58 |
59 |
60 | def exp_se3(S: torch.Tensor, theta: float) -> torch.Tensor:
61 | """Exponential map from Lie algebra so3 to Lie group SO3.
62 |
63 | Modern Robotics Eqn 3.88.
64 |
65 | Args:
66 | S: (6,) A screw axis of motion.
67 | theta: Magnitude of motion.
68 |
69 | Returns:
70 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating
71 | motion of magnitude theta about S for one second.
72 | """
73 | w, v = torch.split(S, 3, dim=-1)
74 | W = skew(w)
75 | R = exp_so3(w, theta)
76 |
77 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device)
78 | W_sqr = torch.bmm(W, W)
79 | theta = theta.view(-1, 1, 1)
80 |
81 | p = torch.bmm((theta * identity + (1.0 - torch.cos(theta)) * W + (theta - torch.sin(theta)) * W_sqr),
82 | v.unsqueeze(-1))
83 | return rp_to_se3(R, p)
84 |
85 |
86 | def to_homogenous(v: torch.Tensor) -> torch.Tensor:
87 | """Converts a vector to a homogeneous coordinate vector by appending a 1.
88 |
89 | Args:
90 | v: A tensor representing a vector or batch of vectors.
91 |
92 | Returns:
93 | A tensor with an additional dimension set to 1.
94 | """
95 | return torch.cat([v, torch.ones_like(v[..., :1])], dim=-1)
96 |
97 |
98 | def from_homogenous(v: torch.Tensor) -> torch.Tensor:
99 | """Converts a homogeneous coordinate vector to a standard vector by dividing by the last element.
100 |
101 | Args:
102 | v: A tensor representing a homogeneous coordinate vector or batch of homogeneous coordinate vectors.
103 |
104 | Returns:
105 | A tensor with the last dimension removed.
106 | """
107 | return v[..., :3] / v[..., -1:]
108 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 |
115 | def RGB2SH(rgb):
116 | return (rgb - 0.5) / C0
117 |
118 |
119 | def SH2RGB(sh):
120 | return sh * C0 + 0.5
121 |
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 |
17 | def mkdir_p(folder_path):
18 | # Creates a directory. equivalent to using mkdir -p on the command line
19 | try:
20 | makedirs(folder_path)
21 | except OSError as exc: # Python >2.5
22 | if exc.errno == EEXIST and path.isdir(folder_path):
23 | pass
24 | else:
25 | raise
26 |
27 |
28 | def searchForMaxIteration(folder):
29 | if not os.path.exists(folder):
30 | return None
31 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder) if "_" in fname]
32 | return max(saved_iters) if saved_iters != [] else None
33 |
--------------------------------------------------------------------------------
/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | from gaussian_renderer import render
2 |
3 |
4 | def render_cur_cam(self, cur_cam):
5 | fid = cur_cam.fid
6 | if self.deform.name == 'node':
7 | if 'Node' in self.visualization_mode:
8 | gaussians = self.deform.deform.as_gaussians # if self.iteration_node_rendering < self.opt.iterations_node_rendering else self.deform.deform.as_gaussians_visualization
9 | time_input = fid.unsqueeze(0).expand(gaussians.get_xyz.shape[0], -1)
10 | d_values = self.deform.deform.query_network(x=gaussians.get_xyz.detach(), t=time_input)
11 | if self.motion_animation_d_values is not None:
12 | for key in self.motion_animation_d_values:
13 | d_values[key] = self.motion_animation_d_values[key]
14 | d_xyz, d_opacity, d_color = d_values['d_xyz'] * gaussians.motion_mask, d_values['d_opacity'] * gaussians.motion_mask if d_values['d_opacity'] is not None else None, d_values['d_color'] * gaussians.motion_mask if d_values['d_color'] is not None else None
15 | d_rotation, d_scaling = 0., 0.
16 | if self.animation_trans_bias is not None:
17 | d_xyz = d_xyz + self.animation_trans_bias
18 | gs_rot_bias = None
19 | vis_scale_const = self.vis_scale_const
20 | else:
21 | time_input = self.deform.deform.expand_time(fid)
22 | d_values = self.deform.step(self.gaussians.get_xyz.detach(), time_input, feature=self.gaussians.feature, is_training=False, node_trans_bias=self.animation_trans_bias, node_rot_bias=self.animation_rot_bias, motion_mask=self.gaussians.motion_mask, camera_center=cur_cam.camera_center, animation_d_values=self.motion_animation_d_values)
23 | gaussians = self.gaussians
24 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color']
25 | gs_rot_bias = d_values['gs_rot_bias'] # GS rotation bias
26 | vis_scale_const = None
27 | else:
28 | vis_scale_const = None
29 | if self.iteration < self.opt.warm_up:
30 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = 0.0, 0.0, 0.0, 0.0, 0.0
31 | gaussians = self.gaussians
32 | else:
33 | N = self.gaussians.get_xyz.shape[0]
34 | time_input = fid.unsqueeze(0).expand(N, -1)
35 | gaussians = self.gaussians
36 | d_values = self.deform.step(self.gaussians.get_xyz.detach(), time_input, feature=self.gaussians.feature, camera_center=cur_cam.camera_center)
37 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color']
38 | gs_rot_bias = None
39 |
40 | render_motion = "Motion" in self.visualization_mode
41 | if render_motion:
42 | vis_scale_const = self.vis_scale_const
43 | if type(d_rotation) is not float and gaussians._rotation.shape[0] != d_rotation.shape[0]:
44 | d_xyz, d_rotation, d_scaling = 0, 0, 0
45 | print('Async in Gaussian Switching')
46 | out = render(viewpoint_camera=cur_cam, pc=gaussians, pipe=self.pipe, bg_color=self.background, d_xyz=d_xyz, d_rotation=d_rotation, d_scaling=d_scaling, render_motion=render_motion, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=self.deform.d_rot_as_res, gs_rot_bias=gs_rot_bias, scale_const=vis_scale_const)
47 |
48 | buffer_image = out[self.mode] # [3, H, W]
49 | return buffer_image
50 |
--------------------------------------------------------------------------------