├── LICENSE ├── README.md ├── confs ├── dtu_sift_porf.conf ├── dtu_sift_pose.conf ├── mobilebrick_sift_porf.conf └── mobilebrick_sift_pose.conf ├── export_camera_file.py ├── models ├── dataset.py ├── embedder.py ├── fields.py ├── networks.py └── renderer.py ├── preprocess_data └── export_colmap_matches.py ├── requirements.txt ├── scripts ├── train_sift_dtu.sh └── train_sift_mobilebrick.sh ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Wenjing Bian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PoRF: Pose Residual Field for Accurate Neural Surface Reconstruction 2 | We present PoRF (pose residual field) for joint optimisation of neural surface reconstruction and camera pose. It uses an MLP to refine the camera poses for all images in the dataset instead of optimising pose parameters for each image independently. The following figure shows that our method can take the COLMAP pose as input and our refined camera pose can be comparable to the GT pose in 3D surface reconstruction. The Chamfer distances (mm) are compared. 3 | 4 | 5 | ![alt tag](https://porf.active.vision/image/dtu_vis.png) 6 | 7 | 8 | 9 | ## [Project page](https://porf.active.vision/) | [Paper](https://arxiv.org/abs/2310.07449) | [Data](https://1drv.ms/u/s!AiV6XqkxJHE2pme7CIkceyLGsng2?e=6qsnlt) 10 | This is the official repo for the implementation of **PoRF: Pose Residual Field for Accurate Neural Surface Reconstruction**. 11 | 12 | ## Usage 13 | 14 | #### Data Convention 15 | The data is organized as follows: 16 | 17 | ``` 18 | 19 | |-- cameras.npz # GT camera parameters 20 | |-- cameras_colmap.npz # COLMAP camera parameters 21 | |-- image 22 | |-- 000.png # target image for each view 23 | |-- 001.png 24 | ... 25 | |-- colmap_matches 26 | |-- 000000.npz # matches exported from COLMAP 27 | |-- 000001.npz 28 | ... 29 | ``` 30 | 31 | Here the `cameras.npz` follows the data format in [IDR](https://github.com/lioryariv/idr/blob/main/DATA_CONVENTION.md), where `world_mat_xx` denotes the world to image projection matrix, and `scale_mat_xx` denotes the normalization matrix. 32 | 33 | ### Setup 34 | 35 | Clone this repository 36 | 37 | ```shell 38 | git clone https://github.com/ActiveVisionLab/porf.git 39 | cd porf 40 | 41 | conda create -n porf python=3.9 42 | conda activate porf 43 | conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.7 -c pytorch -c nvidia 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | 48 | ### Running 49 | 50 | - **Example (you need to change the address)** 51 | 52 | ```shell 53 | bash scripts/train_sift_dtu.sh 54 | ``` 55 | 56 | - **Training** 57 | 58 | ```shell 59 | python train.py --mode train --conf confs/dtu_sift_porf.conf --case 60 | ``` 61 | After training, a mesh should be found in `exp///meshes/.ply`. Note that it is used for debugging the first-stage pose optimisation. If you need high-quality mesh as shown in the paper, you should export the refined camera pose and use it to train a [Voxurf](https://github.com/wutong16/Voxurf) model. 62 | 63 | 64 | - **Export Refined Camera Pose (change folder address)** 65 | ```shell 66 | python export_camera_file.py 67 | ``` 68 | 69 | 70 | ## Citation 71 | 72 | Cite below if you find this repository helpful to your project: 73 | 74 | ``` 75 | @inproceedings{bian2024porf, 76 | title={PoRF: Pose Residual Field for Accurate Neural Surface Reconstruction}, 77 | author={Jia-Wang Bian and Wenjing Bian and Victor Adrian Prisacariu and Philip Torr}, 78 | booktitle={ICLR}, 79 | year={2024} 80 | } 81 | ``` 82 | 83 | ## Acknowledgement 84 | 85 | Some code snippets are borrowed from [NeuS](https://github.com/Totoro97/NeuS). Thanks for these great projects. 86 | -------------------------------------------------------------------------------- /confs/dtu_sift_porf.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp_dtu/CASE_NAME/dtu_sift_porf 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = ./porf_data/dtu/CASE_NAME/ 11 | render_cameras_name = cameras_colmap.npz 12 | object_cameras_name = cameras_colmap.npz 13 | train_resolution_level = 1 14 | match_folder = colmap_matches 15 | mask_folder = mask 16 | } 17 | 18 | train { 19 | learning_rate = 5e-4 20 | learning_rate_alpha = 0.05 21 | 22 | pose_learning_rate = 5e-4 23 | 24 | pose_end_iter = 50000 25 | pose_val_freq = 50 26 | 27 | use_porf = True 28 | scale = 1e-2 29 | inlier_threshold = 20 30 | num_pairs = 20 31 | 32 | batch_size = 512 33 | validate_resolution_level = 4 34 | warm_up_end = 500 35 | anneal_end = 5000 36 | use_white_bkgd = False 37 | 38 | save_freq = 50000 39 | val_freq = 5000 40 | val_mesh_freq = 5000 41 | report_freq = 1000 42 | 43 | # loss weights 44 | igr_weight = 0.1 45 | color_loss_weight = 1.0 46 | epipolar_loss_weight = 0.1 47 | } 48 | 49 | model { 50 | sdf_network { 51 | d_out = 257 52 | d_in = 3 53 | d_hidden = 256 54 | n_layers = 8 55 | skip_in = [4] 56 | multires = 6 57 | bias = 0.5 58 | scale = 1.0 59 | geometric_init = True 60 | weight_norm = True 61 | } 62 | 63 | variance_network { 64 | init_val = 0.3 65 | } 66 | 67 | render_network { 68 | d_feature = 256 69 | mode = idr 70 | d_in = 9 71 | d_out = 3 72 | d_hidden = 256 73 | n_layers = 4 74 | weight_norm = True 75 | multires_view = 4 76 | squeeze_out = True 77 | } 78 | 79 | neus_renderer { 80 | n_samples = 64 81 | n_importance = 64 82 | n_outside = 32 83 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 84 | perturb = 1.0 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /confs/dtu_sift_pose.conf: -------------------------------------------------------------------------------- 1 | # full means mlp, pe, and gradient scaling 2 | general { 3 | base_exp_dir = ./exp_dtu/CASE_NAME/dtu_sift_pose 4 | recording = [ 5 | ./, 6 | ./models 7 | ] 8 | } 9 | 10 | dataset { 11 | data_dir = ./porf_data/dtu/CASE_NAME/ 12 | render_cameras_name = cameras_colmap.npz 13 | object_cameras_name = cameras_colmap.npz 14 | train_resolution_level = 1 15 | match_folder = colmap_matches 16 | mask_folder = mask 17 | } 18 | 19 | train { 20 | learning_rate = 5e-4 21 | learning_rate_alpha = 0.05 22 | 23 | pose_learning_rate = 5e-4 24 | 25 | pose_end_iter = 50000 26 | pose_val_freq = 50 27 | 28 | use_porf = False 29 | inlier_threshold = 20 30 | num_pairs = 20 31 | 32 | batch_size = 512 33 | validate_resolution_level = 4 34 | warm_up_end = 500 35 | anneal_end = 5000 36 | use_white_bkgd = False 37 | 38 | save_freq = 50000 39 | val_freq = 5000 40 | val_mesh_freq = 5000 41 | report_freq = 1000 42 | 43 | # loss weights 44 | igr_weight = 0.1 45 | color_loss_weight = 1.0 46 | epipolar_loss_weight = 0.1 47 | } 48 | 49 | model { 50 | sdf_network { 51 | d_out = 257 52 | d_in = 3 53 | d_hidden = 256 54 | n_layers = 8 55 | skip_in = [4] 56 | multires = 6 57 | bias = 0.5 58 | scale = 1.0 59 | geometric_init = True 60 | weight_norm = True 61 | } 62 | 63 | 64 | variance_network { 65 | init_val = 0.3 66 | } 67 | 68 | 69 | render_network { 70 | d_feature = 256 71 | mode = idr 72 | d_in = 9 73 | d_out = 3 74 | d_hidden = 256 75 | n_layers = 4 76 | weight_norm = True 77 | multires_view = 4 78 | squeeze_out = True 79 | } 80 | 81 | neus_renderer { 82 | n_samples = 64 83 | n_importance = 64 84 | n_outside = 32 85 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 86 | perturb = 1.0 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /confs/mobilebrick_sift_porf.conf: -------------------------------------------------------------------------------- 1 | # full means mlp, pe, and gradient scaling 2 | general { 3 | base_exp_dir = ./exp_mobilebrick/CASE_NAME/mobilebrick_sift_porf 4 | recording = [ 5 | ./, 6 | ./models 7 | ] 8 | } 9 | 10 | dataset { 11 | data_dir = ./porf_data/mobilebrick/CASE_NAME/ 12 | render_cameras_name = cameras_arkit.npz 13 | object_cameras_name = cameras_arkit.npz 14 | train_resolution_level = 1 15 | match_folder = colmap_matches 16 | mask_folder = mask 17 | } 18 | 19 | train { 20 | learning_rate = 5e-4 21 | learning_rate_alpha = 0.05 22 | 23 | pose_learning_rate = 5e-4 24 | 25 | pose_end_iter = 50000 26 | pose_val_freq = 50 27 | 28 | use_porf = True 29 | scale = 1e-2 30 | inlier_threshold = 20 31 | num_pairs = 20 32 | 33 | batch_size = 512 34 | validate_resolution_level = 4 35 | warm_up_end = 500 36 | anneal_end = 5000 37 | use_white_bkgd = False 38 | 39 | save_freq = 50000 40 | val_freq = 5000 41 | val_mesh_freq = 5000 42 | report_freq = 1000 43 | 44 | # loss weights 45 | igr_weight = 0.1 46 | color_loss_weight = 1.0 47 | epipolar_loss_weight = 0.1 48 | } 49 | 50 | model { 51 | sdf_network { 52 | d_out = 257 53 | d_in = 3 54 | d_hidden = 256 55 | n_layers = 8 56 | skip_in = [4] 57 | multires = 6 58 | bias = 0.5 59 | scale = 1.0 60 | geometric_init = True 61 | weight_norm = True 62 | } 63 | 64 | 65 | variance_network { 66 | init_val = 0.3 67 | } 68 | 69 | 70 | render_network { 71 | d_feature = 256 72 | mode = idr 73 | d_in = 9 74 | d_out = 3 75 | d_hidden = 256 76 | n_layers = 4 77 | weight_norm = True 78 | multires_view = 4 79 | squeeze_out = True 80 | } 81 | 82 | neus_renderer { 83 | n_samples = 64 84 | n_importance = 64 85 | n_outside = 32 86 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 87 | perturb = 1.0 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /confs/mobilebrick_sift_pose.conf: -------------------------------------------------------------------------------- 1 | # full means mlp, pe, and gradient scaling 2 | general { 3 | base_exp_dir = ./exp_mobilebrick/CASE_NAME/mobilebrick_sift_pose 4 | recording = [ 5 | ./, 6 | ./models 7 | ] 8 | } 9 | 10 | dataset { 11 | data_dir = ./porf_data/mobilebrick/CASE_NAME/ 12 | render_cameras_name = cameras_arkit.npz 13 | object_cameras_name = cameras_arkit.npz 14 | train_resolution_level = 1 15 | match_folder = colmap_matches 16 | mask_folder = mask 17 | } 18 | 19 | train { 20 | learning_rate = 5e-4 21 | learning_rate_alpha = 0.05 22 | 23 | pose_learning_rate = 5e-4 24 | 25 | pose_end_iter = 50000 26 | pose_val_freq = 50 27 | 28 | use_porf = False 29 | inlier_threshold = 20 30 | num_pairs = 20 31 | 32 | batch_size = 512 33 | validate_resolution_level = 4 34 | warm_up_end = 500 35 | anneal_end = 5000 36 | use_white_bkgd = False 37 | 38 | save_freq = 50000 39 | val_freq = 5000 40 | val_mesh_freq = 5000 41 | report_freq = 1000 42 | 43 | # loss weights 44 | igr_weight = 0.1 45 | color_loss_weight = 1.0 46 | epipolar_loss_weight = 0.1 47 | } 48 | 49 | model { 50 | sdf_network { 51 | d_out = 257 52 | d_in = 3 53 | d_hidden = 256 54 | n_layers = 8 55 | skip_in = [4] 56 | multires = 6 57 | bias = 0.5 58 | scale = 1.0 59 | geometric_init = True 60 | weight_norm = True 61 | } 62 | 63 | 64 | variance_network { 65 | init_val = 0.3 66 | } 67 | 68 | 69 | render_network { 70 | d_feature = 256 71 | mode = idr 72 | d_in = 9 73 | d_out = 3 74 | d_hidden = 256 75 | n_layers = 4 76 | weight_norm = True 77 | multires_view = 4 78 | squeeze_out = True 79 | } 80 | 81 | neus_renderer { 82 | n_samples = 64 83 | n_importance = 64 84 | n_outside = 32 85 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 86 | perturb = 1.0 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /export_camera_file.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from path import Path 3 | import os 4 | import cv2 5 | from scipy.spatial.transform import Rotation 6 | 7 | 8 | def load_K_Rt_from_P(filename, P=None): 9 | if P is None: 10 | lines = open(filename).read().splitlines() 11 | if len(lines) == 4: 12 | lines = lines[1:] 13 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 14 | P = np.asarray(lines).astype(np.float32).squeeze() 15 | 16 | out = cv2.decomposeProjectionMatrix(P) 17 | K = out[0] 18 | R = out[1] 19 | t = out[2] 20 | 21 | K = K / K[2, 2] 22 | intrinsics = np.eye(4) 23 | intrinsics[:3, :3] = K 24 | 25 | pose = np.eye(4, dtype=np.float32) 26 | pose[:3, :3] = R.transpose() 27 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 28 | 29 | return intrinsics, pose 30 | 31 | 32 | def umeyama_alignment(x, y, with_scale=True): 33 | """ 34 | Computes the least squares solution parameters of an Sim(m) matrix 35 | that minimizes the distance between a set of registered points. 36 | Umeyama, Shinji: Least-squares estimation of transformation parameters 37 | between two point patterns. IEEE PAMI, 1991 38 | :param x: mxn matrix of points, m = dimension, n = nr. of data points 39 | :param y: mxn matrix of points, m = dimension, n = nr. of data points 40 | :param with_scale: set to True to align also the scale (default: 1.0 scale) 41 | :return: r, t, c - rotation matrix, translation vector and scale factor 42 | """ 43 | if x.shape != y.shape: 44 | assert False, "x.shape not equal to y.shape" 45 | 46 | # m = dimension, n = nr. of data points 47 | m, n = x.shape 48 | 49 | # means, eq. 34 and 35 50 | mean_x = x.mean(axis=1) 51 | mean_y = y.mean(axis=1) 52 | 53 | # variance, eq. 36 54 | # "transpose" for column subtraction 55 | sigma_x = 1.0 / n * (np.linalg.norm(x - mean_x[:, np.newaxis])**2) 56 | 57 | # covariance matrix, eq. 38 58 | outer_sum = np.zeros((m, m)) 59 | for i in range(n): 60 | outer_sum += np.outer((y[:, i] - mean_y), (x[:, i] - mean_x)) 61 | cov_xy = np.multiply(1.0 / n, outer_sum) 62 | 63 | # SVD (text betw. eq. 38 and 39) 64 | u, d, v = np.linalg.svd(cov_xy) 65 | 66 | # S matrix, eq. 43 67 | s = np.eye(m) 68 | if np.linalg.det(u) * np.linalg.det(v) < 0.0: 69 | # Ensure a RHS coordinate system (Kabsch algorithm). 70 | s[m - 1, m - 1] = -1 71 | 72 | # rotation, eq. 40 73 | r = u.dot(s).dot(v) 74 | 75 | # scale & translation, eq. 42 and 41 76 | c = 1 / sigma_x * np.trace(np.diag(d).dot(s)) if with_scale else 1.0 77 | t = mean_y - np.multiply(c, r.dot(mean_x)) 78 | 79 | return r, t, c 80 | 81 | 82 | def pose_alignment(poses_pred, poses_gt): 83 | 84 | num_gt = poses_gt.shape[0] 85 | 86 | xyz_result = poses_pred[:num_gt, :3, 3].T 87 | xyz_gt = poses_gt[:, :3, 3].T 88 | 89 | r, t, scale = umeyama_alignment(xyz_result, xyz_gt, with_scale=True) 90 | 91 | align_transformation = np.eye(4) 92 | align_transformation[:3:, :3] = r 93 | align_transformation[:3, 3] = t 94 | 95 | for cnt in range(poses_pred.shape[0]): 96 | poses_pred[cnt][:3, 3] *= scale 97 | poses_pred[cnt] = align_transformation @ poses_pred[cnt] 98 | 99 | return poses_pred 100 | 101 | 102 | def rotation_error(pose_error): 103 | """Compute rotation error 104 | Args: 105 | pose_error (4x4 array): relative pose error 106 | Returns: 107 | rot_error (float): rotation error 108 | """ 109 | r_diff = Rotation.from_matrix(pose_error[:3, :3]) 110 | pose_error = r_diff.as_matrix() 111 | a = pose_error[0, 0] 112 | b = pose_error[1, 1] 113 | c = pose_error[2, 2] 114 | d = 0.5*(a+b+c-1.0) 115 | rot_error = np.arccos(max(min(d, 1.0), -1.0)) 116 | return rot_error 117 | 118 | 119 | def translation_error(pose_error): 120 | """Compute translation error 121 | Args: 122 | pose_error (4x4 array): relative pose error 123 | Returns: 124 | trans_error (float): translation error 125 | """ 126 | dx = pose_error[0, 3] 127 | dy = pose_error[1, 3] 128 | dz = pose_error[2, 3] 129 | trans_error = np.sqrt(dx**2+dy**2+dz**2) 130 | return trans_error 131 | 132 | 133 | def compute_RPE(gt, pred): 134 | trans_errors = [] 135 | rot_errors = [] 136 | for i in range(len(gt)-1): 137 | gt1 = gt[i] 138 | gt2 = gt[i+1] 139 | gt_rel = np.linalg.inv(gt1) @ gt2 140 | 141 | pred1 = pred[i] 142 | pred2 = pred[i+1] 143 | pred_rel = np.linalg.inv(pred1) @ pred2 144 | rel_err = np.linalg.inv(gt_rel) @ pred_rel 145 | 146 | trans_errors.append(translation_error(rel_err)) 147 | rot_errors.append(rotation_error(rel_err)) 148 | 149 | return np.array(rot_errors), np.array(trans_errors) 150 | 151 | 152 | def compute_ATE(gt, pred): 153 | """Compute RMSE of ATE 154 | Args: 155 | gt: ground-truth poses 156 | pred: predicted poses 157 | """ 158 | r_errs = [] 159 | t_errs = [] 160 | 161 | for i in range(len(pred)): 162 | # cur_gt = np.linalg.inv(gt_0) @ gt[i] 163 | cur_gt = gt[i] 164 | gt_xyz = cur_gt[:3, 3] 165 | 166 | # cur_pred = np.linalg.inv(pred_0) @ pred[i] 167 | cur_pred = pred[i] 168 | pred_xyz = cur_pred[:3, 3] 169 | 170 | align_err = gt_xyz - pred_xyz 171 | 172 | t_errs.append(np.sqrt(np.sum(align_err ** 2))) 173 | 174 | r_diff = np.linalg.inv(cur_gt[:3, :3]) @ cur_pred[:3, :3] 175 | r_errs.append(rotation_error(r_diff)) 176 | 177 | # ate = np.sqrt(np.mean(np.asarray(errors) ** 2)) 178 | return np.array(r_errs), np.array(t_errs) 179 | 180 | 181 | def generate_camera(scale_mats_np, intrinsics, poses, out_file): 182 | # write poses 183 | cameras = {} 184 | for idx in range(len(poses)): 185 | 186 | cameras["scale_mat_%d" % (idx)] = scale_mats_np[idx] 187 | 188 | K = intrinsics[idx] 189 | P = K @ np.linalg.inv(poses[idx]) 190 | cameras["world_mat_%d" % (idx)] = P 191 | 192 | np.savez(out_file, **cameras) 193 | 194 | 195 | def load_camera(cam_file, n_imgs): 196 | camera_dict = np.load(cam_file, allow_pickle=True) 197 | world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_imgs)] 198 | scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_imgs)] 199 | 200 | intrinsics = [] 201 | poses = [] 202 | for P in world_mats_np: 203 | intrinsic, pose = load_K_Rt_from_P(None, P[:3]) 204 | poses.append(pose) 205 | intrinsics.append(intrinsic) 206 | poses = np.stack(poses) 207 | intrinsics = np.stack(intrinsics) 208 | 209 | return scale_mats_np, world_mats_np, intrinsics, poses 210 | 211 | 212 | if __name__ == '__main__': 213 | 214 | root = 'exp_dtu' 215 | iters = 'poses_050000' 216 | 217 | method = 'dtu_sift_porf' 218 | out_name = 'cameras_refine_porf.npz' 219 | 220 | root_dir = Path('./porf_data/dtu/') 221 | scenes = [os.path.basename(s) for s in sorted(root_dir.dirs())] 222 | 223 | for s in scenes: 224 | scene_dir = root_dir/s 225 | 226 | pose_file = f'./{root}/{s}/{method}/{iters}/refined_pose.txt' 227 | if not os.path.exists(pose_file): 228 | continue 229 | 230 | poses_refine = np.loadtxt(pose_file).reshape(-1, 4, 4) 231 | 232 | # gt pose 233 | n_imgs = len((scene_dir/'image').files('*.png')) 234 | scale_mats_np, _, intrinsics, gt_poses = load_camera(scene_dir/'cameras.npz', n_imgs) 235 | 236 | # align pose to gt 237 | poses_refine = pose_alignment(poses_refine, gt_poses) 238 | 239 | r_err, t_err = compute_ATE(gt_poses, poses_refine) 240 | print('ate errs: ', np.mean(r_err) / 3.14 * 180, np.mean(t_err)) 241 | 242 | r_err, t_err = compute_RPE(gt_poses, poses_refine) 243 | print('rpe errs: ', np.mean(r_err) / 3.14 * 180, np.mean(t_err)) 244 | 245 | generate_camera(scale_mats_np, intrinsics, poses_refine, scene_dir/out_name) 246 | -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 as cv 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | from scipy.spatial.transform import Rotation as Rot 7 | from scipy.spatial.transform import Slerp 8 | from path import Path 9 | 10 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 11 | 12 | 13 | def load_K_Rt_from_P(filename, P=None): 14 | if P is None: 15 | lines = open(filename).read().splitlines() 16 | if len(lines) == 4: 17 | lines = lines[1:] 18 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 19 | P = np.asarray(lines).astype(np.float32).squeeze() 20 | 21 | out = cv.decomposeProjectionMatrix(P) 22 | K = out[0] 23 | R = out[1] 24 | t = out[2] 25 | 26 | K = K / K[2, 2] 27 | intrinsics = np.eye(4) 28 | intrinsics[:3, :3] = K 29 | 30 | pose = np.eye(4, dtype=np.float32) 31 | pose[:3, :3] = R.transpose() 32 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 33 | 34 | return intrinsics, pose 35 | 36 | 37 | def compute_P_from_KT(K, T): 38 | 39 | P = torch.matmul(K, torch.linalg.inv(T)) 40 | 41 | return P 42 | 43 | 44 | class Dataset: 45 | def __init__(self, conf): 46 | super(Dataset, self).__init__() 47 | self.device = torch.device('cuda') 48 | 49 | self.conf = conf 50 | self.data_dir = conf.get_string('data_dir') 51 | self.render_cameras_name = conf.get_string('render_cameras_name') 52 | self.object_cameras_name = conf.get_string('object_cameras_name') 53 | self.match_folder = conf.get_string('match_folder') 54 | 55 | print(f'Load data: Begin from {self.data_dir}') 56 | 57 | self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png'))) 58 | if len(self.images_lis) < 1: 59 | self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.jpg'))) 60 | if len(self.images_lis) < 1: 61 | self.images_lis = sorted(glob(os.path.join(self.data_dir, 'rgb/*.png'))) 62 | 63 | self.n_images = len(self.images_lis) 64 | 65 | camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name)) 66 | # world_mat is a projection matrix from world to image 67 | self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 68 | 69 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin. 70 | self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 71 | 72 | self.intrinsics_all = [] 73 | self.intrinsics_all_inv = [] 74 | self.pose_all = [] 75 | for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np): 76 | P = world_mat @ scale_mat 77 | P = P[:3, :4] 78 | 79 | intrinsics, pose = load_K_Rt_from_P(None, P) 80 | intrinsics = torch.from_numpy(intrinsics).float() 81 | self.intrinsics_all.append(intrinsics) 82 | self.intrinsics_all_inv.append(torch.linalg.inv(intrinsics)) 83 | self.pose_all.append(torch.from_numpy(pose).float()) 84 | 85 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] 86 | self.intrinsics_all_inv = torch.stack(self.intrinsics_all_inv).to(self.device) 87 | self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4] 88 | 89 | # Object scale mat: region of interest to **extract mesh** 90 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0]) 91 | object_bbox_max = np.array([1.01, 1.01, 1.01, 1.0]) 92 | self.object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0'] 93 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ self.object_scale_mat @ object_bbox_min[:, None] 94 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ self.object_scale_mat @ object_bbox_max[:, None] 95 | self.object_bbox_min = object_bbox_min[:3, 0] 96 | self.object_bbox_max = object_bbox_max[:3, 0] 97 | 98 | # load images 99 | images_np = [] 100 | for im_name in self.images_lis: 101 | img = cv.imread(im_name) 102 | images_np.append(img) 103 | images_np = np.stack(images_np) / 256.0 104 | self.images = torch.from_numpy(images_np).float().permute(0, 3, 1, 2).to(self.device) # [n_images, 3, H, W] 105 | del images_np 106 | 107 | self.H, self.W = self.images.shape[2], self.images.shape[3] 108 | 109 | # load gt pose for validation 110 | gt_camera_dict = np.load(os.path.join(self.data_dir, 'cameras.npz')) 111 | gt_world_mats_np = [gt_camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 112 | self.gt_pose_all = [] 113 | for world_mat in gt_world_mats_np: 114 | P = world_mat @ scale_mat 115 | P = P[:3, :4] 116 | intrinsics, pose = load_K_Rt_from_P(None, P) 117 | self.gt_pose_all.append(torch.from_numpy(pose).float()) 118 | self.gt_pose_all = np.stack(self.gt_pose_all) # [n_images, 4, 4] 119 | 120 | # two_view files 121 | two_view_files = sorted((Path(self.data_dir)/self.match_folder).files('*.npz')) 122 | self.two_views_all = [] 123 | for f in two_view_files: 124 | self.two_views_all.append(np.load(f, allow_pickle=True)) 125 | 126 | print('Load data: End') 127 | 128 | @torch.no_grad() 129 | def gen_rays_at(self, img_idx, pose_net, resolution_level=1): 130 | """ 131 | Generate rays at world space from one camera. 132 | """ 133 | l = resolution_level 134 | tx = torch.linspace(0, self.W - 1, self.W // l) 135 | ty = torch.linspace(0, self.H - 1, self.H // l) 136 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 137 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 138 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 139 | 140 | pose = pose_net(img_idx) 141 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 142 | rays_v = torch.matmul(pose[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 143 | rays_o = pose[None, None, :3, 3].expand(rays_v.shape) # W, H, 3 144 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 145 | 146 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 147 | """ 148 | Interpolate pose between two cameras. 149 | """ 150 | l = resolution_level 151 | tx = torch.linspace(0, self.W - 1, self.W // l) 152 | ty = torch.linspace(0, self.H - 1, self.H // l) 153 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 154 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 155 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 156 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 157 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 158 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 159 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 160 | pose_0 = np.linalg.inv(pose_0) 161 | pose_1 = np.linalg.inv(pose_1) 162 | rot_0 = pose_0[:3, :3] 163 | rot_1 = pose_1[:3, :3] 164 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 165 | key_times = [0, 1] 166 | slerp = Slerp(key_times, rots) 167 | rot = slerp(ratio) 168 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 169 | pose = pose.astype(np.float32) 170 | pose[:3, :3] = rot.as_matrix() 171 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 172 | pose = np.linalg.inv(pose) 173 | rot = torch.from_numpy(pose[:3, :3]).to(self.device) 174 | trans = torch.from_numpy(pose[:3, 3]).to(self.device) 175 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 176 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 177 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 178 | 179 | def near_far_from_sphere(self, rays_o, rays_d): 180 | a = torch.sum(rays_d**2, dim=-1, keepdim=True) 181 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 182 | mid = 0.5 * (-b) / a 183 | near = mid - 1.0 184 | far = mid + 1.0 185 | return near, far 186 | 187 | def image_at(self, idx, resolution_level): 188 | img = cv.imread(self.images_lis[idx]) 189 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255) 190 | 191 | def gen_random_rays_at(self, img_idx, batch_size, pose): 192 | """ 193 | Generate random rays at world space from one camera. 194 | """ 195 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]).to(self.device) 196 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]).to(self.device) 197 | 198 | color = self.images[img_idx].permute(1, 2, 0)[(pixels_y, pixels_x)] # batch_size, 3 199 | 200 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3 201 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3 202 | 203 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3 204 | rays_v = torch.matmul(pose[None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3 205 | rays_o = pose[None, :3, 3].expand(rays_v.shape) # batch_size, 3 206 | 207 | return torch.cat([rays_o, rays_v, color], dim=-1) # mask # batch_size, 10 208 | 209 | def get_gt_pose(self): 210 | return self.gt_pose_all 211 | 212 | def sample_matches(self, img_idx, pose_net, num_pairs=20, max_matches=5000): 213 | # ref frame 214 | pose = pose_net(img_idx) 215 | 216 | two_view = self.two_views_all[img_idx] 217 | num_src = len(two_view['src_idx']) 218 | 219 | match_list = [] 220 | intrinsic_src_list = [] 221 | pose_src_list = [] 222 | for id in torch.randperm(num_src)[:num_pairs]: 223 | src_idx = two_view['src_idx'][id] 224 | 225 | match = two_view['match'][id] 226 | 227 | # downsample matches if there are too much 228 | if match.shape[0] > max_matches: 229 | match = match[np.random.randint(match.shape[0], size=max_matches)] 230 | 231 | match = torch.from_numpy(match).float().to(self.device) 232 | 233 | pose_src = pose_net(src_idx) 234 | 235 | match_list.append(match) 236 | pose_src_list.append(pose_src) 237 | intrinsic_src_list.append(self.intrinsics_all[src_idx]) 238 | 239 | intrinsic = self.intrinsics_all[img_idx] 240 | return intrinsic, pose, intrinsic_src_list, pose_src_list, match_list 241 | -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | 39 | def get_embedder(multires, input_dims=3): 40 | embed_kwargs = { 41 | 'include_input': False, 42 | 'input_dims': input_dims, 43 | 'max_freq_log2': multires-1, 44 | 'num_freqs': multires, 45 | 'log_sampling': True, 46 | 'periodic_fns': [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | def embed(x, eo=embedder_obj): return eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | 8 | def contract_inf(pts): 9 | # pts 10 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1) 11 | outside_mask = (pts_norm >= 1.0) 12 | 13 | norm_pts = pts.clone() 14 | norm_pts[outside_mask, :] = (2 - 1.0 / pts_norm[outside_mask, None]) * \ 15 | (pts[outside_mask, :] / pts_norm[outside_mask, None]) 16 | 17 | return norm_pts 18 | 19 | 20 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 21 | class SDFNetwork(nn.Module): 22 | def __init__(self, 23 | d_in, 24 | d_out, 25 | d_hidden, 26 | n_layers, 27 | skip_in=(4,), 28 | multires=0, 29 | bias=0.5, 30 | scale=1, 31 | geometric_init=True, 32 | weight_norm=True, 33 | inside_outside=False 34 | ): 35 | super(SDFNetwork, self).__init__() 36 | 37 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 38 | 39 | self.embed_fn_fine = None 40 | if multires > 0: 41 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 42 | self.embed_fn_fine = embed_fn 43 | dims[0] = dims[0]+input_ch 44 | 45 | print('sdf network dims:', dims) 46 | 47 | self.num_layers = len(dims) 48 | self.skip_in = skip_in 49 | self.bias = bias 50 | self.scale = scale 51 | self.multires = multires 52 | self.dims = dims 53 | self.weight_norm = weight_norm 54 | self.inside_outside = inside_outside 55 | self.geometric_init = geometric_init 56 | 57 | # use Parameter so it could be checkpointed 58 | self.progress = torch.nn.Parameter(torch.tensor(0.), requires_grad=False) 59 | 60 | for l in range(0, self.num_layers - 1): 61 | if l + 1 in self.skip_in: 62 | out_dim = dims[l + 1] - dims[0] 63 | else: 64 | out_dim = dims[l + 1] 65 | 66 | lin = nn.Linear(dims[l], out_dim) 67 | 68 | if geometric_init: 69 | if l == self.num_layers - 2: 70 | if not inside_outside: 71 | torch.nn.init.normal_(lin.weight, mean=np.sqrt( 72 | np.pi) / np.sqrt(dims[l]), std=0.0001) 73 | torch.nn.init.constant_(lin.bias, -bias) 74 | else: 75 | torch.nn.init.normal_( 76 | lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 77 | torch.nn.init.constant_(lin.bias, bias) 78 | elif multires > 0 and l == 0: 79 | torch.nn.init.constant_(lin.bias, 0.0) 80 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 81 | torch.nn.init.normal_( 82 | lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 83 | elif multires > 0 and l in self.skip_in: 84 | torch.nn.init.constant_(lin.bias, 0.0) 85 | torch.nn.init.normal_( 86 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 87 | torch.nn.init.constant_( 88 | lin.weight[:, -(dims[0] - 3):], 0.0) 89 | else: 90 | torch.nn.init.constant_(lin.bias, 0.0) 91 | torch.nn.init.normal_( 92 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 93 | 94 | if weight_norm: 95 | lin = nn.utils.weight_norm(lin) 96 | 97 | setattr(self, "lin" + str(l), lin) 98 | 99 | self.activation = nn.Softplus(beta=100) 100 | 101 | def forward(self, inputs): 102 | inputs = inputs * self.scale 103 | inputs = contract_inf(inputs) 104 | 105 | if self.embed_fn_fine is not None: 106 | embed = self.embed_fn_fine(inputs) 107 | inputs = torch.cat([inputs, embed], dim=-1) 108 | 109 | x = inputs 110 | for l in range(0, self.num_layers - 1): 111 | lin = getattr(self, "lin" + str(l)) 112 | 113 | if l in self.skip_in: 114 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 115 | 116 | x = lin(x) 117 | 118 | if l < self.num_layers - 2: 119 | x = self.activation(x) 120 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) 121 | 122 | def sdf(self, x): 123 | return self.forward(x)[:, :1] 124 | 125 | def gradient(self, x): 126 | x.requires_grad_(True) 127 | y = self.sdf(x) 128 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 129 | gradients = torch.autograd.grad( 130 | outputs=y, 131 | inputs=x, 132 | grad_outputs=d_output, 133 | create_graph=True, 134 | retain_graph=True, 135 | only_inputs=True)[0] 136 | return gradients.unsqueeze(1) 137 | 138 | 139 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 140 | class RenderingNetwork(nn.Module): 141 | def __init__(self, 142 | d_feature, 143 | mode, 144 | d_in, 145 | d_out, 146 | d_hidden, 147 | n_layers, 148 | weight_norm=True, 149 | multires_view=0, 150 | squeeze_out=True): 151 | super().__init__() 152 | 153 | self.mode = mode 154 | self.squeeze_out = squeeze_out 155 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] 156 | 157 | self.embedview_fn = None 158 | if multires_view > 0: 159 | embedview_fn, input_ch = get_embedder(multires_view) 160 | self.embedview_fn = embedview_fn 161 | dims[0] += (input_ch - 3) 162 | 163 | self.num_layers = len(dims) 164 | 165 | for l in range(0, self.num_layers - 1): 166 | out_dim = dims[l + 1] 167 | lin = nn.Linear(dims[l], out_dim) 168 | 169 | if weight_norm: 170 | lin = nn.utils.weight_norm(lin) 171 | 172 | setattr(self, "lin" + str(l), lin) 173 | 174 | self.relu = nn.ReLU() 175 | 176 | def forward(self, points, normals, view_dirs, feature_vectors): 177 | if self.embedview_fn is not None: 178 | view_dirs = self.embedview_fn(view_dirs) 179 | 180 | rendering_input = None 181 | 182 | if self.mode == 'idr': 183 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 184 | elif self.mode == 'no_view_dir': 185 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) 186 | elif self.mode == 'no_normal': 187 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) 188 | 189 | x = rendering_input 190 | 191 | for l in range(0, self.num_layers - 1): 192 | lin = getattr(self, "lin" + str(l)) 193 | 194 | x = lin(x) 195 | 196 | if l < self.num_layers - 2: 197 | x = self.relu(x) 198 | 199 | if self.squeeze_out: 200 | x = torch.sigmoid(x) 201 | return x 202 | 203 | 204 | class NeRF(nn.Module): 205 | def __init__(self, 206 | D=8, 207 | W=256, 208 | d_in=3, 209 | d_in_view=3, 210 | multires=0, 211 | multires_view=0, 212 | output_ch=4, 213 | skips=[4], 214 | use_viewdirs=False): 215 | super(NeRF, self).__init__() 216 | self.D = D 217 | self.W = W 218 | self.d_in = d_in 219 | self.d_in_view = d_in_view 220 | self.input_ch = 3 221 | self.input_ch_view = 3 222 | self.embed_fn = None 223 | self.embed_fn_view = None 224 | 225 | if multires > 0: 226 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 227 | self.embed_fn = embed_fn 228 | self.input_ch = input_ch 229 | 230 | if multires_view > 0: 231 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view) 232 | self.embed_fn_view = embed_fn_view 233 | self.input_ch_view = input_ch_view 234 | 235 | self.skips = skips 236 | self.use_viewdirs = use_viewdirs 237 | 238 | self.pts_linears = nn.ModuleList( 239 | [nn.Linear(self.input_ch, W)] + 240 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]) 241 | 242 | # Implementation according to the official code release 243 | # (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 244 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) 245 | 246 | # Implementation according to the paper 247 | # self.views_linears = nn.ModuleList( 248 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 249 | 250 | if use_viewdirs: 251 | self.feature_linear = nn.Linear(W, W) 252 | self.alpha_linear = nn.Linear(W, 1) 253 | self.rgb_linear = nn.Linear(W // 2, 3) 254 | else: 255 | self.output_linear = nn.Linear(W, output_ch) 256 | 257 | def forward(self, input_pts, input_views): 258 | 259 | input_pts = contract_inf(input_pts) 260 | 261 | if self.embed_fn is not None: 262 | input_pts = self.embed_fn(input_pts) 263 | if self.embed_fn_view is not None: 264 | input_views = self.embed_fn_view(input_views) 265 | 266 | h = input_pts 267 | for i, l in enumerate(self.pts_linears): 268 | h = self.pts_linears[i](h) 269 | h = F.relu(h) 270 | if i in self.skips: 271 | h = torch.cat([input_pts, h], -1) 272 | 273 | if self.use_viewdirs: 274 | alpha = self.alpha_linear(h) 275 | feature = self.feature_linear(h) 276 | h = torch.cat([feature, input_views], -1) 277 | 278 | for i, l in enumerate(self.views_linears): 279 | h = self.views_linears[i](h) 280 | h = F.relu(h) 281 | 282 | rgb = self.rgb_linear(h) 283 | return alpha, rgb 284 | else: 285 | assert False 286 | 287 | 288 | class SingleVarianceNetwork(nn.Module): 289 | def __init__(self, init_val): 290 | super(SingleVarianceNetwork, self).__init__() 291 | self.register_parameter( 292 | 'variance', nn.Parameter(torch.tensor(init_val))) 293 | 294 | def forward(self, x): 295 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0) 296 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from kornia.geometry.conversions import axis_angle_to_rotation_matrix 4 | from kornia.geometry.conversions import rotation_matrix_to_axis_angle 5 | 6 | 7 | def make_c2w(r, t): 8 | """ 9 | :param r: (3, ) axis-angle torch tensor 10 | :param t: (3, ) translation vector torch tensor 11 | :return: (4, 4) 12 | """ 13 | c2w = torch.eye(4).type_as(r) 14 | R = axis_angle_to_rotation_matrix(r.unsqueeze(0))[0] # (3, 3) 15 | c2w[:3, :3] = R 16 | c2w[:3, 3] = t 17 | 18 | return c2w 19 | 20 | 21 | class LearnPose(nn.Module): 22 | def __init__(self, num_cams, init_c2w): 23 | """ 24 | :param num_cams: 25 | :param init_c2w: (N, 4, 4) torch tensor 26 | """ 27 | super(LearnPose, self).__init__() 28 | 29 | self.num_cams = num_cams 30 | self.init_c2w = init_c2w.clone().detach() 31 | self.init_r = [] 32 | self.init_t = [] 33 | for idx in range(num_cams): 34 | r_init = rotation_matrix_to_axis_angle(self.init_c2w[idx][:3, :3].reshape([1, 3, 3])).reshape(-1) 35 | t_init = self.init_c2w[idx][:3, 3].reshape(-1) 36 | self.init_r.append(r_init) 37 | self.init_t.append(t_init) 38 | self.init_r = torch.stack(self.init_r) # nx3 39 | self.init_t = torch.stack(self.init_t) # nx3 40 | 41 | self.r = nn.Parameter(torch.zeros(size=(num_cams, 3), dtype=torch.float32), requires_grad=True) # (N, 3) 42 | self.t = nn.Parameter(torch.zeros(size=(num_cams, 3), dtype=torch.float32), requires_grad=True) # (N, 3) 43 | 44 | def get_init_pose(self, cam_id): 45 | return self.init_c2w[cam_id] 46 | 47 | def forward(self, cam_id): 48 | dr = self.r[cam_id] # (3, ) axis-angle 49 | dt = self.t[cam_id] # (3, ) 50 | 51 | r = dr + self.init_r[cam_id] 52 | t = dt + self.init_t[cam_id] 53 | c2w = make_c2w(r, t) # (4, 4) 54 | 55 | return c2w 56 | 57 | 58 | class PoRF(nn.Module): 59 | def __init__(self, num_cams, init_c2w=None, layers=2, mode='porf', scale=1e-6): 60 | """ 61 | :param num_cams: 62 | :param init_c2w: (N, 4, 4) torch tensor 63 | """ 64 | super(PoRF, self).__init__() 65 | self.num_cams = num_cams 66 | self.scale = scale 67 | self.mode = mode 68 | 69 | if init_c2w is not None: 70 | self.init_c2w = init_c2w.clone().detach() 71 | self.init_r = [] 72 | self.init_t = [] 73 | for idx in range(num_cams): 74 | r_init = rotation_matrix_to_axis_angle(self.init_c2w[idx][:3, :3].reshape([1, 3, 3])).reshape(-1) 75 | t_init = self.init_c2w[idx][:3, 3].reshape(-1) 76 | self.init_r.append(r_init) 77 | self.init_t.append(t_init) 78 | self.init_r = torch.stack(self.init_r) # nx3 79 | self.init_t = torch.stack(self.init_t) # nx3 80 | else: 81 | self.init_r = torch.zeros(size=(num_cams, 3), dtype=torch.float32) 82 | self.init_t = torch.zeros(size=(num_cams, 3), dtype=torch.float32) 83 | 84 | d_in = 7 # 1 cam_id + 6 pose 85 | 86 | activation_func = nn.ELU(inplace=True) 87 | 88 | self.layers = nn.Sequential(nn.Linear(d_in, 256), 89 | activation_func) 90 | for i in range(layers): 91 | self.layers.append(nn.Sequential(nn.Linear(256, 256), 92 | activation_func)) 93 | self.layers.append(nn.Linear(256, 6)) 94 | 95 | print('init_r range: ', [self.init_r.min(), self.init_r.max()]) 96 | print('init_t range: ', [self.init_t.min(), self.init_t.max()]) 97 | 98 | def get_init_pose(self, cam_id): 99 | return self.init_c2w[cam_id] 100 | 101 | def forward(self, cam_id): 102 | cam_id_tensor = torch.tensor([cam_id]).type_as(self.init_c2w) 103 | cam_id_tensor = (cam_id_tensor / self.num_cams) * 2 - 1 # range [-1, +1] 104 | 105 | init_r = self.init_r[cam_id] 106 | init_t = self.init_t[cam_id] 107 | 108 | if self.mode == 'porf': 109 | inputs = torch.cat([cam_id_tensor, init_r, init_t], dim=-1) 110 | elif self.mode == 'index_only': 111 | inputs = torch.cat([cam_id_tensor, torch.zeros_like(init_r), torch.zeros_like(init_t)], dim=-1) 112 | elif self.mode == 'pose_only': 113 | inputs = torch.cat([torch.zeros_like(cam_id_tensor), init_r, init_t], dim=-1) 114 | 115 | out = self.layers(inputs) * self.scale 116 | 117 | # cat pose 118 | r = out[:3] + self.init_r[cam_id] 119 | t = out[3:] + self.init_t[cam_id] 120 | c2w = make_c2w(r, t) # (4, 4) 121 | 122 | return c2w 123 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import mcubes 5 | 6 | 7 | def eff_distloss(w, t_vals, t_near, t_far): 8 | ''' 9 | Efficient O(N) realization of distortion loss. 10 | There are B rays each with N sampled points. 11 | w: Float tensor in shape [B,N]. Volume rendering weights of each point. 12 | m: Float tensor in shape [B,N]. Midpoint distance to camera of each point. 13 | interval: Scalar or float tensor in shape [B,N]. The query interval of each point. 14 | ''' 15 | # fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x) 16 | # s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)] 17 | # t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near) 18 | 19 | # normalize the distance 20 | s_vals = t_vals.clone().detach() 21 | 22 | t_near = t_near.expand_as(t_vals) 23 | t_far = t_far.expand_as(t_vals) 24 | 25 | # scale to 0 - far 26 | mask = t_vals < t_far 27 | s_vals[mask] = ((t_vals[mask] - t_near[mask]) / 28 | (t_far[mask] - t_near[mask]) * 0.5).clamp(0, 0.5) 29 | s_vals[~mask] = (1.0 - 0.5 / t_vals[~mask]).clamp(0.5, 1) 30 | s_vals = s_vals.clamp(0, 1) 31 | 32 | interval = s_vals[:, 1:] - s_vals[:, :-1] 33 | interval = torch.cat([interval, interval[:, -1, None]], dim=-1) 34 | m = s_vals + interval * 0.5 35 | 36 | loss_uni = (1/3) * (interval * w.pow(2)).sum(dim=-1).mean() 37 | wm = (w * m) 38 | w_cumsum = w.cumsum(dim=-1) 39 | wm_cumsum = wm.cumsum(dim=-1) 40 | loss_bi_0 = wm[..., 1:] * w_cumsum[..., :-1] 41 | loss_bi_1 = w[..., 1:] * wm_cumsum[..., :-1] 42 | loss_bi = 2 * (loss_bi_0 - loss_bi_1).sum(dim=-1).mean() 43 | return loss_bi + loss_uni 44 | 45 | 46 | def original_distloss(w, t_vals, t_near, t_far): 47 | ''' 48 | Original O(N^2) realization of distortion loss. 49 | There are B rays each with N sampled points. 50 | w: Float tensor in shape [B,N]. Volume rendering weights of each point. 51 | m: Float tensor in shape [B,N]. Midpoint distance to camera of each point. 52 | interval: Scalar or float tensor in shape [B,N]. The query interval of each point. 53 | ''' 54 | 55 | # normalize the distance 56 | s_vals = t_vals.clone().detach() 57 | 58 | t_near = t_near.expand_as(t_vals) 59 | t_far = t_far.expand_as(t_vals) 60 | 61 | # scale to 0 - far 62 | mask = t_vals < t_far 63 | s_vals[mask] = ((t_vals[mask] - t_near[mask]) / 64 | (t_far[mask] - t_near[mask]) * 0.5).clamp(0, 0.5) 65 | s_vals[~mask] = (1.0 - 0.5 / t_vals[~mask]).clamp(0.5, 1) 66 | s_vals = s_vals.clamp(0, 1) 67 | 68 | interval = s_vals[:, 1:] - s_vals[:, :-1] 69 | interval = torch.cat([interval, interval[:, -1, None]], dim=-1) 70 | m = s_vals + interval * 0.5 71 | 72 | loss_uni = (1/3) * (interval * w.pow(2)).sum(-1).mean() 73 | ww = w.unsqueeze(-1) * w.unsqueeze(-2) # [B,N,N] 74 | mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [B,N,N] 75 | loss_bi = (ww * mm).sum((-1, -2)).mean() 76 | return loss_uni + loss_bi 77 | 78 | 79 | def extract_fields(bound_min, bound_max, resolution, query_func): 80 | N = 64 81 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 82 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 83 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 84 | 85 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 86 | with torch.no_grad(): 87 | for xi, xs in enumerate(X): 88 | for yi, ys in enumerate(Y): 89 | for zi, zs in enumerate(Z): 90 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 91 | pts = torch.cat( 92 | [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) 93 | val = query_func(pts).reshape( 94 | len(xs), len(ys), len(zs)).detach().cpu().numpy() 95 | u[xi * N: xi * N + len(xs), yi * N: yi * N + 96 | len(ys), zi * N: zi * N + len(zs)] = val 97 | return u 98 | 99 | 100 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): 101 | print('threshold: {}'.format(threshold)) 102 | u = extract_fields(bound_min, bound_max, resolution, query_func) 103 | vertices, triangles = mcubes.marching_cubes(u, threshold) 104 | b_max_np = bound_max 105 | b_min_np = bound_min 106 | 107 | vertices = vertices / (resolution - 1.0) * \ 108 | (b_max_np - b_min_np)[None, :] + b_min_np[None, :] 109 | return vertices, triangles 110 | 111 | 112 | def sample_pdf(bins, weights, n_samples, det=False): 113 | # This implementation is from NeRF 114 | # Get pdf 115 | weights = weights + 1e-5 # prevent nans 116 | pdf = weights / torch.sum(weights, -1, keepdim=True) 117 | cdf = torch.cumsum(pdf, -1) 118 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 119 | # Take uniform samples 120 | if det: 121 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 122 | 0.5 / n_samples, steps=n_samples) 123 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 124 | else: 125 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]) 126 | 127 | # Invert CDF 128 | u = u.contiguous() 129 | inds = torch.searchsorted(cdf, u, right=True) 130 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 131 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 132 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 133 | 134 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 135 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 136 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 137 | 138 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 139 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 140 | t = (u - cdf_g[..., 0]) / denom 141 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 142 | 143 | return samples 144 | 145 | 146 | class NeuSRenderer: 147 | def __init__(self, 148 | sdf_network, 149 | deviation_network, 150 | render_network, 151 | n_samples, 152 | n_importance, 153 | n_outside, 154 | up_sample_steps, 155 | perturb): 156 | self.sdf_network = sdf_network 157 | self.deviation_network = deviation_network 158 | self.render_network = render_network 159 | self.n_samples = n_samples 160 | self.n_importance = n_importance 161 | self.n_outside = n_outside 162 | self.up_sample_steps = up_sample_steps 163 | self.perturb = perturb 164 | 165 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s): 166 | """ 167 | Up sampling give a fixed inv_s 168 | """ 169 | batch_size, n_samples = z_vals.shape 170 | pts = rays_o[:, None, :] + rays_d[:, None, :] * \ 171 | z_vals[..., :, None] # n_rays, n_samples, 3 172 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) 173 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0) 174 | sdf = sdf.reshape(batch_size, n_samples) 175 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] 176 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] 177 | mid_sdf = (prev_sdf + next_sdf) * 0.5 178 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) 179 | 180 | # ---------------------------------------------------------------------------------------------------------- 181 | # Use min value of [ cos, prev_cos ] 182 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more 183 | # robust when meeting situations like below: 184 | # 185 | # SDF 186 | # ^ 187 | # |\ -----x----... 188 | # | \ / 189 | # | x x 190 | # |---\----/-------------> 0 level 191 | # | \ / 192 | # | \/ 193 | # | 194 | # ---------------------------------------------------------------------------------------------------------- 195 | prev_cos_val = torch.cat( 196 | [torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1) 197 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) 198 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) 199 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere 200 | 201 | dist = (next_z_vals - prev_z_vals) 202 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 203 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5 204 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) 205 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s) 206 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) 207 | weights = alpha * torch.cumprod( 208 | torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 209 | 210 | z_samples = sample_pdf( 211 | z_vals, weights, n_importance, det=True).detach() 212 | return z_samples 213 | 214 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): 215 | batch_size, n_samples = z_vals.shape 216 | _, n_importance = new_z_vals.shape 217 | pts = rays_o[:, None, :] + \ 218 | rays_d[:, None, :] * new_z_vals[..., :, None] 219 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1) 220 | z_vals, index = torch.sort(z_vals, dim=-1) 221 | 222 | if not last: 223 | new_sdf = self.sdf_network.sdf( 224 | pts.reshape(-1, 3)).reshape(batch_size, n_importance) 225 | sdf = torch.cat([sdf, new_sdf], dim=-1) 226 | xx = torch.arange(batch_size)[:, None].expand( 227 | batch_size, n_samples + n_importance).reshape(-1) 228 | index = index.reshape(-1) 229 | sdf = sdf[(xx, index)].reshape( 230 | batch_size, n_samples + n_importance) 231 | 232 | return z_vals, sdf 233 | 234 | def render_core(self, 235 | rays_o, 236 | rays_d, 237 | z_vals, 238 | sample_dist, 239 | near, 240 | far, 241 | background_rgb=None, 242 | cos_anneal_ratio=0.0): 243 | batch_size, n_samples = z_vals.shape 244 | 245 | # Section length 246 | dists = z_vals[..., 1:] - z_vals[..., :-1] 247 | dists = torch.cat([dists, torch.Tensor( 248 | [sample_dist]).expand(dists[..., :1].shape)], -1) 249 | mid_z_vals = z_vals + dists * 0.5 250 | 251 | # Section midpoints 252 | pts = rays_o[:, None, :] + rays_d[:, None, :] * \ 253 | mid_z_vals[..., :, None] # n_rays, n_samples, 3 254 | dirs = rays_d[:, None, :].expand(pts.shape) 255 | 256 | pts = pts.reshape(-1, 3) 257 | dirs = dirs.reshape(-1, 3) 258 | 259 | # determin inside and outside pts 260 | pts_norm = torch.linalg.norm( 261 | pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) 262 | inside_sphere = (pts_norm < 1.0).float().detach() 263 | relax_inside_sphere = (pts_norm < 2.0).float().detach() 264 | 265 | sdf_nn_output = self.sdf_network(pts) 266 | sdf = sdf_nn_output[:, :1] 267 | sdf_features = sdf_nn_output[:, 1:] 268 | 269 | gradients = self.sdf_network.gradient(pts).squeeze() 270 | 271 | inv_s = self.deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 272 | inv_s = inv_s.expand(batch_size * n_samples, 1) 273 | 274 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 275 | 276 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 277 | # the cos value "not dead" at the beginning training iterations, for better convergence. 278 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + 279 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive 280 | 281 | # Estimate signed distances at section points 282 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 283 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 284 | 285 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 286 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 287 | 288 | p = prev_cdf - next_cdf 289 | c = prev_cdf 290 | 291 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, 292 | n_samples).clip(0.0, 1.0) 293 | weights = alpha * \ 294 | torch.cumprod( 295 | torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 296 | weights_sum = weights.sum(dim=-1, keepdim=True) 297 | 298 | sampled_color = self.render_network(pts, 299 | gradients, 300 | dirs, 301 | sdf_features).reshape(batch_size, n_samples, 3) 302 | 303 | color = (sampled_color * weights[:, :, None]).sum(dim=1) 304 | 305 | if background_rgb is not None: # Fixed background, usually black 306 | color = color + background_rgb * (1.0 - weights_sum) 307 | 308 | # Eikonal loss 309 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, 310 | dim=-1) - 1.0) ** 2 311 | gradient_error = (relax_inside_sphere * gradient_error).sum() / \ 312 | (relax_inside_sphere.sum() + 1e-5) 313 | 314 | # dist loss 315 | dist_loss = eff_distloss(weights, mid_z_vals, near, far) 316 | 317 | return { 318 | 'color': color, 319 | 'pts': pts.reshape(batch_size, n_samples, 3), 320 | 'sdf': sdf, 321 | 'dists': dists, 322 | 'dist_loss': dist_loss, 323 | 'gradients': gradients.reshape(batch_size, n_samples, 3), 324 | 's_val': 1.0 / inv_s, 325 | 'mid_z_vals': mid_z_vals, 326 | 'weights': weights, 327 | 'cdf': c.reshape(batch_size, -1), 328 | 'gradient_error': gradient_error, 329 | 'inside_sphere': inside_sphere 330 | } 331 | 332 | def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0): 333 | batch_size = len(rays_o) 334 | # Assuming the region of interest is a unit sphere 335 | sample_dist = 2.0 / self.n_samples 336 | z_vals = torch.linspace(0.0, 1.0, self.n_samples) 337 | z_vals = near + (far - near) * z_vals[None, :] 338 | 339 | z_vals_outside = None 340 | if self.n_outside > 0: 341 | z_vals_outside = torch.linspace( 342 | 1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) 343 | 344 | n_samples = self.n_samples 345 | perturb = self.perturb 346 | 347 | if perturb_overwrite >= 0: 348 | perturb = perturb_overwrite 349 | if perturb > 0: 350 | t_rand = (torch.rand([batch_size, 1]) - 0.5) 351 | z_vals = z_vals + t_rand * 2.0 / self.n_samples 352 | 353 | if self.n_outside > 0: 354 | mids = .5 * (z_vals_outside[..., 1:] + 355 | z_vals_outside[..., :-1]) 356 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) 357 | lower = torch.cat([z_vals_outside[..., :1], mids], -1) 358 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) 359 | z_vals_outside = lower[None, :] + \ 360 | (upper - lower)[None, :] * t_rand 361 | 362 | if self.n_outside > 0: 363 | z_vals_outside = far / \ 364 | torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples 365 | 366 | # Up sample 367 | if self.n_importance > 0: 368 | with torch.no_grad(): 369 | pts = rays_o[:, None, :] + \ 370 | rays_d[:, None, :] * z_vals[..., :, None] 371 | sdf = self.sdf_network.sdf( 372 | pts.reshape(-1, 3)).reshape(batch_size, n_samples) 373 | 374 | for i in range(self.up_sample_steps): 375 | new_z_vals = self.up_sample(rays_o, 376 | rays_d, 377 | z_vals, 378 | sdf, 379 | self.n_importance // self.up_sample_steps, 380 | 64 * 2**i) 381 | z_vals, sdf = self.cat_z_vals(rays_o, 382 | rays_d, 383 | z_vals, 384 | new_z_vals, 385 | sdf, 386 | last=(i + 1 == self.up_sample_steps)) 387 | 388 | n_samples = n_samples + self.n_importance 389 | 390 | # Background model 391 | if self.n_outside > 0: 392 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) 393 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) 394 | z_vals = z_vals_feed 395 | 396 | n_samples = n_samples + self.n_outside 397 | 398 | # Render core 399 | render_core_out = self.render_core(rays_o, 400 | rays_d, 401 | z_vals, 402 | sample_dist, 403 | near, 404 | far, 405 | background_rgb=background_rgb, 406 | cos_anneal_ratio=cos_anneal_ratio) 407 | 408 | color = render_core_out['color'] 409 | weights = render_core_out['weights'] 410 | pts = render_core_out['pts'] 411 | weights_sum = weights.sum(dim=-1, keepdim=True) 412 | gradients = render_core_out['gradients'] 413 | s_val = render_core_out['s_val'].reshape(batch_size, n_samples) 414 | 415 | surface_pts = torch.sum(weights[:, :, None] * pts, dim=1) # B 3 416 | 417 | return { 418 | 'color': color, 419 | 'surface_pts': surface_pts, 420 | 'pts': pts, 421 | 'dist_loss': render_core_out['dist_loss'], 422 | 's_val': s_val.mean(dim=-1, keepdim=True), 423 | 'cdf': render_core_out['cdf'], 424 | 'weight_sum': weights_sum, 425 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0], 426 | 'gradients': gradients, 427 | 'weights': weights, 428 | 'gradient_error': render_core_out['gradient_error'], 429 | 'inside_sphere': render_core_out['inside_sphere'] 430 | } 431 | 432 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0): 433 | return extract_geometry(bound_min, 434 | bound_max, 435 | resolution=resolution, 436 | threshold=threshold, 437 | query_func=lambda pts: -self.sdf_network.sdf(pts)) 438 | -------------------------------------------------------------------------------- /preprocess_data/export_colmap_matches.py: -------------------------------------------------------------------------------- 1 | from path import Path 2 | import numpy as np 3 | import sqlite3 4 | import os 5 | 6 | 7 | def pair_id_to_image_ids(pair_id): 8 | image_id2 = pair_id % 2147483647 9 | image_id1 = (pair_id - image_id2) / 2147483647 10 | return image_id1, image_id2 11 | 12 | 13 | def image_ids_to_pair(image_id1, image_id2): 14 | pair_id = image_id2 + 2147483647 * image_id1 15 | return pair_id 16 | 17 | 18 | def get_keypoints(cursor, image_id): 19 | cursor.execute("SELECT * FROM keypoints WHERE image_id = ?;", (image_id,)) 20 | image_idx, n_rows, n_columns, raw_data = cursor.fetchone() 21 | kypnts = np.frombuffer(raw_data, dtype=np.float32).reshape(n_rows, n_columns).copy() 22 | kypnts = kypnts[:, 0:2] 23 | return kypnts 24 | 25 | 26 | def process_one_scene(scene_dir): 27 | 28 | filename_db = Path(scene_dir)/'database.db' 29 | outdir = scene_dir/'colmap_matches' 30 | print("Opening database: " + filename_db) 31 | 32 | if not os.path.exists(filename_db): 33 | print('Error db does not exist!') 34 | exit() 35 | 36 | if not os.path.exists(outdir): 37 | os.mkdir(outdir) 38 | 39 | print(f'Clean old matches in {outdir}') 40 | for f in Path(outdir).files('*'): 41 | os.remove(f) 42 | 43 | connection = sqlite3.connect(filename_db) 44 | cursor = connection.cursor() 45 | 46 | list_image_ids = [] 47 | img_ids_to_names_dict = {} 48 | cursor.execute('SELECT image_id, name, cameras.width, cameras.height FROM images LEFT JOIN cameras ON images.camera_id == cameras.camera_id;') 49 | for row in cursor: 50 | image_idx, name, width, height = row 51 | list_image_ids.append(image_idx) 52 | img_ids_to_names_dict[image_idx] = name 53 | 54 | num_image_ids = len(list_image_ids) 55 | 56 | # Iterate over entries in the two-view geometry table 57 | cursor.execute('SELECT pair_id, rows, cols, data FROM two_view_geometries;') 58 | all_matches = {} 59 | for row in cursor: 60 | pair_id = row[0] 61 | rows = row[1] 62 | cols = row[2] 63 | raw_data = row[3] 64 | if (rows < 5): 65 | continue 66 | 67 | matches = np.frombuffer(raw_data, dtype=np.uint32).reshape(rows, cols) 68 | 69 | if matches.shape[0] < 5: 70 | continue 71 | 72 | all_matches[pair_id] = matches 73 | 74 | for key in all_matches: 75 | pair_id = key 76 | matches = all_matches[key] 77 | 78 | # # skip if too few matches are given 79 | # if matches.shape[0] < 300: 80 | # continue 81 | 82 | id1, id2 = pair_id_to_image_ids(pair_id) 83 | image_name1 = img_ids_to_names_dict[id1] 84 | image_name2 = img_ids_to_names_dict[id2] 85 | 86 | keys1 = get_keypoints(cursor, id1) 87 | keys2 = get_keypoints(cursor, id2) 88 | 89 | match_positions = np.empty([matches.shape[0], 4]) 90 | for i in range(0, matches.shape[0]): 91 | match_positions[i, :] = np.array([keys1[matches[i, 0]][0], keys1[matches[i, 0]][1], keys2[matches[i, 1]][0], keys2[matches[i, 1]][1]]) 92 | 93 | # print(match_positions.shape) 94 | # outfile = os.path.join(outdir, image_name1.split("/")[0].split(".jpg")[0] + "_" + image_name2.split("/")[0].split(".jpg")[0] + ".txt") 95 | outfile = os.path.join(outdir, '{:06d}_{:06d}.txt'.format(int(id1), int(id2))) 96 | 97 | np.savetxt(outfile, match_positions, delimiter=' ') 98 | 99 | # reverse and save 100 | match_positions_reverse = np.concatenate([match_positions[:, 2:4], match_positions[:, 0:2]], axis=1) 101 | # outfile = os.path.join(outdir, image_name2.split("/")[0].split(".jpg")[0] + "_" + image_name1.split("/")[0].split(".jpg")[0] + ".txt") 102 | outfile = os.path.join(outdir, '{:06d}_{:06d}.txt'.format(int(id2), int(id1))) 103 | np.savetxt(outfile, match_positions_reverse, delimiter=' ') 104 | 105 | cursor.close() 106 | connection.close() 107 | 108 | for idx in range(num_image_ids): 109 | 110 | two_view = {} 111 | two_view["src_idx"] = [] 112 | two_view["match"] = [] 113 | 114 | files = sorted(outdir.files('{:06d}_*.txt'.format(idx+1))) 115 | for f in files: 116 | j = int(os.path.basename(f)[7:13])-1 117 | 118 | one_pair = np.loadtxt(f) 119 | 120 | two_view["src_idx"].append(j) 121 | two_view["match"].append(one_pair) 122 | 123 | np.savez(outdir/'{:06d}.npz'.format(idx), **two_view) 124 | 125 | 126 | if __name__ == "__main__": 127 | 128 | # data_dir = Path('./porf_data/dtu') 129 | data_dir = Path('./porf_data/mobilebrick_test') 130 | scene_dirs = sorted(data_dir.dirs()) 131 | 132 | for scene in scene_dirs: 133 | process_one_scene(scene_dir=scene) 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trimesh 2 | tensorflow==2.9.1 3 | pyhocon==0.3.57 4 | kornia 5 | opencv_python 6 | tqdm 7 | scipy 8 | path 9 | matplotlib 10 | PyMCubes -------------------------------------------------------------------------------- /scripts/train_sift_dtu.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # for SCENE in 'scan37' 'scan65' 'scan69' 'scan83' 'scan97' 'scan105' 'scan110' 'scan118' 'scan106' 'scan114' 'scan122' 'scan24' 'scan40' 'scan55' 'scan63'; 4 | 5 | for SCENE in 'scan37'; 6 | do 7 | python train.py \ 8 | --mode train --conf confs/dtu_sift_porf.conf \ 9 | --case $SCENE 10 | 11 | # python train.py \ 12 | # --mode train --conf confs/dtu_sift_pose.conf \ 13 | # --case $SCENE 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_sift_mobilebrick.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | 4 | # for SCENE in 'aston' 'audi' 'beetles' 'big_ben' 'boat' 'bridge' 'cabin' 'convertible' 'ferrari' 'jeep' 'castle' 'london_bus' 'colosseum' 'camera' 'motorcycle' 'porsche' 'satellite' 'space_shuttle'; 5 | for SCENE in 'aston'; 6 | do 7 | # python train.py \ 8 | # --mode train --conf confs/mobilebrick_sift_porf.conf \ 9 | # --case $SCENE 10 | 11 | python train.py \ 12 | --mode train --conf confs/mobilebrick_sift_pose.conf \ 13 | --case $SCENE 14 | 15 | done -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | import numpy as np 5 | import cv2 as cv 6 | import trimesh 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | from tqdm import tqdm 11 | from pyhocon import ConfigFactory 12 | from models.dataset import Dataset 13 | from models.fields import SDFNetwork, SingleVarianceNetwork, RenderingNetwork 14 | from models.renderer import NeuSRenderer 15 | from models.networks import LearnPose, PoRF 16 | import utils 17 | 18 | print(torch.__version__) 19 | 20 | # torch.autograd.set_detect_anomaly(True) 21 | 22 | 23 | class PoseRunner: 24 | def __init__(self, conf_path, mode='train', case='CASE_NAME'): 25 | self.device = torch.device('cuda') 26 | 27 | # Configuration 28 | self.conf_path = conf_path 29 | f = open(self.conf_path) 30 | conf_text = f.read() 31 | conf_text = conf_text.replace('CASE_NAME', case) 32 | f.close() 33 | 34 | self.conf = ConfigFactory.parse_string(conf_text) 35 | self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case) 36 | self.base_exp_dir = self.conf['general.base_exp_dir'] 37 | os.makedirs(self.base_exp_dir, exist_ok=True) 38 | self.dataset = Dataset(self.conf['dataset']) 39 | self.iter_step = 0 40 | 41 | # Training parameters 42 | self.end_iter = self.conf['train.pose_end_iter'] 43 | self.val_freq = self.conf.get_int('train.pose_val_freq') 44 | self.report_freq = self.conf.get_int('train.report_freq') 45 | self.batch_size = self.conf.get_int('train.batch_size') 46 | self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') 47 | self.learning_rate = self.conf.get_float('train.learning_rate') 48 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha') 49 | self.pose_learning_rate = self.conf.get_float('train.pose_learning_rate') 50 | self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') 51 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0) 52 | self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) 53 | 54 | # porf parameters 55 | self.use_porf = self.conf.get_bool('train.use_porf') 56 | self.inlier_threshold = self.conf.get_float('train.inlier_threshold') 57 | self.num_pairs = self.conf.get_int('train.num_pairs') 58 | 59 | # Weights 60 | self.color_loss_weight = self.conf.get_float('train.color_loss_weight') 61 | self.igr_weight = self.conf.get_float('train.igr_weight') 62 | self.epipolar_loss_weight = self.conf.get_float('train.epipolar_loss_weight') 63 | self.mode = mode 64 | 65 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'pose_logs')) 66 | 67 | # Networks 68 | params_to_train = [] 69 | self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device) 70 | self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) 71 | self.render_network = RenderingNetwork(**self.conf['model.render_network']).to(self.device) 72 | params_to_train += list(self.sdf_network.parameters()) 73 | params_to_train += list(self.deviation_network.parameters()) 74 | params_to_train += list(self.render_network.parameters()) 75 | 76 | optim_params = [{'params': params_to_train, 'lr': self.learning_rate}] 77 | self.optimizer = torch.optim.Adam(optim_params) 78 | 79 | self.renderer = NeuSRenderer(self.sdf_network, 80 | self.deviation_network, 81 | self.render_network, 82 | **self.conf['model.neus_renderer']) 83 | 84 | # # pose optimization 85 | if self.use_porf: 86 | self.pose_param_net = PoRF( 87 | self.dataset.n_images, 88 | init_c2w=self.dataset.pose_all, 89 | scale=self.conf.get_float('train.scale') 90 | ).to(self.device) 91 | else: 92 | self.pose_param_net = LearnPose( 93 | self.dataset.n_images, 94 | init_c2w=self.dataset.pose_all 95 | ).to(self.device) 96 | 97 | self.optimizer_pose = torch.optim.Adam(self.pose_param_net.parameters(), 98 | lr=self.pose_learning_rate) 99 | 100 | # validate pose for initial pose err analysis 101 | if self.iter_step == 0: 102 | self.validate_pose(initial_pose=True) 103 | 104 | def train(self): 105 | self.update_learning_rate() 106 | res_step = self.end_iter - self.iter_step 107 | 108 | for iter_i in tqdm(range(res_step)): 109 | 110 | self.update_image_index() 111 | 112 | intrinsic, pose, intrinsic_src_list, pose_src_list, match_list = self.dataset.sample_matches(self.img_idx, 113 | self.pose_param_net) 114 | 115 | P_src_list = [] 116 | for cam, p in zip(intrinsic_src_list, pose_src_list): 117 | P_src_list.append(utils.compute_P_from_KT(cam, p)) 118 | 119 | # match 120 | avg_inlier_rate, epipolar_loss = utils.evaluate_pose(intrinsic, 121 | pose, 122 | P_src_list, 123 | match_list, 124 | self.num_pairs, 125 | self.inlier_threshold) 126 | 127 | # neus 128 | data = self.dataset.gen_random_rays_at(self.img_idx, 129 | self.batch_size, 130 | pose 131 | ) 132 | 133 | rays_o, rays_d = data[:, :3], data[:, 3: 6] 134 | true_rgb = data[:, 6: 9] 135 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d) 136 | 137 | background_rgb = None 138 | if self.use_white_bkgd: 139 | background_rgb = torch.ones([1, 3]) 140 | 141 | render_out = self.renderer.render(rays_o, 142 | rays_d, 143 | near, 144 | far, 145 | background_rgb=background_rgb, 146 | cos_anneal_ratio=self.get_cos_anneal_ratio()) 147 | 148 | color = render_out['color'] 149 | s_val = render_out['s_val'] 150 | cdf = render_out['cdf'] 151 | gradient_error = render_out['gradient_error'] 152 | weight_max = render_out['weight_max'] 153 | dist_loss = render_out['dist_loss'] 154 | 155 | mask = torch.ones_like(color[:, :1]) 156 | mask_sum = mask.sum() 157 | 158 | color_error = (color - true_rgb) * mask 159 | color_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction='sum') / mask_sum 160 | psnr = 20.0 * torch.log10(1.0 / (((color - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt()) 161 | 162 | eikonal_loss = gradient_error 163 | 164 | loss = color_loss * self.color_loss_weight +\ 165 | eikonal_loss * self.igr_weight +\ 166 | dist_loss * 0.001 +\ 167 | epipolar_loss * self.epipolar_loss_weight 168 | 169 | self.optimizer.zero_grad() 170 | self.optimizer_pose.zero_grad() 171 | loss.backward() 172 | 173 | self.optimizer.step() 174 | self.optimizer_pose.step() 175 | 176 | self.iter_step += 1 177 | 178 | self.writer.add_scalar('Loss/loss', loss, self.iter_step) 179 | self.writer.add_scalar('Loss/color_loss', color_loss, self.iter_step) 180 | self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step) 181 | self.writer.add_scalar('Loss/dist_loss', dist_loss, self.iter_step) 182 | self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step) 183 | self.writer.add_scalar('Statistics/cdf', cdf[:, :1].mean(), self.iter_step) 184 | self.writer.add_scalar('Statistics/weight_max', weight_max.mean(), self.iter_step) 185 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step) 186 | self.writer.add_scalar('Statistics/inlier_rate', avg_inlier_rate, self.iter_step) 187 | self.writer.add_scalar('Loss/epipolar_loss', epipolar_loss, self.iter_step) 188 | 189 | # check pose grad for debug if not using porf 190 | if not self.use_porf: 191 | r_grad_norms = torch.linalg.norm(self.pose_param_net.r.grad, 192 | dim=-1, 193 | keepdim=True).expand_as(self.pose_param_net.r.grad) 194 | 195 | t_grad_norms = torch.linalg.norm(self.pose_param_net.t.grad, 196 | dim=-1, 197 | keepdim=True).expand_as(self.pose_param_net.t.grad) 198 | r_grad = r_grad_norms[r_grad_norms > 0].mean() 199 | t_grad = t_grad_norms[t_grad_norms > 0].mean() 200 | 201 | self.writer.add_scalar('Statistics/r_grad', r_grad, self.iter_step) 202 | self.writer.add_scalar('Statistics/t_grad', t_grad, self.iter_step) 203 | 204 | if self.iter_step % self.report_freq == 0: 205 | print(self.base_exp_dir) 206 | print('iter:{:8>d} loss = {} lr={}'.format(self.iter_step, loss, self.optimizer.param_groups[0]['lr'])) 207 | 208 | if self.iter_step % self.val_freq == 0: 209 | self.validate_pose() 210 | 211 | self.update_learning_rate() 212 | 213 | self.save_checkpoint() 214 | self.validate_image() 215 | self.validate_mesh() 216 | 217 | def update_image_index(self): 218 | self.img_idx = np.random.randint(self.dataset.n_images) 219 | 220 | def get_cos_anneal_ratio(self): 221 | if self.anneal_end == 0.0: 222 | return 1.0 223 | else: 224 | return np.min([1.0, self.iter_step / self.anneal_end]) 225 | 226 | def update_learning_rate(self): 227 | if self.iter_step < self.warm_up_end: 228 | learning_factor = self.iter_step / self.warm_up_end 229 | else: 230 | alpha = self.learning_rate_alpha 231 | progress = (self.iter_step - self.warm_up_end) / \ 232 | (self.end_iter - self.warm_up_end) 233 | learning_factor = (np.cos(np.pi * progress) + 234 | 1.0) * 0.5 * (1 - alpha) + alpha 235 | 236 | for g in self.optimizer.param_groups: 237 | g['lr'] = self.learning_rate * learning_factor 238 | 239 | def save_checkpoint(self): 240 | checkpoint = { 241 | 'sdf_network': self.sdf_network.state_dict(), 242 | 'variance_network': self.deviation_network.state_dict(), 243 | 'render_network': self.render_network.state_dict(), 244 | 'pose_param_net': self.pose_param_net.state_dict(), 245 | 'optimizer': self.optimizer.state_dict(), 246 | 'iter_step': self.iter_step, 247 | } 248 | 249 | out_dir = os.path.join(self.base_exp_dir, 'pose_checkpoints') 250 | os.makedirs(out_dir, exist_ok=True) 251 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_{:0>6d}.pth'.format(self.iter_step))) 252 | 253 | def validate_image(self, idx=-1, resolution_level=-1): 254 | if idx < 0: 255 | idx = np.random.randint(self.dataset.n_images) 256 | 257 | print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx)) 258 | 259 | if resolution_level < 0: 260 | resolution_level = self.validate_resolution_level 261 | rays_o, rays_d = self.dataset.gen_rays_at(idx, 262 | self.pose_param_net, 263 | resolution_level=resolution_level) 264 | H, W, _ = rays_o.shape 265 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 266 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 267 | 268 | out_rgb = [] 269 | out_normal = [] 270 | 271 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 272 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 273 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 274 | 275 | render_out = self.renderer.render(rays_o_batch, 276 | rays_d_batch, 277 | near, 278 | far, 279 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 280 | background_rgb=background_rgb) 281 | 282 | def feasible(key): return (key in render_out) and ( 283 | render_out[key] is not None) 284 | 285 | if feasible('color'): 286 | out_rgb.append(render_out['color'].detach().cpu().numpy()[..., :3]) 287 | if feasible('gradients') and feasible('weights'): 288 | n_samples = render_out['gradients'].shape[1] 289 | normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None] 290 | if feasible('inside_sphere'): 291 | normals = normals * render_out['inside_sphere'][..., None] 292 | normals = normals.sum(dim=1).detach().cpu().numpy() 293 | out_normal.append(normals) 294 | del render_out 295 | 296 | img = None 297 | if len(out_rgb) > 0: 298 | img = (np.concatenate(out_rgb, axis=0).reshape( 299 | [H, W, 3, -1]) * 256).clip(0, 255) 300 | 301 | normal_img = None 302 | if len(out_normal) > 0: 303 | normal_img = np.concatenate(out_normal, axis=0) 304 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy()) 305 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None] 306 | ).reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255) 307 | 308 | os.makedirs(os.path.join(self.base_exp_dir, 'validations'), exist_ok=True) 309 | os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True) 310 | 311 | for i in range(img.shape[-1]): 312 | if len(out_rgb) > 0: 313 | cv.imwrite(os.path.join(self.base_exp_dir, 314 | 'validations', 315 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), 316 | np.concatenate([img[..., i], 317 | self.dataset.image_at(idx, resolution_level=resolution_level)])) 318 | if len(out_normal) > 0: 319 | cv.imwrite(os.path.join(self.base_exp_dir, 320 | 'normals', 321 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), 322 | normal_img[..., i]) 323 | 324 | def validate_mesh(self, world_space=True, resolution=256, threshold=0.0): 325 | bound_min = self.dataset.object_bbox_min 326 | bound_max = self.dataset.object_bbox_max 327 | 328 | vertices, triangles =\ 329 | self.renderer.extract_geometry( 330 | bound_min, bound_max, resolution=resolution, threshold=threshold) 331 | os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True) 332 | 333 | if world_space: 334 | vertices = vertices * \ 335 | self.dataset.scale_mats_np[0][0, 0] + \ 336 | self.dataset.scale_mats_np[0][:3, 3][None] 337 | 338 | mesh = trimesh.Trimesh(vertices, triangles) 339 | mesh.export(os.path.join(self.base_exp_dir, 'meshes', 340 | '{:0>8d}.ply'.format(self.iter_step))) 341 | 342 | logging.info('End') 343 | 344 | def validate_pose(self, initial_pose=False): 345 | pose_dir = os.path.join( 346 | self.base_exp_dir, 'poses_{:06d}'.format(self.iter_step)) 347 | os.makedirs(pose_dir, exist_ok=True) 348 | 349 | scale_mat = self.dataset.object_scale_mat 350 | 351 | pred_poses = [] 352 | for idx in range(self.dataset.n_images): 353 | if initial_pose: 354 | p = self.pose_param_net.get_init_pose(idx) 355 | else: 356 | p = self.pose_param_net(idx) 357 | p = p.detach().cpu().numpy() 358 | # scale and transform 359 | t = scale_mat @ p[:, 3].T 360 | p = np.concatenate([p[:, :3], t[:, None]], axis=1) 361 | pred_poses.append(p) 362 | pred_poses = np.stack(pred_poses) 363 | 364 | np.savetxt(os.path.join(pose_dir, 'refined_pose.txt'), 365 | pred_poses.reshape(-1, 16), 366 | fmt='%.8f', delimiter=' ') 367 | 368 | gt_poses = self.dataset.get_gt_pose() # np, [n44] 369 | 370 | pred_poses = utils.pose_alignment(pred_poses, gt_poses) 371 | 372 | # ate 373 | ate_rots, ate_trans = utils.compute_ATE(gt_poses, pred_poses) 374 | ate_errs = np.stack([ate_rots, ate_trans], axis=-1) 375 | ate_errs = np.concatenate([ate_errs, np.mean(ate_errs, axis=0).reshape(-1, 2)], axis=0) 376 | 377 | self.writer.add_scalar('Val/ate_rot', np.mean(ate_errs, axis=0)[0] / 3.14 * 180, self.iter_step) 378 | self.writer.add_scalar('Val/ate_trans', np.mean(ate_errs, axis=0)[1], self.iter_step) 379 | 380 | 381 | if __name__ == '__main__': 382 | print('Hello Wooden') 383 | 384 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 385 | 386 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" 387 | logging.basicConfig(level=logging.DEBUG, format=FORMAT) 388 | 389 | parser = argparse.ArgumentParser() 390 | parser.add_argument('--conf', type=str, default='./confs/base.conf') 391 | parser.add_argument('--mode', type=str, default='train') 392 | parser.add_argument('--mcube_threshold', type=float, default=0.0) 393 | parser.add_argument('--gpu', type=int, default=0) 394 | parser.add_argument('--case', type=str, default='') 395 | 396 | args = parser.parse_args() 397 | 398 | torch.cuda.set_device(args.gpu) 399 | runner = PoseRunner(args.conf, args.mode, args.case) 400 | 401 | if args.mode == 'train': 402 | runner.train() 403 | elif args.mode == 'validate_pose': 404 | runner.validate_pose() 405 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import kornia.geometry as KG 4 | from scipy.spatial.transform import Rotation 5 | import torch.nn.functional as F 6 | 7 | 8 | def compute_P_from_KT(K, T): 9 | 10 | P = torch.matmul(K, torch.linalg.inv(T)) 11 | return P 12 | 13 | 14 | def umeyama_alignment(x, y, with_scale=True): 15 | """ 16 | Computes the least squares solution parameters of an Sim(m) matrix 17 | that minimizes the distance between a set of registered points. 18 | Umeyama, Shinji: Least-squares estimation of transformation parameters 19 | between two point patterns. IEEE PAMI, 1991 20 | :param x: mxn matrix of points, m = dimension, n = nr. of data points 21 | :param y: mxn matrix of points, m = dimension, n = nr. of data points 22 | :param with_scale: set to True to align also the scale (default: 1.0 scale) 23 | :return: r, t, c - rotation matrix, translation vector and scale factor 24 | """ 25 | if x.shape != y.shape: 26 | assert False, "x.shape not equal to y.shape" 27 | 28 | # m = dimension, n = nr. of data points 29 | m, n = x.shape 30 | 31 | # means, eq. 34 and 35 32 | mean_x = x.mean(axis=1) 33 | mean_y = y.mean(axis=1) 34 | 35 | # variance, eq. 36 36 | # "transpose" for column subtraction 37 | sigma_x = 1.0 / n * (np.linalg.norm(x - mean_x[:, np.newaxis])**2) 38 | 39 | # covariance matrix, eq. 38 40 | outer_sum = np.zeros((m, m)) 41 | for i in range(n): 42 | outer_sum += np.outer((y[:, i] - mean_y), (x[:, i] - mean_x)) 43 | cov_xy = np.multiply(1.0 / n, outer_sum) 44 | 45 | # SVD (text betw. eq. 38 and 39) 46 | u, d, v = np.linalg.svd(cov_xy) 47 | 48 | # S matrix, eq. 43 49 | s = np.eye(m) 50 | if np.linalg.det(u) * np.linalg.det(v) < 0.0: 51 | # Ensure a RHS coordinate system (Kabsch algorithm). 52 | s[m - 1, m - 1] = -1 53 | 54 | # rotation, eq. 40 55 | r = u.dot(s).dot(v) 56 | 57 | # scale & translation, eq. 42 and 41 58 | c = 1 / sigma_x * np.trace(np.diag(d).dot(s)) if with_scale else 1.0 59 | t = mean_y - np.multiply(c, r.dot(mean_x)) 60 | 61 | return r, t, c 62 | 63 | 64 | def pose_alignment(poses_pred, poses_gt): 65 | 66 | xyz_result = poses_pred[:, :3, 3].T 67 | xyz_gt = poses_gt[:, :3, 3].T 68 | 69 | r, t, scale = umeyama_alignment(xyz_result, xyz_gt, with_scale=True) 70 | 71 | align_transformation = np.eye(4) 72 | align_transformation[:3:, :3] = r 73 | align_transformation[:3, 3] = t 74 | 75 | for cnt in range(poses_pred.shape[0]): 76 | poses_pred[cnt][:3, 3] *= scale 77 | poses_pred[cnt] = align_transformation @ poses_pred[cnt] 78 | 79 | return poses_pred 80 | 81 | 82 | def rotation_error(pose_error): 83 | """Compute rotation error 84 | Args: 85 | pose_error (4x4 array): relative pose error 86 | Returns: 87 | rot_error (float): rotation error 88 | """ 89 | r_diff = Rotation.from_matrix(pose_error[:3, :3]) 90 | pose_error = r_diff.as_matrix() 91 | a = pose_error[0, 0] 92 | b = pose_error[1, 1] 93 | c = pose_error[2, 2] 94 | d = 0.5*(a+b+c-1.0) 95 | rot_error = np.arccos(max(min(d, 1.0), -1.0)) 96 | return rot_error 97 | 98 | 99 | def translation_error(pose_error): 100 | """Compute translation error 101 | Args: 102 | pose_error (4x4 array): relative pose error 103 | Returns: 104 | trans_error (float): translation error 105 | """ 106 | dx = pose_error[0, 3] 107 | dy = pose_error[1, 3] 108 | dz = pose_error[2, 3] 109 | trans_error = np.sqrt(dx**2+dy**2+dz**2) 110 | return trans_error 111 | 112 | 113 | def compute_rpe(gt, pred): 114 | trans_errors = [] 115 | rot_errors = [] 116 | for i in range(len(gt)-1): 117 | gt1 = gt[i] 118 | gt2 = gt[i+1] 119 | gt_rel = np.linalg.inv(gt1) @ gt2 120 | 121 | pred1 = pred[i] 122 | pred2 = pred[i+1] 123 | pred_rel = np.linalg.inv(pred1) @ pred2 124 | rel_err = np.linalg.inv(gt_rel) @ pred_rel 125 | 126 | trans_errors.append(translation_error(rel_err)) 127 | rot_errors.append(rotation_error(rel_err)) 128 | 129 | return np.array(rot_errors), np.array(trans_errors) 130 | 131 | 132 | def compute_ATE(gt, pred): 133 | """Compute RMSE of ATE 134 | Args: 135 | gt: ground-truth poses 136 | pred: predicted poses 137 | """ 138 | r_errs = [] 139 | t_errs = [] 140 | 141 | for i in range(len(pred)): 142 | # cur_gt = np.linalg.inv(gt_0) @ gt[i] 143 | cur_gt = gt[i] 144 | gt_xyz = cur_gt[:3, 3] 145 | 146 | # cur_pred = np.linalg.inv(pred_0) @ pred[i] 147 | cur_pred = pred[i] 148 | pred_xyz = cur_pred[:3, 3] 149 | 150 | align_err = gt_xyz - pred_xyz 151 | 152 | rot_err = rotation_error(np.linalg.inv(cur_gt) @ cur_pred) 153 | 154 | r_errs.append(rot_err) 155 | t_errs.append(np.sqrt(np.sum(align_err ** 2))) 156 | 157 | # ate = np.sqrt(np.mean(np.asarray(errors) ** 2)) 158 | return np.array(r_errs), np.array(t_errs) 159 | 160 | 161 | 162 | def compute_epipolar_err(ref_xy, src_xy, P1, P2): 163 | 164 | Fm = KG.epipolar.fundamental_from_projections(P1[None, :3], P2[None, :3]) 165 | 166 | err = KG.symmetrical_epipolar_distance(ref_xy[None], 167 | src_xy[None], 168 | Fm, 169 | squared=False, 170 | eps=1e-08) 171 | 172 | return err.squeeze() 173 | 174 | 175 | def evaluate_pose(intrinsic, pose, P_src_list, match_list, num_pairs, inlier_threshold): 176 | P_ref = compute_P_from_KT(intrinsic, pose) 177 | 178 | inlier_rates = [] 179 | errs = [] 180 | 181 | loss = 0 182 | for idx, m in enumerate(match_list): 183 | epi_err = compute_epipolar_err(m[:, 0:2], 184 | m[:, 2:4], 185 | P_ref, 186 | P_src_list[idx]) 187 | 188 | inlier_mask = epi_err < inlier_threshold 189 | inlier_rate = inlier_mask.float().mean() 190 | 191 | inlier_rates.append(inlier_rate) 192 | 193 | if inlier_rate > 0: 194 | errs.append(epi_err) 195 | 196 | weight = inlier_rate * inlier_rate 197 | loss += weight * F.huber_loss(epi_err[inlier_mask], torch.zeros_like(epi_err[inlier_mask])) 198 | 199 | if len(errs) > num_pairs: 200 | break 201 | 202 | avg_inlier_rate = torch.stack(inlier_rates).mean() 203 | 204 | loss = loss / num_pairs 205 | 206 | return avg_inlier_rate, loss 207 | 208 | --------------------------------------------------------------------------------