├── .gitignore ├── ATE ├── align_trajectory.py ├── align_utils.py ├── compute_trajectory_errors.py ├── results_writer.py ├── trajectory_utils.py └── transformations.py ├── DPT └── dpt │ ├── __init__.py │ ├── base_model.py │ ├── blocks.py │ ├── midas_net.py │ ├── models.py │ ├── transforms.py │ └── vit.py ├── LICENSE ├── README.md ├── configs ├── LLFF │ └── fern.yaml ├── Tanks │ ├── Ballroom.yaml │ ├── Barn.yaml │ ├── Church.yaml │ ├── Family.yaml │ ├── Francis.yaml │ ├── Horse.yaml │ ├── Ignatius.yaml │ └── Museum.yaml ├── Test │ ├── images.yaml │ └── nerf.yaml ├── default.yaml └── preprocess.yaml ├── dataloading ├── __init__.py ├── common.py ├── configloading.py ├── dataloading.py └── dataset.py ├── environment.yaml ├── evaluation ├── eval.py └── eval_poses.py ├── model ├── __init__.py ├── checkpoints.py ├── common.py ├── config.py ├── distortions.py ├── eval_images.py ├── eval_pose_one_epoch.py ├── extracting_images.py ├── intrinsics.py ├── losses.py ├── network.py ├── official_nerf.py ├── poses.py ├── rendering.py └── training.py ├── preprocess └── dpt_depth.py ├── third_party └── pytorch_ssim │ └── __init__.py ├── train.py ├── utils_poses ├── align_traj.py ├── comp_ate.py ├── lie_group_helper.py └── vis_cam_traj.py └── vis ├── render.py └── vis_poses.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/**.pt 3 | data/ 4 | .DS_Store 5 | temp/ 6 | run.sh 7 | third_party/ 8 | out/ 9 | arc.sh -------------------------------------------------------------------------------- /ATE/align_trajectory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import ATE.transformations as tfs 6 | 7 | 8 | def get_best_yaw(C): 9 | ''' 10 | maximize trace(Rz(theta) * C) 11 | ''' 12 | assert C.shape == (3, 3) 13 | 14 | A = C[0, 1] - C[1, 0] 15 | B = C[0, 0] + C[1, 1] 16 | theta = np.pi / 2 - np.arctan2(B, A) 17 | 18 | return theta 19 | 20 | 21 | def rot_z(theta): 22 | R = tfs.rotation_matrix(theta, [0, 0, 1]) 23 | R = R[0:3, 0:3] 24 | 25 | return R 26 | 27 | 28 | def align_umeyama(model, data, known_scale=False, yaw_only=False): 29 | """Implementation of the paper: S. Umeyama, Least-Squares Estimation 30 | of Transformation Parameters Between Two Point Patterns, 31 | IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991. 32 | 33 | model = s * R * data + t 34 | 35 | Input: 36 | model -- first trajectory (nx3), numpy array type 37 | data -- second trajectory (nx3), numpy array type 38 | 39 | Output: 40 | s -- scale factor (scalar) 41 | R -- rotation matrix (3x3) 42 | t -- translation vector (3x1) 43 | t_error -- translational error per point (1xn) 44 | 45 | """ 46 | 47 | # substract mean 48 | mu_M = model.mean(0) 49 | mu_D = data.mean(0) 50 | model_zerocentered = model - mu_M 51 | data_zerocentered = data - mu_D 52 | n = np.shape(model)[0] 53 | 54 | # correlation 55 | C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered) 56 | sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum() 57 | U_svd, D_svd, V_svd = np.linalg.linalg.svd(C) 58 | 59 | D_svd = np.diag(D_svd) 60 | V_svd = np.transpose(V_svd) 61 | 62 | S = np.eye(3) 63 | if(np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0): 64 | S[2, 2] = -1 65 | 66 | if yaw_only: 67 | rot_C = np.dot(data_zerocentered.transpose(), model_zerocentered) 68 | theta = get_best_yaw(rot_C) 69 | R = rot_z(theta) 70 | else: 71 | R = np.dot(U_svd, np.dot(S, np.transpose(V_svd))) 72 | 73 | if known_scale: 74 | s = 1 75 | else: 76 | s = 1.0/sigma2*np.trace(np.dot(D_svd, S)) 77 | 78 | t = mu_M-s*np.dot(R, mu_D) 79 | 80 | return s, R, t 81 | -------------------------------------------------------------------------------- /ATE/align_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | 6 | import ATE.transformations as tfs 7 | import ATE.align_trajectory as align 8 | 9 | 10 | def _getIndices(n_aligned, total_n): 11 | if n_aligned == -1: 12 | idxs = np.arange(0, total_n) 13 | else: 14 | assert n_aligned <= total_n and n_aligned >= 1 15 | idxs = np.arange(0, n_aligned) 16 | return idxs 17 | 18 | 19 | def alignPositionYawSingle(p_es, p_gt, q_es, q_gt): 20 | ''' 21 | calcualte the 4DOF transformation: yaw R and translation t so that: 22 | gt = R * est + t 23 | ''' 24 | 25 | p_es_0, q_es_0 = p_es[0, :], q_es[0, :] 26 | p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :] 27 | g_rot = tfs.quaternion_matrix(q_gt_0) 28 | g_rot = g_rot[0:3, 0:3] 29 | est_rot = tfs.quaternion_matrix(q_es_0) 30 | est_rot = est_rot[0:3, 0:3] 31 | 32 | C_R = np.dot(est_rot, g_rot.transpose()) 33 | theta = align.get_best_yaw(C_R) 34 | R = align.rot_z(theta) 35 | t = p_gt_0 - np.dot(R, p_es_0) 36 | 37 | return R, t 38 | 39 | 40 | def alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned=1): 41 | if n_aligned == 1: 42 | R, t = alignPositionYawSingle(p_es, p_gt, q_es, q_gt) 43 | return R, t 44 | else: 45 | idxs = _getIndices(n_aligned, p_es.shape[0]) 46 | est_pos = p_es[idxs, 0:3] 47 | gt_pos = p_gt[idxs, 0:3] 48 | _, R, t = align.align_umeyama(gt_pos, est_pos, known_scale=True, 49 | yaw_only=True) # note the order 50 | t = np.array(t) 51 | t = t.reshape((3, )) 52 | R = np.array(R) 53 | return R, t 54 | 55 | 56 | # align by a SE3 transformation 57 | def alignSE3Single(p_es, p_gt, q_es, q_gt): 58 | ''' 59 | Calculate SE3 transformation R and t so that: 60 | gt = R * est + t 61 | Using only the first poses of est and gt 62 | ''' 63 | 64 | p_es_0, q_es_0 = p_es[0, :], q_es[0, :] 65 | p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :] 66 | 67 | g_rot = tfs.quaternion_matrix(q_gt_0) 68 | g_rot = g_rot[0:3, 0:3] 69 | est_rot = tfs.quaternion_matrix(q_es_0) 70 | est_rot = est_rot[0:3, 0:3] 71 | 72 | R = np.dot(g_rot, np.transpose(est_rot)) 73 | t = p_gt_0 - np.dot(R, p_es_0) 74 | 75 | return R, t 76 | 77 | 78 | def alignSE3(p_es, p_gt, q_es, q_gt, n_aligned=-1): 79 | ''' 80 | Calculate SE3 transformation R and t so that: 81 | gt = R * est + t 82 | ''' 83 | if n_aligned == 1: 84 | R, t = alignSE3Single(p_es, p_gt, q_es, q_gt) 85 | return R, t 86 | else: 87 | idxs = _getIndices(n_aligned, p_es.shape[0]) 88 | est_pos = p_es[idxs, 0:3] 89 | gt_pos = p_gt[idxs, 0:3] 90 | s, R, t = align.align_umeyama(gt_pos, est_pos, 91 | known_scale=True) # note the order 92 | t = np.array(t) 93 | t = t.reshape((3, )) 94 | R = np.array(R) 95 | return R, t 96 | 97 | 98 | # align by similarity transformation 99 | def alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned=-1): 100 | ''' 101 | calculate s, R, t so that: 102 | gt = R * s * est + t 103 | ''' 104 | idxs = _getIndices(n_aligned, p_es.shape[0]) 105 | est_pos = p_es[idxs, 0:3] 106 | gt_pos = p_gt[idxs, 0:3] 107 | s, R, t = align.align_umeyama(gt_pos, est_pos) # note the order 108 | return s, R, t 109 | 110 | 111 | # a general interface 112 | def alignTrajectory(p_es, p_gt, q_es, q_gt, method, n_aligned=-1): 113 | ''' 114 | calculate s, R, t so that: 115 | gt = R * s * est + t 116 | method can be: sim3, se3, posyaw, none; 117 | n_aligned: -1 means using all the frames 118 | ''' 119 | assert p_es.shape[1] == 3 120 | assert p_gt.shape[1] == 3 121 | assert q_es.shape[1] == 4 122 | assert q_gt.shape[1] == 4 123 | 124 | s = 1 125 | R = None 126 | t = None 127 | if method == 'sim3': 128 | assert n_aligned >= 2 or n_aligned == -1, "sim3 uses at least 2 frames" 129 | s, R, t = alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned) 130 | elif method == 'se3': 131 | R, t = alignSE3(p_es, p_gt, q_es, q_gt, n_aligned) 132 | elif method == 'posyaw': 133 | R, t = alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned) 134 | elif method == 'none': 135 | R = np.identity(3) 136 | t = np.zeros((3, )) 137 | else: 138 | assert False, 'unknown alignment method' 139 | 140 | return s, R, t 141 | 142 | 143 | if __name__ == '__main__': 144 | pass 145 | -------------------------------------------------------------------------------- /ATE/compute_trajectory_errors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | import os 4 | import numpy as np 5 | 6 | import ATE.trajectory_utils as tu 7 | import ATE.transformations as tf 8 | 9 | 10 | def compute_relative_error(p_es, q_es, p_gt, q_gt, T_cm, dist, max_dist_diff, 11 | accum_distances=[], 12 | scale=1.0): 13 | 14 | if len(accum_distances) == 0: 15 | accum_distances = tu.get_distance_from_start(p_gt) 16 | comparisons = tu.compute_comparison_indices_length( 17 | accum_distances, dist, max_dist_diff) 18 | 19 | n_samples = len(comparisons) 20 | print('number of samples = {0} '.format(n_samples)) 21 | if n_samples < 2: 22 | print("Too few samples! Will not compute.") 23 | return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]),\ 24 | np.array([]), np.array([]) 25 | 26 | T_mc = np.linalg.inv(T_cm) 27 | errors = [] 28 | for idx, c in enumerate(comparisons): 29 | if not c == -1: 30 | T_c1 = tu.get_rigid_body_trafo(q_es[idx, :], p_es[idx, :]) 31 | T_c2 = tu.get_rigid_body_trafo(q_es[c, :], p_es[c, :]) 32 | T_c1_c2 = np.dot(np.linalg.inv(T_c1), T_c2) 33 | T_c1_c2[:3, 3] *= scale 34 | 35 | T_m1 = tu.get_rigid_body_trafo(q_gt[idx, :], p_gt[idx, :]) 36 | T_m2 = tu.get_rigid_body_trafo(q_gt[c, :], p_gt[c, :]) 37 | T_m1_m2 = np.dot(np.linalg.inv(T_m1), T_m2) 38 | 39 | T_m1_m2_in_c1 = np.dot(T_cm, np.dot(T_m1_m2, T_mc)) 40 | T_error_in_c2 = np.dot(np.linalg.inv(T_m1_m2_in_c1), T_c1_c2) 41 | T_c2_rot = np.eye(4) 42 | T_c2_rot[0:3, 0:3] = T_c2[0:3, 0:3] 43 | T_error_in_w = np.dot(T_c2_rot, np.dot( 44 | T_error_in_c2, np.linalg.inv(T_c2_rot))) 45 | errors.append(T_error_in_w) 46 | 47 | error_trans_norm = [] 48 | error_trans_perc = [] 49 | error_yaw = [] 50 | error_gravity = [] 51 | e_rot = [] 52 | e_rot_deg_per_m = [] 53 | for e in errors: 54 | tn = np.linalg.norm(e[0:3, 3]) 55 | error_trans_norm.append(tn) 56 | error_trans_perc.append(tn / dist * 100) 57 | ypr_angles = tf.euler_from_matrix(e, 'rzyx') 58 | e_rot.append(tu.compute_angle(e)) 59 | error_yaw.append(abs(ypr_angles[0])*180.0/np.pi) 60 | error_gravity.append( 61 | np.sqrt(ypr_angles[1]**2+ypr_angles[2]**2)*180.0/np.pi) 62 | e_rot_deg_per_m.append(e_rot[-1] / dist) 63 | return errors, np.array(error_trans_norm), np.array(error_trans_perc),\ 64 | np.array(error_yaw), np.array(error_gravity), np.array(e_rot),\ 65 | np.array(e_rot_deg_per_m) 66 | 67 | 68 | def compute_absolute_error(p_es_aligned, q_es_aligned, p_gt, q_gt): 69 | e_trans_vec = (p_gt-p_es_aligned) 70 | e_trans = np.sqrt(np.sum(e_trans_vec**2, 1)) 71 | 72 | 73 | # orientation error 74 | e_rot = np.zeros((len(e_trans,))) 75 | e_ypr = np.zeros(np.shape(p_es_aligned)) 76 | for i in range(np.shape(p_es_aligned)[0]): 77 | R_we = tf.matrix_from_quaternion(q_es_aligned[i, :]) 78 | R_wg = tf.matrix_from_quaternion(q_gt[i, :]) 79 | e_R = np.dot(R_we, np.linalg.inv(R_wg)) 80 | e_ypr[i, :] = tf.euler_from_matrix(e_R, 'rzyx') 81 | e_rot[i] = np.rad2deg(np.linalg.norm(tf.logmap_so3(e_R[:3, :3]))) 82 | # scale drift 83 | motion_gt = np.diff(p_gt, 0) 84 | motion_es = np.diff(p_es_aligned, 0) 85 | dist_gt = np.sqrt(np.sum(np.multiply(motion_gt, motion_gt), 1)) 86 | dist_es = np.sqrt(np.sum(np.multiply(motion_es, motion_es), 1)) 87 | e_scale_perc = np.abs((np.divide(dist_es, dist_gt)-1.0) * 100) 88 | # ate = np.sqrt(np.mean(np.asarray(e_trans) ** 2)) 89 | return e_trans, e_trans_vec, e_rot, e_ypr, e_scale_perc 90 | -------------------------------------------------------------------------------- /ATE/results_writer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | import os 3 | # import yaml 4 | import numpy as np 5 | 6 | 7 | def compute_statistics(data_vec): 8 | stats = dict() 9 | if len(data_vec) > 0: 10 | stats['rmse'] = float( 11 | np.sqrt(np.dot(data_vec, data_vec) / len(data_vec))) 12 | stats['mean'] = float(np.mean(data_vec)) 13 | stats['median'] = float(np.median(data_vec)) 14 | stats['std'] = float(np.std(data_vec)) 15 | stats['min'] = float(np.min(data_vec)) 16 | stats['max'] = float(np.max(data_vec)) 17 | stats['num_samples'] = int(len(data_vec)) 18 | else: 19 | stats['rmse'] = 0 20 | stats['mean'] = 0 21 | stats['median'] = 0 22 | stats['std'] = 0 23 | stats['min'] = 0 24 | stats['max'] = 0 25 | stats['num_samples'] = 0 26 | 27 | return stats 28 | 29 | 30 | # def update_and_save_stats(new_stats, label, yaml_filename): 31 | # stats = dict() 32 | # if os.path.exists(yaml_filename): 33 | # stats = yaml.load(open(yaml_filename, 'r'), Loader=yaml.FullLoader) 34 | # stats[label] = new_stats 35 | # 36 | # with open(yaml_filename, 'w') as outfile: 37 | # outfile.write(yaml.dump(stats, default_flow_style=False)) 38 | # 39 | # return 40 | # 41 | # 42 | # def compute_and_save_statistics(data_vec, label, yaml_filename): 43 | # new_stats = compute_statistics(data_vec) 44 | # update_and_save_stats(new_stats, label, yaml_filename) 45 | # 46 | # return new_stats 47 | # 48 | # 49 | # def write_tex_table(list_values, rows, cols, outfn): 50 | # ''' 51 | # write list_values[row_idx][col_idx] to a table that is ready to be pasted 52 | # into latex source 53 | # 54 | # list_values is a list of row values 55 | # 56 | # The value should be string of desired format 57 | # ''' 58 | # 59 | # assert len(rows) >= 1 60 | # assert len(cols) >= 1 61 | # 62 | # with open(outfn, 'w') as f: 63 | # # write header 64 | # f.write(' & ') 65 | # for col_i in cols[:-1]: 66 | # f.write(col_i + ' & ') 67 | # f.write(' ' + cols[-1]+'\n') 68 | # 69 | # # write each row 70 | # for row_idx, row_i in enumerate(list_values): 71 | # f.write(rows[row_idx] + ' & ') 72 | # row_values = list_values[row_idx] 73 | # for col_idx in range(len(row_values) - 1): 74 | # f.write(row_values[col_idx] + ' & ') 75 | # f.write(' ' + row_values[-1]+' \n') 76 | -------------------------------------------------------------------------------- /ATE/trajectory_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | """ 3 | @author: Christian Forster 4 | """ 5 | 6 | import os 7 | import numpy as np 8 | import ATE.transformations as tf 9 | 10 | 11 | def get_rigid_body_trafo(quat, trans): 12 | T = tf.quaternion_matrix(quat) 13 | T[0:3, 3] = trans 14 | return T 15 | 16 | 17 | def get_distance_from_start(gt_translation): 18 | distances = np.diff(gt_translation[:, 0:3], axis=0) 19 | distances = np.sqrt(np.sum(np.multiply(distances, distances), 1)) 20 | distances = np.cumsum(distances) 21 | distances = np.concatenate(([0], distances)) 22 | return distances 23 | 24 | 25 | def compute_comparison_indices_length(distances, dist, max_dist_diff): 26 | max_idx = len(distances) 27 | comparisons = [] 28 | for idx, d in enumerate(distances): 29 | best_idx = -1 30 | error = max_dist_diff 31 | for i in range(idx, max_idx): 32 | if np.abs(distances[i]-(d+dist)) < error: 33 | best_idx = i 34 | error = np.abs(distances[i] - (d+dist)) 35 | if best_idx != -1: 36 | comparisons.append(best_idx) 37 | return comparisons 38 | 39 | 40 | def compute_angle(transform): 41 | """ 42 | Compute the rotation angle from a 4x4 homogeneous matrix. 43 | """ 44 | # an invitation to 3-d vision, p 27 45 | return np.arccos( 46 | min(1, max(-1, (np.trace(transform[0:3, 0:3]) - 1)/2)))*180.0/np.pi 47 | -------------------------------------------------------------------------------- /DPT/dpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/nope-nerf/47c861f6259fb0f3921be3a19385feebfa769325/DPT/dpt/__init__.py -------------------------------------------------------------------------------- /DPT/dpt/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device("cpu")) 12 | # parameters = torch.load(path, map_location=torch.device("cuda")) 13 | 14 | if "optimizer" in parameters: 15 | parameters = parameters["model"] 16 | 17 | self.load_state_dict(parameters) 18 | -------------------------------------------------------------------------------- /DPT/dpt/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | 12 | def _make_encoder( 13 | backbone, 14 | features, 15 | use_pretrained, 16 | groups=1, 17 | expand=False, 18 | exportable=True, 19 | hooks=None, 20 | use_vit_only=False, 21 | use_readout="ignore", 22 | enable_attention_hooks=False, 23 | ): 24 | if backbone == "vitl16_384": 25 | pretrained = _make_pretrained_vitl16_384( 26 | use_pretrained, 27 | hooks=hooks, 28 | use_readout=use_readout, 29 | enable_attention_hooks=enable_attention_hooks, 30 | ) 31 | scratch = _make_scratch( 32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 33 | ) # ViT-L/16 - 85.0% Top1 (backbone) 34 | elif backbone == "vitb_rn50_384": 35 | pretrained = _make_pretrained_vitb_rn50_384( 36 | use_pretrained, 37 | hooks=hooks, 38 | use_vit_only=use_vit_only, 39 | use_readout=use_readout, 40 | enable_attention_hooks=enable_attention_hooks, 41 | ) 42 | scratch = _make_scratch( 43 | [256, 512, 768, 768], features, groups=groups, expand=expand 44 | ) # ViT-H/16 - 85.0% Top1 (backbone) 45 | elif backbone == "vitb16_384": 46 | pretrained = _make_pretrained_vitb16_384( 47 | use_pretrained, 48 | hooks=hooks, 49 | use_readout=use_readout, 50 | enable_attention_hooks=enable_attention_hooks, 51 | ) 52 | scratch = _make_scratch( 53 | [96, 192, 384, 768], features, groups=groups, expand=expand 54 | ) # ViT-B/16 - 84.6% Top1 (backbone) 55 | elif backbone == "resnext101_wsl": 56 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 57 | scratch = _make_scratch( 58 | [256, 512, 1024, 2048], features, groups=groups, expand=expand 59 | ) # efficientnet_lite3 60 | else: 61 | print(f"Backbone '{backbone}' not implemented") 62 | assert False 63 | 64 | return pretrained, scratch 65 | 66 | 67 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 68 | scratch = nn.Module() 69 | 70 | out_shape1 = out_shape 71 | out_shape2 = out_shape 72 | out_shape3 = out_shape 73 | out_shape4 = out_shape 74 | if expand == True: 75 | out_shape1 = out_shape 76 | out_shape2 = out_shape * 2 77 | out_shape3 = out_shape * 4 78 | out_shape4 = out_shape * 8 79 | 80 | scratch.layer1_rn = nn.Conv2d( 81 | in_shape[0], 82 | out_shape1, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1, 86 | bias=False, 87 | groups=groups, 88 | ) 89 | scratch.layer2_rn = nn.Conv2d( 90 | in_shape[1], 91 | out_shape2, 92 | kernel_size=3, 93 | stride=1, 94 | padding=1, 95 | bias=False, 96 | groups=groups, 97 | ) 98 | scratch.layer3_rn = nn.Conv2d( 99 | in_shape[2], 100 | out_shape3, 101 | kernel_size=3, 102 | stride=1, 103 | padding=1, 104 | bias=False, 105 | groups=groups, 106 | ) 107 | scratch.layer4_rn = nn.Conv2d( 108 | in_shape[3], 109 | out_shape4, 110 | kernel_size=3, 111 | stride=1, 112 | padding=1, 113 | bias=False, 114 | groups=groups, 115 | ) 116 | 117 | return scratch 118 | 119 | 120 | def _make_resnet_backbone(resnet): 121 | pretrained = nn.Module() 122 | pretrained.layer1 = nn.Sequential( 123 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 124 | ) 125 | 126 | pretrained.layer2 = resnet.layer2 127 | pretrained.layer3 = resnet.layer3 128 | pretrained.layer4 = resnet.layer4 129 | 130 | return pretrained 131 | 132 | 133 | def _make_pretrained_resnext101_wsl(use_pretrained): 134 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 135 | return _make_resnet_backbone(resnet) 136 | 137 | 138 | class Interpolate(nn.Module): 139 | """Interpolation module.""" 140 | 141 | def __init__(self, scale_factor, mode, align_corners=False): 142 | """Init. 143 | 144 | Args: 145 | scale_factor (float): scaling 146 | mode (str): interpolation mode 147 | """ 148 | super(Interpolate, self).__init__() 149 | 150 | self.interp = nn.functional.interpolate 151 | self.scale_factor = scale_factor 152 | self.mode = mode 153 | self.align_corners = align_corners 154 | 155 | def forward(self, x): 156 | """Forward pass. 157 | 158 | Args: 159 | x (tensor): input 160 | 161 | Returns: 162 | tensor: interpolated data 163 | """ 164 | 165 | x = self.interp( 166 | x, 167 | scale_factor=self.scale_factor, 168 | mode=self.mode, 169 | align_corners=self.align_corners, 170 | ) 171 | 172 | return x 173 | 174 | 175 | class ResidualConvUnit(nn.Module): 176 | """Residual convolution module.""" 177 | 178 | def __init__(self, features): 179 | """Init. 180 | 181 | Args: 182 | features (int): number of features 183 | """ 184 | super().__init__() 185 | 186 | self.conv1 = nn.Conv2d( 187 | features, features, kernel_size=3, stride=1, padding=1, bias=True 188 | ) 189 | 190 | self.conv2 = nn.Conv2d( 191 | features, features, kernel_size=3, stride=1, padding=1, bias=True 192 | ) 193 | 194 | self.relu = nn.ReLU(inplace=True) 195 | 196 | def forward(self, x): 197 | """Forward pass. 198 | 199 | Args: 200 | x (tensor): input 201 | 202 | Returns: 203 | tensor: output 204 | """ 205 | out = self.relu(x) 206 | out = self.conv1(out) 207 | out = self.relu(out) 208 | out = self.conv2(out) 209 | 210 | return out + x 211 | 212 | 213 | class FeatureFusionBlock(nn.Module): 214 | """Feature fusion block.""" 215 | 216 | def __init__(self, features): 217 | """Init. 218 | 219 | Args: 220 | features (int): number of features 221 | """ 222 | super(FeatureFusionBlock, self).__init__() 223 | 224 | self.resConfUnit1 = ResidualConvUnit(features) 225 | self.resConfUnit2 = ResidualConvUnit(features) 226 | 227 | def forward(self, *xs): 228 | """Forward pass. 229 | 230 | Returns: 231 | tensor: output 232 | """ 233 | output = xs[0] 234 | 235 | if len(xs) == 2: 236 | output += self.resConfUnit1(xs[1]) 237 | 238 | output = self.resConfUnit2(output) 239 | 240 | output = nn.functional.interpolate( 241 | output, scale_factor=2, mode="bilinear", align_corners=True 242 | ) 243 | 244 | return output 245 | 246 | 247 | class ResidualConvUnit_custom(nn.Module): 248 | """Residual convolution module.""" 249 | 250 | def __init__(self, features, activation, bn): 251 | """Init. 252 | 253 | Args: 254 | features (int): number of features 255 | """ 256 | super().__init__() 257 | 258 | self.bn = bn 259 | 260 | self.groups = 1 261 | 262 | self.conv1 = nn.Conv2d( 263 | features, 264 | features, 265 | kernel_size=3, 266 | stride=1, 267 | padding=1, 268 | bias=not self.bn, 269 | groups=self.groups, 270 | ) 271 | 272 | self.conv2 = nn.Conv2d( 273 | features, 274 | features, 275 | kernel_size=3, 276 | stride=1, 277 | padding=1, 278 | bias=not self.bn, 279 | groups=self.groups, 280 | ) 281 | 282 | if self.bn == True: 283 | self.bn1 = nn.BatchNorm2d(features) 284 | self.bn2 = nn.BatchNorm2d(features) 285 | 286 | self.activation = activation 287 | 288 | self.skip_add = nn.quantized.FloatFunctional() 289 | 290 | def forward(self, x): 291 | """Forward pass. 292 | 293 | Args: 294 | x (tensor): input 295 | 296 | Returns: 297 | tensor: output 298 | """ 299 | 300 | out = self.activation(x) 301 | out = self.conv1(out) 302 | if self.bn == True: 303 | out = self.bn1(out) 304 | 305 | out = self.activation(out) 306 | out = self.conv2(out) 307 | if self.bn == True: 308 | out = self.bn2(out) 309 | 310 | if self.groups > 1: 311 | out = self.conv_merge(out) 312 | 313 | return self.skip_add.add(out, x) 314 | 315 | # return out + x 316 | 317 | 318 | class FeatureFusionBlock_custom(nn.Module): 319 | """Feature fusion block.""" 320 | 321 | def __init__( 322 | self, 323 | features, 324 | activation, 325 | deconv=False, 326 | bn=False, 327 | expand=False, 328 | align_corners=True, 329 | ): 330 | """Init. 331 | 332 | Args: 333 | features (int): number of features 334 | """ 335 | super(FeatureFusionBlock_custom, self).__init__() 336 | 337 | self.deconv = deconv 338 | self.align_corners = align_corners 339 | 340 | self.groups = 1 341 | 342 | self.expand = expand 343 | out_features = features 344 | if self.expand == True: 345 | out_features = features // 2 346 | 347 | self.out_conv = nn.Conv2d( 348 | features, 349 | out_features, 350 | kernel_size=1, 351 | stride=1, 352 | padding=0, 353 | bias=True, 354 | groups=1, 355 | ) 356 | 357 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 358 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 359 | 360 | self.skip_add = nn.quantized.FloatFunctional() 361 | 362 | def forward(self, *xs): 363 | """Forward pass. 364 | 365 | Returns: 366 | tensor: output 367 | """ 368 | output = xs[0] 369 | 370 | if len(xs) == 2: 371 | res = self.resConfUnit1(xs[1]) 372 | output = self.skip_add.add(output, res) 373 | # output += res 374 | 375 | output = self.resConfUnit2(output) 376 | 377 | output = nn.functional.interpolate( 378 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 379 | ) 380 | 381 | output = self.out_conv(output) 382 | 383 | return output 384 | -------------------------------------------------------------------------------- /DPT/dpt/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_large(BaseModel): 13 | """Network for monocular depth estimation.""" 14 | 15 | def __init__(self, path=None, features=256, non_negative=True): 16 | """Init. 17 | 18 | Args: 19 | path (str, optional): Path to saved model. Defaults to None. 20 | features (int, optional): Number of features. Defaults to 256. 21 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 22 | """ 23 | print("Loading weights: ", path) 24 | 25 | super(MidasNet_large, self).__init__() 26 | 27 | use_pretrained = False if path is None else True 28 | 29 | self.pretrained, self.scratch = _make_encoder( 30 | backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained 31 | ) 32 | 33 | self.scratch.refinenet4 = FeatureFusionBlock(features) 34 | self.scratch.refinenet3 = FeatureFusionBlock(features) 35 | self.scratch.refinenet2 = FeatureFusionBlock(features) 36 | self.scratch.refinenet1 = FeatureFusionBlock(features) 37 | 38 | self.scratch.output_conv = nn.Sequential( 39 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 40 | Interpolate(scale_factor=2, mode="bilinear"), 41 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(True), 43 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 44 | nn.ReLU(True) if non_negative else nn.Identity(), 45 | ) 46 | 47 | if path: 48 | self.load(path) 49 | 50 | def forward(self, x): 51 | """Forward pass. 52 | 53 | Args: 54 | x (tensor): input data (image) 55 | 56 | Returns: 57 | tensor: depth 58 | """ 59 | 60 | layer_1 = self.pretrained.layer1(x) 61 | layer_2 = self.pretrained.layer2(layer_1) 62 | layer_3 = self.pretrained.layer3(layer_2) 63 | layer_4 = self.pretrained.layer4(layer_3) 64 | 65 | layer_1_rn = self.scratch.layer1_rn(layer_1) 66 | layer_2_rn = self.scratch.layer2_rn(layer_2) 67 | layer_3_rn = self.scratch.layer3_rn(layer_3) 68 | layer_4_rn = self.scratch.layer4_rn(layer_4) 69 | 70 | path_4 = self.scratch.refinenet4(layer_4_rn) 71 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 72 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 73 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 74 | 75 | out = self.scratch.output_conv(path_1) 76 | 77 | return torch.squeeze(out, dim=1) 78 | -------------------------------------------------------------------------------- /DPT/dpt/models.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | import string 3 | from xmlrpc.client import Boolean 4 | import cv2 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision.transforms import Compose 9 | # from pytorch_lightning import LightningModule 10 | 11 | from .base_model import BaseModel 12 | from .blocks import ( 13 | FeatureFusionBlock, 14 | FeatureFusionBlock_custom, 15 | Interpolate, 16 | _make_encoder, 17 | forward_vit, 18 | ) 19 | import sys 20 | def _make_fusion_block(features, use_bn): 21 | return FeatureFusionBlock_custom( 22 | features, 23 | nn.ReLU(False), 24 | deconv=False, 25 | bn=use_bn, 26 | expand=False, 27 | align_corners=True, 28 | ) 29 | 30 | 31 | class DepthLoss(nn.Module): 32 | def __init__(self, loss_type): 33 | """Calculate depth loss with masking scheme. 34 | Remove zero/negative target values 35 | 36 | Args: 37 | cfg (eDict): loss configuration 38 | - loss_type (str): the method of calculating loss 39 | - smL1 40 | - L1 41 | - L2 42 | - use_inv_depth (bool): use inverse depth 43 | """ 44 | super(DepthLoss, self).__init__() 45 | self.loss_type = loss_type 46 | self.use_inv_depth = False 47 | self.eps = 1e-6 48 | 49 | def forward(self, pred, target): 50 | """ 51 | Args: 52 | pred (Nx1xHxW): predicted depth map 53 | target (Nx1xHxW): GT depth map 54 | 55 | Returns: 56 | total_loss (dict): loss items 57 | """ 58 | losses = {} 59 | 60 | # compute mask 61 | non_zero_mask = target > 0 62 | mask = non_zero_mask 63 | 64 | # use inverse depth 65 | if self.use_inv_depth: 66 | target = 1. / (target + self.eps) 67 | pred = 1. / (pred + self.eps) 68 | 69 | if len(target[mask]) != 0: 70 | # compute loss 71 | if self.loss_type in ['smL1', 'L1', 'L2']: 72 | diff = target[mask] - pred[mask] 73 | if self.loss_type == 'smL1': 74 | loss = ((diff / 2 )**2 + 1 ).pow(0.5) - 1 75 | elif self.loss_type == 'L1': 76 | loss = diff.abs() 77 | elif self.loss_type == "L2": 78 | loss = diff ** 2 79 | depth_loss = loss.mean() 80 | elif self.loss_type in ['eigen']: 81 | diff = torch.log(target[mask]) - torch.log(pred[mask]) 82 | loss1 = (diff**2).mean() 83 | loss2 = (diff.sum())**2/(len(diff)**2) 84 | depth_loss = loss1 + 0.5 * loss2 85 | 86 | else: 87 | ### set depth_loss to 0 ### 88 | depth_loss = (pred*0).sum() 89 | 90 | return depth_loss 91 | 92 | 93 | class DPT(BaseModel): 94 | def __init__( 95 | self, 96 | head, 97 | features=256, 98 | backbone="vitb_rn50_384", 99 | readout="project", 100 | channels_last=False, 101 | use_bn=False, 102 | enable_attention_hooks=False, 103 | freeze=True 104 | ): 105 | 106 | super(DPT, self).__init__() 107 | 108 | self.channels_last = channels_last 109 | 110 | hooks = { 111 | "vitb_rn50_384": [0, 1, 8, 11], 112 | "vitb16_384": [2, 5, 8, 11], 113 | "vitl16_384": [5, 11, 17, 23], 114 | } 115 | 116 | # Instantiate backbone and reassemble blocks 117 | self.pretrained, self.scratch = _make_encoder( 118 | backbone, 119 | features, 120 | False, # Set to true of you want to train from scratch, uses ImageNet weights 121 | groups=1, 122 | expand=False, 123 | exportable=False, 124 | hooks=hooks[backbone], 125 | use_readout=readout, 126 | enable_attention_hooks=enable_attention_hooks, 127 | ) 128 | 129 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 130 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 131 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 132 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 133 | 134 | self.scratch.output_conv = head 135 | 136 | 137 | if freeze: 138 | for name, p in self.named_parameters(): 139 | p.requires_grad = False 140 | 141 | 142 | def forward(self, x): 143 | if self.channels_last == True: 144 | x.contiguous(memory_format=torch.channels_last) 145 | 146 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 147 | 148 | layer_1_rn = self.scratch.layer1_rn(layer_1) 149 | layer_2_rn = self.scratch.layer2_rn(layer_2) 150 | layer_3_rn = self.scratch.layer3_rn(layer_3) 151 | layer_4_rn = self.scratch.layer4_rn(layer_4) 152 | 153 | path_4 = self.scratch.refinenet4(layer_4_rn) 154 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 155 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 156 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 157 | 158 | out = self.scratch.output_conv(path_1) 159 | return out 160 | 161 | 162 | class DPTDepthModel(DPT): 163 | def __init__( 164 | self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=True, freeze=True, **kwargs 165 | ): 166 | features = kwargs["features"] if "features" in kwargs else 256 167 | 168 | self.scale = scale 169 | self.shift = shift 170 | self.invert = invert 171 | 172 | head = nn.Sequential( 173 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 174 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 175 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 176 | nn.ReLU(True), 177 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 178 | nn.ReLU(True) if non_negative else nn.Identity(), 179 | nn.Identity(), 180 | ) 181 | # modified head 182 | # head = nn.Sequential( 183 | # nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 184 | # Interpolate(scale_factor=4, mode="bilinear", align_corners=True), 185 | # nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 186 | # nn.ReLU(True), 187 | # nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 188 | # nn.ReLU(True) if non_negative else nn.Identity(), 189 | # nn.Identity(), 190 | # ) 191 | 192 | super().__init__(head, freeze=freeze, **kwargs) 193 | 194 | if path is not None: 195 | self.load(path) 196 | 197 | # if freeze: 198 | # for name, p in self.named_parameters(): 199 | # print(name) 200 | # p.requires_grad = False 201 | 202 | 203 | def forward(self, x): 204 | inv_depth = super().forward(x).squeeze(dim=1) 205 | 206 | if self.invert: 207 | depth = self.scale * inv_depth + self.shift 208 | depth[depth < 1e-8] = 1e-8 209 | depth = 1.0 / depth 210 | return depth 211 | else: 212 | return inv_depth 213 | 214 | 215 | # class LitDPTModule(LightningModule): 216 | 217 | # def __init__( 218 | # self, 219 | # path: string = None, 220 | # non_negative: bool = True, 221 | # scale: float = 0.000305, 222 | # shift: float = 0.1378, 223 | # invert: bool = False, 224 | # lr: float = 0.0001, 225 | # weight_decay: float = 0.005, 226 | # loss_type: string = "eigen", 227 | # ): 228 | # super().__init__() 229 | 230 | # # this line allows to access init params with 'self.hparams' attribute 231 | # # it also ensures init params will be stored in ckpt 232 | # self.save_hyperparameters(logger=False) 233 | 234 | # self.model = DPTDepthModel(path, non_negative, scale, shift, invert) 235 | 236 | # # loss function 237 | # self.criterion = DepthLoss(loss_type) 238 | 239 | # # self.automatic_optimization = False 240 | 241 | # def forward(self, x: torch.Tensor): 242 | # return self.model(x) 243 | 244 | # def step(self, batch: Any): 245 | # in_, mask, gt = batch['image'], batch['mask'], batch['depth'] 246 | # pred = self.forward(in_) 247 | # loss = self.criterion(pred, gt) 248 | # return loss, pred, gt 249 | 250 | # def training_step(self, batch: Any, batch_idx: int): 251 | 252 | # # opt = self.optimizers() 253 | # # opt.zero_grad() 254 | 255 | # loss, preds, targets = self.step(batch) 256 | # self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False) 257 | # # input_visual = rgb_unnormalize(batch['image'][0]) 258 | # # preds_visual = depth_visualization(preds[0]) 259 | # # gt_visual = depth_visualization(batch['depth'][0]) 260 | # # tensor_logger = self.logger.experiment[0] 261 | # # tensor_logger.add_image( 262 | # # 'train/input_rgb', input_visual, self.global_step 263 | # # ) 264 | # # tensor_logger.add_image( 265 | # # 'train/pred_depth', preds_visual, self.global_step 266 | # # ) 267 | # # tensor_logger.add_image( 268 | # # 'train/gt_depth', gt_visual, self.global_step 269 | # # ) 270 | 271 | # # we can return here dict with any tensors 272 | # # and then read it in some callback or in `training_epoch_end()`` below 273 | # # remember to always return loss from `training_step()` or else backpropagation will fail! 274 | # # self.manual_backward(loss) 275 | 276 | # return loss 277 | 278 | # def training_epoch_end(self, outputs: List[Any]): 279 | # # `outputs` is a list of dicts returned from `training_step()` 280 | # pass 281 | 282 | # def validation_step(self, batch: Any, batch_idx: int): 283 | # loss, preds, targets = self.step(batch) 284 | # self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 285 | # input_visual = rgb_unnormalize(batch['image'][0]) 286 | # preds_visual = depth_visualization(preds[0]) 287 | # gt_visual = depth_visualization(batch['depth'][0]) 288 | # tensor_logger = self.logger.experiment[0] 289 | # tensor_logger.add_image( 290 | # f'val/input_rgb', input_visual, self.global_step 291 | # ) 292 | # tensor_logger.add_image( 293 | # f'val/pred_depth', preds_visual, self.global_step 294 | # ) 295 | # tensor_logger.add_image( 296 | # f'val/gt_depth', gt_visual, self.global_step 297 | # ) 298 | 299 | # return loss 300 | 301 | # def validation_epoch_end(self, outputs: List[Any]): 302 | # # acc = self.val_acc.compute() # get val accuracy from current epoch 303 | # # self.val_acc_best.update(acc) 304 | # # self.log("val/acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True) 305 | # pass 306 | 307 | # def test_step(self, batch: Any, batch_idx: int): 308 | # loss, preds, targets = self.step(batch) 309 | # self.log("test/loss", loss, on_step=False, on_epoch=True) 310 | 311 | # return loss 312 | 313 | # def test_epoch_end(self, outputs: List[Any]): 314 | # pass 315 | 316 | # def on_epoch_end(self): 317 | # # reset metrics at the end of every epoch 318 | # pass 319 | 320 | # def configure_optimizers(self): 321 | # """Choose what optimizers and learning-rate schedulers to use in your optimization. 322 | # Normally you'd need one. But in the case of GANs or similar you might have multiple. 323 | # See examples here: 324 | # https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers 325 | # """ 326 | # return torch.optim.Adam( 327 | # params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay 328 | # ) -------------------------------------------------------------------------------- /DPT/dpt/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height).""" 50 | 51 | def __init__( 52 | self, 53 | width, 54 | height, 55 | resize_target=True, 56 | keep_aspect_ratio=False, 57 | ensure_multiple_of=1, 58 | resize_method="lower_bound", 59 | ): 60 | """Init. 61 | 62 | Args: 63 | width (int): desired output width 64 | height (int): desired output height 65 | resize_target (bool, optional): 66 | True: Resize the full sample (image, mask, target). 67 | False: Resize image only. 68 | Defaults to True. 69 | keep_aspect_ratio (bool, optional): 70 | True: Keep the aspect ratio of the input sample. 71 | Output sample might not have the given width and height, and 72 | resize behaviour depends on the parameter 'resize_method'. 73 | Defaults to False. 74 | ensure_multiple_of (int, optional): 75 | Output width and height is constrained to be multiple of this parameter. 76 | Defaults to 1. 77 | resize_method (str, optional): 78 | "lower_bound": Output will be at least as large as the given size. 79 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 80 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 81 | Defaults to "lower_bound". 82 | """ 83 | self.__width = width 84 | self.__height = height 85 | 86 | self.__resize_target = resize_target 87 | self.__keep_aspect_ratio = keep_aspect_ratio 88 | self.__multiple_of = ensure_multiple_of 89 | self.__resize_method = resize_method 90 | self.__image_interpolation_method = cv2.INTER_CUBIC 91 | 92 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 93 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 94 | 95 | if max_val is not None and y > max_val: 96 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 97 | 98 | if y < min_val: 99 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 100 | 101 | return y 102 | 103 | def get_size(self, width, height): 104 | # determine new height and width 105 | scale_height = self.__height / height 106 | scale_width = self.__width / width 107 | 108 | if self.__keep_aspect_ratio: 109 | if self.__resize_method == "lower_bound": 110 | # scale such that output size is lower bound 111 | if scale_width > scale_height: 112 | # fit width 113 | scale_height = scale_width 114 | else: 115 | # fit height 116 | scale_width = scale_height 117 | elif self.__resize_method == "upper_bound": 118 | # scale such that output size is upper bound 119 | if scale_width < scale_height: 120 | # fit width 121 | scale_height = scale_width 122 | else: 123 | # fit height 124 | scale_width = scale_height 125 | elif self.__resize_method == "minimal": 126 | # scale as least as possbile 127 | if abs(1 - scale_width) < abs(1 - scale_height): 128 | # fit width 129 | scale_height = scale_width 130 | else: 131 | # fit height 132 | scale_width = scale_height 133 | else: 134 | raise ValueError( 135 | f"resize_method {self.__resize_method} not implemented" 136 | ) 137 | 138 | if self.__resize_method == "lower_bound": 139 | new_height = self.constrain_to_multiple_of( 140 | scale_height * height, min_val=self.__height 141 | ) 142 | new_width = self.constrain_to_multiple_of( 143 | scale_width * width, min_val=self.__width 144 | ) 145 | elif self.__resize_method == "upper_bound": 146 | new_height = self.constrain_to_multiple_of( 147 | scale_height * height, max_val=self.__height 148 | ) 149 | new_width = self.constrain_to_multiple_of( 150 | scale_width * width, max_val=self.__width 151 | ) 152 | elif self.__resize_method == "minimal": 153 | new_height = self.constrain_to_multiple_of(scale_height * height) 154 | new_width = self.constrain_to_multiple_of(scale_width * width) 155 | else: 156 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 157 | 158 | return (new_width, new_height) 159 | 160 | def __call__(self, sample): 161 | width, height = self.get_size( 162 | sample["image"].shape[1], sample["image"].shape[0] 163 | ) 164 | 165 | # resize sample 166 | sample["image"] = cv2.resize( 167 | sample["image"], 168 | (width, height), 169 | interpolation=self.__image_interpolation_method, 170 | ) 171 | 172 | if self.__resize_target: 173 | if "disparity" in sample: 174 | sample["disparity"] = cv2.resize( 175 | sample["disparity"], 176 | (width, height), 177 | interpolation=cv2.INTER_NEAREST, 178 | ) 179 | 180 | if "depth" in sample: 181 | sample["depth"] = cv2.resize( 182 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 183 | ) 184 | if "mask" in sample: 185 | sample["mask"] = cv2.resize( 186 | sample["mask"].astype(np.float32), 187 | (width, height), 188 | interpolation=cv2.INTER_NEAREST, 189 | ) 190 | sample["mask"] = sample["mask"].astype(bool) 191 | 192 | return sample 193 | 194 | 195 | class NormalizeImage(object): 196 | """Normlize image by given mean and std.""" 197 | 198 | def __init__(self, mean, std): 199 | self.__mean = mean 200 | self.__std = std 201 | 202 | def __call__(self, sample): 203 | sample["image"] = (sample["image"] - self.__mean) / self.__std 204 | 205 | return sample 206 | 207 | 208 | class PrepareForNet(object): 209 | """Prepare sample for usage as network input.""" 210 | 211 | def __init__(self): 212 | pass 213 | 214 | def __call__(self, sample): 215 | image = np.transpose(sample["image"], (2, 0, 1)) 216 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 217 | 218 | if "mask" in sample: 219 | sample["mask"] = sample["mask"].astype(np.float32) 220 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 221 | 222 | if "disparity" in sample: 223 | disparity = sample["disparity"].astype(np.float32) 224 | sample["disparity"] = np.ascontiguousarray(disparity) 225 | 226 | if "depth" in sample: 227 | depth = sample["depth"].astype(np.float32) 228 | sample["depth"] = np.ascontiguousarray(depth) 229 | 230 | return sample 231 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Wenjing Bian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NoPe-NeRF: Optimising Neural Radiance Field with No Pose Prior 2 | 3 | **[Project Page](https://nope-nerf.active.vision/) | [Arxiv](https://arxiv.org/abs/2212.07388) | [Data](https://www.robots.ox.ac.uk/~wenjing/Tanks.zip) | [Pretrained Model](https://www.robots.ox.ac.uk/~wenjing/pretrained_Tanks.zip)** 4 | 5 | Wenjing Bian, 6 | Zirui Wang, 7 | [Kejie Li](https://likojack.github.io/kejieli/#/home), 8 | [Jiawag Bian](https://jwbian.net/), 9 | [Victor Adrian Prisacariu](http://www.robots.ox.ac.uk/~victor/). (CVPR 2023 highlight) 10 | 11 | Active Vision Lab, University of Oxford. 12 | 13 | 14 | ## Table of Content 15 | - [Installation](#Installation) 16 | - [Data](#Data) 17 | - [Usage](#Usage) 18 | - [Acknowledgement](#Acknowledgement) 19 | - [Citation](#citation) 20 | 21 | ## Installation 22 | 23 | ``` 24 | git clone https://github.com/ActiveVisionLab/nope-nerf.git 25 | cd nope-nerf 26 | conda env create -f environment.yaml 27 | conda activate nope-nerf 28 | ``` 29 | 30 | ## Data and Preprocessing 31 | 1. [Tanks and Temples](https://www.robots.ox.ac.uk/~wenjing/Tanks.zip): 32 | Our pre-processed Tanks and Temples data contains the 8 scenes shown in the paper. Each scene contains images, monocular depth estimations from DPT and COLMAP poses. You can download and unzip it to `data` directory. 33 | 34 | 2. [NeRF LLFF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1): 35 | We also provide config file for NeRF LLFF dataset. You can download the dataset and unzip it to `data` directory. One example of the config file is `configs/LLFF/fern.yaml`. 36 | 37 | 38 | 3. If you want to use your own image sequence with customised camera intrinsics, you need to add an `intrinsics.npz` file to the scene directory. One example of the config file is `configs/Test/images.yaml` (please add your own data to the `data/Test/images` directory). 39 | 40 | 41 | 42 | Monocular depth map generation: you can first download the pre-trained DPT model from [this link](https://drive.google.com/file/d/1dgcJEYYw1F8qirXhZxgNK8dWWz_8gZBD/view?usp=sharing) provided by [Vision Transformers for Dense Prediction](https://github.com/isl-org/DPT) to `DPT` directory, then run 43 | ``` 44 | python preprocess/dpt_depth.py configs/preprocess.yaml 45 | ``` 46 | to generate monocular depth maps. You need to modify the `cfg['dataloading']['path']` and `cfg['dataloading']['scene']` in `configs/preprocess.yaml` to your own image sequence. 47 | 48 | ## Training 49 | 50 | 1. Train a new model from scratch: 51 | 52 | ``` 53 | python train.py configs/Tanks/Ignatius.yaml 54 | ``` 55 | where you can replace `configs/Tanks/Ignatius.yaml` with other config files. 56 | 57 | You can monitor on the training process using [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard): 58 | ``` 59 | tensorboard --logdir ./out --port 6006 60 | ``` 61 | 62 | For available training options, please take a look at `configs/default.yaml`. 63 | ## Evaluation 64 | 1. Evaluate image quality and depth: 65 | ``` 66 | python evaluation/eval.py configs/Tanks/Ignatius.yaml 67 | ``` 68 | To evaluate depth: add `--depth` . Note that you need to add ground truth depth maps by yourself. 69 | 70 | 2. Evaluate poses: 71 | ``` 72 | python evaluation/eval_poses.py configs/Tanks/Ignatius.yaml 73 | ``` 74 | To visualise estimated & ground truth trajectories: add `--vis` 75 | 76 | 77 | ## More Visualisations 78 | Novel view synthesis 79 | ``` 80 | python vis/render.py configs/Tanks/Ignatius.yaml 81 | ``` 82 | Pose visualisation (estimated trajectory only) 83 | ``` 84 | python vis/vis_poses.py configs/Tanks/Ignatius.yaml 85 | ``` 86 | 87 | 88 | ## Acknowledgement 89 | We thank [Theo Costain](https://www.robots.ox.ac.uk/~costain/) and Michael Hobley for helpful comments and proofreading. We thank Shuai Chen and Xinghui Li for insightful discussions. Wenjing Bian is supported by China Scholarship Council (CSC). 90 | 91 | We refer to [NeRFmm](https://github.com/ActiveVisionLab/nerfmm), [UNISURF](https://github.com/autonomousvision/unisurf), [Vision Transformers for Dense Prediction](https://github.com/isl-org/DPT), [kitti-odom-eval](https://github.com/Huangying-Zhan/kitti-odom-eval) and [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch). We thank the excellent code they provide. 92 | 93 | ## Citation 94 | ``` 95 | @inproceedings{bian2022nopenerf, 96 | author = {Wenjing Bian and Zirui Wang and Kejie Li and Jiawang Bian and Victor Adrian Prisacariu}, 97 | title = {NoPe-NeRF: Optimising Neural Radiance Field with No Pose Prior}, 98 | journal = {CVPR}, 99 | year = {2023} 100 | } 101 | ``` -------------------------------------------------------------------------------- /configs/LLFF/fern.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | rendering: 6 | depth_range: [0.0, 1.0] 7 | dist_alpha: True 8 | sample_option: ndc 9 | dataloading: 10 | path: data/nerf_llff_data 11 | scene: ['fern'] 12 | random_ref: 1 13 | resize_factor: 4 14 | training: 15 | out_dir: out/llff/fern 16 | vis_resolution: [75, 100] 17 | extract_images: 18 | resolution: [756, 1008] -------------------------------------------------------------------------------- /configs/Tanks/Ballroom.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Ballroom'] 8 | customized_focal: False 9 | random_ref: 1 10 | training: 11 | out_dir: out/Tanks/Ballroom 12 | auto_scheduler: True 13 | extract_images: 14 | resolution: [540, 960] 15 | -------------------------------------------------------------------------------- /configs/Tanks/Barn.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Barn'] 8 | customized_focal: False 9 | random_ref: 1 10 | training: 11 | out_dir: out/Tanks/Barn 12 | auto_scheduler: True 13 | extract_images: 14 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Tanks/Church.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Church'] 8 | customized_focal: False 9 | random_ref: 4 10 | training: 11 | out_dir: out/Tanks/Church 12 | auto_scheduler: True 13 | extract_images: 14 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Tanks/Family.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Family'] 8 | customized_focal: False 9 | random_ref: 1 10 | sample_rate: 2 11 | training: 12 | out_dir: out/Tanks/Family 13 | auto_scheduler: True 14 | extract_images: 15 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Tanks/Francis.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Francis'] 8 | customized_focal: False 9 | random_ref: 1 10 | training: 11 | out_dir: out/Tanks/Francis 12 | auto_scheduler: True 13 | extract_images: 14 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Tanks/Horse.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Horse'] 8 | customized_focal: False 9 | random_ref: 1 10 | training: 11 | out_dir: out/Tanks/Horse 12 | auto_scheduler: True 13 | extract_images: 14 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Tanks/Ignatius.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Ignatius'] 8 | customized_focal: False 9 | random_ref: 1 10 | training: 11 | out_dir: out/Tanks/Ignatius 12 | auto_scheduler: True 13 | extract_images: 14 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Tanks/Museum.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Tanks 7 | scene: ['Museum'] 8 | customized_focal: False 9 | random_ref: 1 10 | training: 11 | out_dir: out/Tanks/Museum 12 | auto_scheduler: True 13 | extract_images: 14 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Test/images.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_pose: True 5 | dataloading: 6 | path: data/Test 7 | scene: ['images'] 8 | load_colmap_poses: False 9 | customized_focal: True 10 | training: 11 | out_dir: out/Test/images 12 | auto_scheduler: True 13 | eval_pose_every: -1 14 | extract_images: 15 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/Test/nerf.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: None 3 | pose: 4 | learn_R: False 5 | learn_t: False 6 | init_pose: True 7 | init_pose_type: gt 8 | dataloading: 9 | path: data/Tanks 10 | scene: ['Ignatius'] 11 | training: 12 | out_dir: out/Tanks/Ignatius_nerf 13 | auto_scheduler: False 14 | scheduling_start: 0 15 | annealing_epochs: 0 16 | extract_images: 17 | resolution: [540, 960] -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | num_layers: 8 3 | freeze_network: False 4 | use_image_feature: False 5 | network_type: official 6 | occ_activation: softplus 7 | hidden_dim: 256 8 | pos_enc_levels: 10 9 | dir_enc_levels: 4 10 | dataloading: 11 | dataset_name: any 12 | path: 13 | scene: [] 14 | batchsize: 1 15 | n_workers: 1 16 | img_size: 17 | path: 18 | with_depth: False 19 | with_mask: False 20 | spherify: True 21 | customized_poses: False #use poses other than colmap 22 | customized_focal: False #use focal other than colmap 23 | resize_factor: 24 | depth_net: dpt 25 | crop_size: 0 26 | random_ref: 1 27 | norm_depth: False 28 | load_colmap_poses: True 29 | shuffle: True 30 | sample_rate: 8 31 | 32 | rendering: 33 | type: nope_nerf # changed 34 | n_max_network_queries: 64000 35 | white_background: False 36 | radius: 4.0 37 | num_points: 128 38 | depth_range: [0.01, 10] 39 | dist_alpha: False 40 | use_ray_dir: True 41 | normalise_ray: True 42 | normal_loss: False 43 | sample_option: uniform 44 | outside_steps: 0 45 | depth: 46 | type: None 47 | path: DPT/dpt_hybrid-midas-501f0c75.pt 48 | non_negative: True 49 | scale: 0.000305 50 | shift: 0.1378 51 | invert: True 52 | freeze: True 53 | pose: 54 | learn_pose: True 55 | learn_R: True 56 | learn_t: True 57 | init_pose: False 58 | init_R_only: False 59 | learn_focal: False 60 | update_focal: True 61 | fx_only: False 62 | focal_order: 2 63 | init_pose_type: gt 64 | init_focal_type: gt 65 | distortion: 66 | learn_distortion: True 67 | fix_scaleN: True 68 | learn_scale: True 69 | learn_shift: True 70 | training: 71 | type: nope_nerf 72 | load_dir: model.pt 73 | load_pose_dir: model_pose.pt 74 | load_focal_dir: model_focal.pt 75 | load_distortion_dir: model_distortion.pt 76 | n_training_points: 1024 77 | scheduling_epoch: 10000 78 | batch_size: 1 79 | learning_rate: 0.001 80 | focal_lr: 0.001 81 | pose_lr: 0.0005 82 | distortion_lr: 0.0005 83 | weight_decay: 0.0 84 | scheduler_gamma_pose: 0.9 85 | scheduler_gamma: 0.9954 86 | scheduler_gamma_distortion: 0.9 87 | scheduler_gamma_focal: 0.9 88 | validate_every: -1 89 | visualize_every: 10000 90 | eval_pose_every: 1 # epoch 91 | eval_img_every: 1 # epoch 92 | print_every: 100 93 | backup_every: 10000 94 | checkpoint_every: 5000 95 | rgb_weight: [1.0, 1.0] 96 | depth_weight: [0.04, 0.0] 97 | weight_dist_2nd_loss: [0.0, 0.0] 98 | weight_dist_1st_loss: [0.0, 0.0] 99 | pc_weight: [1.0, 0.0] 100 | rgb_s_weight: [1.0, 0.0] 101 | depth_consistency_weight: [0.0, 0.0] 102 | rgb_loss_type: l1 103 | depth_loss_type: l1 104 | log_scale_shift_per_view: False 105 | with_auto_mask: False 106 | vis_geo: True 107 | vis_resolution: [54, 96] 108 | mode: train 109 | with_ssim: False 110 | use_gt_depth: False 111 | load_ckpt_model_only: False 112 | optim: Adam 113 | detach_gt_depth: False 114 | match_method: dense 115 | pc_ratio: 4 116 | shift_first: False 117 | detach_ref_img: True 118 | scheduling_start: 10000 119 | auto_scheduler: True 120 | length_smooth: 1000 121 | patient: 30 122 | scale_pcs: True 123 | detach_rgbs_scale: False 124 | scheduling_mode: 125 | vis_reprojection_every: 5000 126 | nearest_limit: 0.01 127 | annealing_epochs: 2000 #should be >=0 128 | extract_images: 129 | extraction_dir: extraction 130 | N_novel_imgs: 120 131 | traj_option: bspline 132 | use_learnt_poses: True 133 | use_learnt_focal: True 134 | resolution: 135 | model_file: model.pt 136 | model_file_pose: model_pose.pt 137 | model_file_focal: model_focal.pt 138 | eval_depth: False 139 | bspline_degree: 100 140 | eval_pose: 141 | n_points: 1024 142 | type: nope_nerf 143 | type_to_eval: eval 144 | opt_pose_epoch: 1000 145 | extraction_dir: extraction 146 | init_method: pre 147 | opt_eval_lr: 0.001 148 | 149 | 150 | -------------------------------------------------------------------------------- /configs/preprocess.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | type: DPT 3 | dataloading: 4 | path: data/nerf_llff_data 5 | scene: ['fern'] 6 | resize_factor: 7 | load_colmap_poses: False 8 | training: 9 | mode: 'all' -------------------------------------------------------------------------------- /dataloading/__init__.py: -------------------------------------------------------------------------------- 1 | from dataloading.dataloading import get_dataloader 2 | from dataloading.configloading import load_config 3 | -------------------------------------------------------------------------------- /dataloading/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | import numpy as np 5 | import imageio 6 | import cv2 7 | 8 | def _minify(basedir, factors=[], resolutions=[], img_folder='images'): 9 | needtoload = False 10 | for r in factors: 11 | imgdir = os.path.join(basedir, img_folder + '_{}'.format(r)) 12 | if not os.path.exists(imgdir): 13 | needtoload = True 14 | for r in resolutions: 15 | imgdir = os.path.join(basedir, img_folder + '_{}x{}'.format(r[1], r[0])) 16 | if not os.path.exists(imgdir): 17 | needtoload = True 18 | if not needtoload: 19 | return 20 | 21 | from shutil import copy 22 | from subprocess import check_output 23 | 24 | imgdir = os.path.join(basedir, img_folder) 25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 27 | imgdir_orig = imgdir 28 | 29 | wd = os.getcwd() 30 | 31 | for r in factors + resolutions: 32 | if isinstance(r, int): 33 | name = img_folder + '_{}'.format(r) 34 | resizearg = '{}%'.format(100./r) 35 | else: 36 | name = img_folder + '_{}x{}'.format(r[1], r[0]) 37 | resizearg = '{}x{}'.format(r[1], r[0]) 38 | imgdir = os.path.join(basedir, name) 39 | if os.path.exists(imgdir): 40 | continue 41 | 42 | print('Minifying', r, basedir) 43 | 44 | os.makedirs(imgdir) 45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 46 | 47 | ext = imgs[0].split('.')[-1] 48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 49 | print(args) 50 | os.chdir(imgdir) 51 | check_output(args, shell=True) 52 | os.chdir(wd) 53 | 54 | if ext != 'png': 55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 56 | print('Removed duplicates') 57 | print('Done') 58 | 59 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True, crop_size=0, load_colmap_poses=True): 60 | if load_colmap_poses: 61 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 62 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) # 3 x 5 x N 63 | bds = poses_arr[:, -2:].transpose([1,0]) 64 | img_folder = 'images' 65 | crop_ratio = 1 66 | focal_crop_factor = 1 67 | if crop_size!=0: 68 | img_folder = 'images_cropped' 69 | crop_dir = os.path.join(basedir, 'images_cropped') 70 | if not os.path.exists(crop_dir): 71 | os.makedirs(crop_dir) 72 | for f in sorted(os.listdir(os.path.join(basedir, 'images'))): 73 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png'): 74 | image = imageio.imread(os.path.join(basedir, 'images', f)) 75 | crop_size_H = crop_size 76 | H, W, _ = image.shape 77 | crop_size_W = int(crop_size_H * W/H) 78 | image_cropped = image[crop_size_H:H-crop_size_H, crop_size_W:W-crop_size_W] 79 | save_path = os.path.join(crop_dir, f) 80 | im = Image.fromarray(image_cropped) 81 | im = im.resize((W, H)) 82 | im.save(save_path) 83 | crop_ratio = crop_size_H / H 84 | print('=======images cropped=======') 85 | focal_crop_factor = (H - 2*crop_size_H) / H 86 | 87 | 88 | 89 | 90 | img0 = [os.path.join(basedir, img_folder, f) for f in sorted(os.listdir(os.path.join(basedir, img_folder))) \ 91 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 92 | sh = imageio.imread(img0).shape 93 | 94 | sfx = '' 95 | 96 | if factor is not None: 97 | sfx = '_{}'.format(factor) 98 | _minify(basedir, factors=[factor], img_folder=img_folder) 99 | factor = factor 100 | elif height is not None: 101 | factor = sh[0] / float(height) 102 | width = int(sh[1] / factor) 103 | _minify(basedir, resolutions=[[height, width]], img_folder=img_folder) 104 | sfx = '_{}x{}'.format(width, height) 105 | elif width is not None: 106 | factor = sh[1] / float(width) 107 | height = int(sh[0] / factor) 108 | _minify(basedir, resolutions=[[height, width]], img_folder=img_folder) 109 | sfx = '_{}x{}'.format(width, height) 110 | else: 111 | factor = 1 112 | 113 | imgdir = os.path.join(basedir, img_folder + sfx) 114 | if not os.path.exists(imgdir): 115 | print( imgdir, 'does not exist, returning' ) 116 | return 117 | 118 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 119 | sh = imageio.imread(imgfiles[0]).shape 120 | if load_colmap_poses: 121 | if poses.shape[-1] != len(imgfiles): 122 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 123 | return 124 | 125 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 126 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 127 | 128 | if not load_imgs: 129 | return poses, bds 130 | 131 | def imread(f): 132 | if f.endswith('png'): 133 | return imageio.imread(f, ignoregamma=True) 134 | else: 135 | return imageio.imread(f) 136 | 137 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] 138 | imgs = np.stack(imgs, -1) 139 | 140 | if load_colmap_poses: 141 | print('Loaded image data', imgs.shape, poses[:,-1,0]) 142 | else: 143 | print('Loaded image data', imgs.shape) 144 | poses=None 145 | bds=None 146 | # added 147 | imgnames = [f for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 148 | return poses, bds, imgs, imgnames, crop_ratio, focal_crop_factor 149 | def recenter_poses(poses): 150 | 151 | poses_ = poses+0 152 | bottom = np.reshape([0,0,0,1.], [1,4]) 153 | c2w = poses_avg(poses) 154 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 155 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 156 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 157 | 158 | poses = np.linalg.inv(c2w) @ poses 159 | poses_[:,:3,:4] = poses[:,:3,:4] 160 | poses = poses_ 161 | return poses 162 | def poses_avg(poses): 163 | 164 | hwf = poses[0, :3, -1:] 165 | 166 | center = poses[:, :3, 3].mean(0) 167 | vec2 = normalize(poses[:, :3, 2].sum(0)) 168 | up = poses[:, :3, 1].sum(0) 169 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 170 | 171 | return c2w 172 | def normalize(x): 173 | return x / np.linalg.norm(x) 174 | 175 | def viewmatrix(z, up, pos): 176 | vec2 = normalize(z) 177 | vec1_avg = up 178 | vec0 = normalize(np.cross(vec1_avg, vec2)) 179 | vec1 = normalize(np.cross(vec2, vec0)) 180 | m = np.stack([vec0, vec1, vec2, pos], 1) 181 | return m 182 | def spherify_poses(poses, bds): 183 | 184 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 185 | 186 | rays_d = poses[:,:3,2:3] 187 | rays_o = poses[:,:3,3:4] 188 | 189 | def min_line_dist(rays_o, rays_d): 190 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 191 | b_i = -A_i @ rays_o 192 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 193 | return pt_mindist 194 | 195 | pt_mindist = min_line_dist(rays_o, rays_d) 196 | 197 | center = pt_mindist 198 | up = (poses[:,:3,3] - center).mean(0) 199 | 200 | vec0 = normalize(up) 201 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 202 | vec2 = normalize(np.cross(vec0, vec1)) 203 | pos = center 204 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 205 | 206 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 207 | 208 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 209 | 210 | sc = 1./rad 211 | poses_reset[:,:3,3] *= sc 212 | bds *= sc 213 | rad *= sc 214 | 215 | centroid = np.mean(poses_reset[:,:3,3], 0) 216 | zh = centroid[2] 217 | radcircle = np.sqrt(rad**2-zh**2) 218 | new_poses = [] 219 | 220 | for th in np.linspace(0.,2.*np.pi, 120): 221 | 222 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 223 | up = np.array([0,0,-1.]) 224 | 225 | vec2 = normalize(camorigin) 226 | vec0 = normalize(np.cross(vec2, up)) 227 | vec1 = normalize(np.cross(vec2, vec0)) 228 | pos = camorigin 229 | p = np.stack([vec0, vec1, vec2, pos], 1) 230 | 231 | new_poses.append(p) 232 | 233 | new_poses = np.stack(new_poses, 0) 234 | 235 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 236 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 237 | 238 | return poses_reset, new_poses, bds 239 | 240 | 241 | def load_gt_depths(image_list, datadir, H=None, W=None, crop_ratio=1): 242 | depths = [] 243 | for image_name in image_list: 244 | frame_id = image_name.split('.')[0] 245 | depth_path = os.path.join(datadir, 'depth', '{}.png'.format(frame_id)) 246 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) 247 | depth = depth.astype(np.float32) / 1000 248 | if crop_ratio != 1: 249 | h, w = depth.shape 250 | crop_size_h = int(h*crop_ratio) 251 | crop_size_w = int(w*crop_ratio) 252 | depth = depth[crop_size_h:h-crop_size_h, crop_size_w:w-crop_size_w] 253 | 254 | if H is not None: 255 | # mask = (depth > 0).astype(np.uint8) 256 | depth_resize = cv2.resize(depth, (W, H), interpolation=cv2.INTER_NEAREST) 257 | # mask_resize = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST) 258 | depths.append(depth_resize) 259 | # masks.append(mask_resize > 0.5) 260 | else: 261 | depths.append(depth) 262 | # masks.append(depth > 0) 263 | return np.stack(depths) 264 | def load_depths(image_list, datadir, H=None, W=None): 265 | depths = [] 266 | 267 | for image_name in image_list: 268 | frame_id = image_name.split('.')[0] 269 | depth_path = os.path.join(datadir, '{}_depth.npy'.format(frame_id)) 270 | if not os.path.exists(depth_path): 271 | depth_path = os.path.join(datadir, 'depth_{}.npy'.format(frame_id)) 272 | depth = np.load(depth_path) 273 | 274 | if H is not None: 275 | depth_resize = cv2.resize(depth, (W, H)) 276 | depths.append(depth_resize) 277 | else: 278 | depths.append(depth) 279 | return np.stack(depths) 280 | def load_images(image_list, datadir): 281 | images = [] 282 | 283 | for image_name in image_list: 284 | frame_id = image_name.split('.')[0] 285 | im_path = os.path.join(datadir, '{}.npy'.format(frame_id)) 286 | im = np.load(im_path) 287 | images.append(im) 288 | return np.stack(images) 289 | def load_depths_npz(image_list, datadir, H=None, W=None, norm=False): 290 | depths = [] 291 | 292 | for image_name in image_list: 293 | frame_id = image_name.split('.')[0] 294 | depth_path = os.path.join(datadir, 'depth_{}.npz'.format(frame_id)) 295 | depth = np.load(depth_path)['pred'] 296 | if depth.shape[0] == 1: 297 | depth = depth[0] 298 | 299 | if H is not None: 300 | depth_resize = cv2.resize(depth, (W, H)) 301 | depths.append(depth_resize) 302 | else: 303 | depths.append(depth) 304 | depths = np.stack(depths) 305 | if norm: 306 | depths_n = [] 307 | t_all = np.median(depths) 308 | s_all = np.mean(np.abs(depths - t_all)) 309 | for depth in depths: 310 | t_i = np.median(depth) 311 | s_i = np.mean(np.abs(depth - t_i)) 312 | depth = s_all * (depth - t_i) / s_i + t_all 313 | depths_n.append(depth) 314 | depths = np.stack(depths_n) 315 | return depths -------------------------------------------------------------------------------- /dataloading/configloading.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def load_config(path, default_path=None, inherit_from=None): 4 | ''' Loads config file. 5 | 6 | Args: 7 | path (str): path to config file 8 | default_path (bool): whether to use default path 9 | ''' 10 | # Load configuration from file itself 11 | with open(path, 'r') as f: 12 | cfg_special = yaml.load(f, Loader=yaml.Loader) 13 | 14 | # Check if we should inherit from a config 15 | # inherit_from = cfg_special.get('inherit_from') 16 | 17 | # If yes, load this config first as default 18 | # If no, use the default_path 19 | if inherit_from is not None: 20 | cfg = load_config(inherit_from, default_path) 21 | elif default_path is not None: 22 | with open(default_path, 'r') as f: 23 | cfg = yaml.load(f, Loader=yaml.Loader) 24 | else: 25 | cfg = dict() 26 | 27 | # Include main configuration 28 | update_recursive(cfg, cfg_special) 29 | 30 | return cfg 31 | 32 | 33 | def update_recursive(dict1, dict2): 34 | ''' Update two config dictionaries recursively. 35 | 36 | Args: 37 | dict1 (dict): first dictionary to be updated 38 | dict2 (dict): second dictionary which entries should be used 39 | 40 | ''' 41 | for k, v in dict2.items(): 42 | if k not in dict1: 43 | dict1[k] = dict() 44 | if isinstance(v, dict): 45 | update_recursive(dict1[k], v) 46 | else: 47 | dict1[k] = v -------------------------------------------------------------------------------- /dataloading/dataloading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import logging 5 | import torch 6 | from torch.utils import data 7 | import numpy as np 8 | from torchvision import transforms 9 | from .dataset import DataField 10 | from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet 11 | logger = logging.getLogger(__name__) 12 | 13 | def get_dataloader(cfg, mode='train', 14 | shuffle=True, n_views=None): 15 | ''' Return dataloader instance 16 | 17 | Instansiate dataset class and dataloader and 18 | return dataloader 19 | 20 | Args: 21 | cfg (dict): imported config for dataloading 22 | mode (str): tran/eval/render/all 23 | shuffle (bool): as name 24 | n_views (int): specify number of views during rendering 25 | ''' 26 | 27 | batch_size = cfg['dataloading']['batchsize'] 28 | n_workers = cfg['dataloading']['n_workers'] 29 | 30 | fields = get_data_fields(cfg, mode) 31 | if n_views is not None and mode=='render': 32 | n_views = n_views 33 | else: 34 | n_views = fields['img'].N_imgs 35 | ## get dataset 36 | dataset = OurDataset( 37 | fields, n_views=n_views, mode=mode) 38 | 39 | ## dataloader 40 | dataloader = torch.utils.data.DataLoader( 41 | dataset, batch_size=batch_size, num_workers=n_workers, 42 | shuffle=shuffle, pin_memory=True 43 | ) 44 | 45 | return dataloader, fields 46 | 47 | 48 | def get_data_fields(cfg, mode='train'): 49 | ''' Returns the data fields. 50 | 51 | Args: 52 | cfg (dict): imported yaml config 53 | mode (str): the mode which is used 54 | 55 | Return: 56 | field (dict): datafield 57 | ''' 58 | use_DPT = (cfg['depth']['type']=='DPT') 59 | resize_img_transform = ResizeImage_mvs() # for dpt input images 60 | fields = {} 61 | load_ref_img = ((cfg['training']['pc_weight']!=0.0) or (cfg['training']['rgb_s_weight']!=0.0)) 62 | dataset_name = cfg['dataloading']['dataset_name'] 63 | if dataset_name=='any': 64 | img_field = DataField( 65 | model_path=cfg['dataloading']['path'], 66 | transform=resize_img_transform, 67 | with_camera=True, 68 | with_depth=cfg['dataloading']['with_depth'], 69 | scene_name=cfg['dataloading']['scene'], 70 | use_DPT=use_DPT, mode=mode,spherify=cfg['dataloading']['spherify'], 71 | load_ref_img=load_ref_img, customized_poses=cfg['dataloading']['customized_poses'], 72 | customized_focal=cfg['dataloading']['customized_focal'], 73 | resize_factor=cfg['dataloading']['resize_factor'], depth_net=cfg['dataloading']['depth_net'], 74 | crop_size=cfg['dataloading']['crop_size'], random_ref=cfg['dataloading']['random_ref'], norm_depth=cfg['dataloading']['norm_depth'], 75 | load_colmap_poses=cfg['dataloading']['load_colmap_poses'], sample_rate=cfg['dataloading']['sample_rate']) 76 | else: 77 | print(dataset_name, 'does not exist') 78 | fields['img'] = img_field 79 | return fields 80 | class ResizeImage_mvs(object): 81 | def __init__(self): 82 | net_w = net_h = 384 83 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 84 | self.transform = transforms.Compose( 85 | [ 86 | Resize( 87 | net_w, 88 | net_h, 89 | resize_target=True, 90 | keep_aspect_ratio=True, 91 | ensure_multiple_of=32, 92 | resize_method="minimal" 93 | ), 94 | normalization, 95 | PrepareForNet(), 96 | ] 97 | ) 98 | def __call__(self, img): 99 | img = self.transform(img) 100 | return img 101 | 102 | 103 | 104 | 105 | class OurDataset(data.Dataset): 106 | '''Dataset class 107 | ''' 108 | 109 | def __init__(self, fields, n_views=0, mode='train'): 110 | # Attributes 111 | self.fields = fields 112 | print(mode,': ', n_views, ' views') 113 | self.n_views = n_views 114 | 115 | def __len__(self): 116 | ''' Returns the length of the dataset. 117 | ''' 118 | return self.n_views 119 | 120 | def __getitem__(self, idx): 121 | ''' Returns an item of the dataset. 122 | 123 | Args: 124 | idx (int): ID of data point 125 | ''' 126 | data = {} 127 | for field_name, field in self.fields.items(): 128 | field_data = field.load(idx) 129 | 130 | if isinstance(field_data, dict): 131 | for k, v in field_data.items(): 132 | if k is None: 133 | data[field_name] = v 134 | else: 135 | data['%s.%s' % (field_name, k)] = v 136 | else: 137 | data[field_name] = field_data 138 | 139 | return data 140 | 141 | 142 | 143 | def collate_remove_none(batch): 144 | ''' Collater that puts each data field into a tensor with outer dimension 145 | batch size. 146 | 147 | Args: 148 | batch: batch 149 | ''' 150 | batch = list(filter(lambda x: x is not None, batch)) 151 | return data.dataloader.default_collate(batch) 152 | 153 | 154 | def worker_init_fn(worker_id): 155 | ''' Worker init function to ensure true randomness. 156 | ''' 157 | random_data = os.urandom(4) 158 | base_seed = int.from_bytes(random_data, byteorder="big") 159 | np.random.seed(base_seed + worker_id) 160 | -------------------------------------------------------------------------------- /dataloading/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import logging 5 | import torch 6 | from PIL import Image 7 | import numpy as np 8 | import imageio 9 | import cv2 10 | from dataloading.common import _load_data, recenter_poses, spherify_poses, load_depths_npz, load_gt_depths 11 | logger = logging.getLogger(__name__) 12 | 13 | class DataField(object): 14 | def __init__(self, model_path, 15 | transform=None, 16 | with_camera=False, 17 | with_depth=False, 18 | use_DPT=False, scene_name=[' '], mode='train', spherify=False, 19 | load_ref_img=False,customized_poses=False, 20 | customized_focal=False,resize_factor=2, depth_net='dpt',crop_size=0, 21 | random_ref=False,norm_depth=False,load_colmap_poses=True, sample_rate=8, **kwargs): 22 | """load images, depth maps, etc. 23 | Args: 24 | model_path (str): path of dataset 25 | transform (class, optional): transform made to the image. Defaults to None. 26 | with_camera (bool, optional): load camera intrinsics. Defaults to False. 27 | with_depth (bool, optional): load gt depth maps (if available). Defaults to False. 28 | DPT (bool, optional): run DPT model. Defaults to False. 29 | scene_name (list, optional): scene folder name. Defaults to [' ']. 30 | mode (str, optional): train/eval/all/render. Defaults to 'train'. 31 | spherify (bool, optional): spherify colmap poses (no effect to training). Defaults to False. 32 | load_ref_img (bool, optional): load reference image. Defaults to False. 33 | customized_poses (bool, optional): use GT pose if available. Defaults to False. 34 | customized_focal (bool, optional): use GT focal if provided. Defaults to False. 35 | resize_factor (int, optional): image downsample factor. Defaults to 2. 36 | depth_net (str, optional): which depth estimator use. Defaults to 'dpt'. 37 | crop_size (int, optional): crop if images have black border. Defaults to 0. 38 | random_ref (bool/int, optional): if use a random reference image/number of neaest images. Defaults to False. 39 | norm_depth (bool, optional): normalise depth maps. Defaults to False. 40 | load_colmap_poses (bool, optional): load colmap poses. Defaults to True. 41 | sample_rate (int, optional): 1 in 'sample_rate' images as test set. Defaults to 8. 42 | """ 43 | self.transform = transform 44 | self.with_camera = with_camera 45 | self.with_depth = with_depth 46 | self.use_DPT = use_DPT 47 | self.mode = mode 48 | self.ref_img = load_ref_img 49 | self.random_ref = random_ref 50 | self.sample_rate = sample_rate 51 | 52 | load_dir = os.path.join(model_path, scene_name[0]) 53 | if crop_size!=0: 54 | depth_net = depth_net + '_' + str(crop_size) 55 | poses, bds, imgs, img_names, crop_ratio, focal_crop_factor = _load_data(load_dir, factor=resize_factor, crop_size=crop_size, load_colmap_poses=load_colmap_poses) 56 | if load_colmap_poses: 57 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 58 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 59 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 60 | bd_factor = 0.75 61 | # Rescale if bd_factor is provided 62 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 63 | poses[:,:3,3] *= sc 64 | bds *= sc 65 | poses = recenter_poses(poses) 66 | if spherify: 67 | poses, render_poses, bds = spherify_poses(poses, bds) 68 | input_poses = poses.astype(np.float32) 69 | hwf = input_poses[0,:3,-1] 70 | self.hwf = input_poses[:,:3,:] 71 | input_poses = input_poses[:,:3,:4] 72 | H, W, focal = hwf 73 | H, W = int(H), int(W) 74 | poses_tensor = torch.from_numpy(input_poses) 75 | bottom = torch.FloatTensor([0, 0, 0, 1]).unsqueeze(0) 76 | bottom = bottom.repeat(poses_tensor.shape[0], 1, 1) 77 | c2ws_colmap = torch.cat([poses_tensor, bottom], 1) 78 | 79 | 80 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 81 | imgs = np.transpose(imgs, (0, 3, 1, 2)) 82 | _, _, h, w = imgs.shape 83 | 84 | if customized_focal: 85 | focal_gt = np.load(os.path.join(load_dir, 'intrinsics.npz'))['K'].astype(np.float32) 86 | if resize_factor is None: 87 | resize_factor = 1 88 | fx = focal_gt[0, 0] / resize_factor 89 | fy = focal_gt[1, 1] / resize_factor 90 | else: 91 | if load_colmap_poses: 92 | fx, fy = focal, focal 93 | else: 94 | print('No focal provided, use image size as default') 95 | fx, fy = w, h 96 | fx = fx / focal_crop_factor 97 | fy = fy / focal_crop_factor 98 | 99 | 100 | self.H, self.W, self.focal = h, w, fx 101 | self.K = np.array([[2*fx/w, 0, 0, 0], 102 | [0, -2*fy/h, 0, 0], 103 | [0, 0, -1, 0], 104 | [0, 0, 0, 1]]).astype(np.float32) 105 | ids = np.arange(imgs.shape[0]) 106 | i_test = ids[int(sample_rate/2)::sample_rate] 107 | i_train = np.array([i for i in ids if i not in i_test]) 108 | self.i_train = i_train 109 | self.i_test = i_test 110 | image_list_train = [img_names[i] for i in i_train] 111 | image_list_test = [img_names[i] for i in i_test] 112 | print('test set: ', image_list_test) 113 | 114 | if customized_poses: 115 | c2ws_gt = np.load(os.path.join(load_dir, 'gt_poses.npz'))['poses'].astype(np.float32) 116 | T = torch.tensor(np.array([[1, 0, 0, 0],[0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float32)) # ScanNet coordinate 117 | c2ws_gt = torch.from_numpy(c2ws_gt) 118 | c2ws = c2ws_gt @ T 119 | else: 120 | if load_colmap_poses: 121 | c2ws = c2ws_colmap 122 | else: 123 | c2ws = None 124 | 125 | 126 | self.N_imgs_train = len(i_train) 127 | self.N_imgs_test = len(i_test) 128 | 129 | pred_depth_path = os.path.join(load_dir, depth_net) 130 | self.dpt_depth = None 131 | if mode in ('train','eval_trained', 'render'): 132 | idx_list = i_train 133 | self.img_list = image_list_train 134 | elif mode=='eval': 135 | idx_list = i_test 136 | self.img_list = image_list_test 137 | elif mode=='all': 138 | idx_list = ids 139 | self.img_list = img_names 140 | 141 | self.imgs = imgs[idx_list] 142 | self.N_imgs = len(idx_list) 143 | if c2ws is not None: 144 | self.c2ws = c2ws[idx_list] 145 | if load_colmap_poses: 146 | self.c2ws_colmap = c2ws_colmap[i_train] 147 | if not use_DPT: 148 | self.dpt_depth = load_depths_npz(image_list_train, pred_depth_path, norm=norm_depth) 149 | if with_depth: 150 | self.depth = load_gt_depths(image_list_train, load_dir, crop_ratio=crop_ratio) 151 | 152 | 153 | 154 | 155 | def load(self, input_idx_img=None): 156 | ''' Loads the field. 157 | ''' 158 | return self.load_field(input_idx_img) 159 | 160 | def load_image(self, idx, data={}): 161 | image = self.imgs[idx] 162 | data[None] = image 163 | if self.use_DPT: 164 | data_in = {"image": np.transpose(image, (1, 2, 0))} 165 | data_in = self.transform(data_in) 166 | data['normalised_img'] = data_in['image'] 167 | data['idx'] = idx 168 | def load_ref_img(self, idx, data={}): 169 | if self.random_ref: 170 | if idx==self.N_imgs-1: 171 | ref_idx = idx-1 172 | else: 173 | ran_idx = random.randint(1, min(self.random_ref, self.N_imgs-idx-1)) 174 | ref_idx = idx + ran_idx 175 | image = self.imgs[ref_idx] 176 | if self.dpt_depth is not None: 177 | dpt = self.dpt_depth[ref_idx] 178 | data['ref_dpts'] = dpt 179 | if self.use_DPT: 180 | data_in = {"image": np.transpose(image, (1, 2, 0))} 181 | data_in = self.transform(data_in) 182 | normalised_ref_img = data_in['image'] 183 | data['normalised_ref_img'] = normalised_ref_img 184 | if self.with_depth: 185 | depth = self.depth[ref_idx] 186 | data['ref_depths'] = depth 187 | data['ref_imgs'] = image 188 | data['ref_idxs'] = ref_idx 189 | 190 | def load_depth(self, idx, data={}): 191 | depth = self.depth[idx] 192 | data['depth'] = depth 193 | def load_DPT_depth(self, idx, data={}): 194 | depth_dpt = self.dpt_depth[idx] 195 | data['dpt'] = depth_dpt 196 | 197 | def load_camera(self, idx, data={}): 198 | data['camera_mat'] = self.K 199 | data['scale_mat'] = np.array([[1, 0, 0, 0], [0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]]).astype(np.float32) 200 | data['idx'] = idx 201 | 202 | 203 | 204 | def load_field(self, input_idx_img=None): 205 | if input_idx_img is not None: 206 | idx_img = input_idx_img 207 | else: 208 | idx_img = 0 209 | # Load the data 210 | data = {} 211 | if not self.mode =='render': 212 | self.load_image(idx_img, data) 213 | if self.ref_img: 214 | self.load_ref_img(idx_img, data) 215 | if self.with_depth: 216 | self.load_depth(idx_img, data) 217 | if self.dpt_depth is not None: 218 | self.load_DPT_depth(idx_img, data) 219 | if self.with_camera: 220 | self.load_camera(idx_img, data) 221 | 222 | return data 223 | 224 | 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: nope-nerf 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - python=3.9 9 | - pytorch=1.7 10 | - torchvision=0.8.2 11 | - torchaudio 12 | - cudatoolkit=10.1 13 | - cffi 14 | - cython 15 | - imageio 16 | - numpy 17 | - scipy 18 | - matplotlib 19 | - pandas 20 | - tensorboard 21 | - yaml 22 | - pillow 23 | - wheel 24 | - pip 25 | - tqdm 26 | - pip: 27 | - ipdb 28 | - ipython 29 | - ipython-genutils 30 | - jedi 31 | - opencv-python 32 | - scikit-image 33 | - pyyaml 34 | - trimesh 35 | - imageio 36 | - matplotlib 37 | - plyfile 38 | - timm 39 | - lpips 40 | - setuptools 41 | - kornia==0.5.0 42 | - imageio-ffmpeg -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from re import L 3 | import sys 4 | import argparse 5 | import time 6 | import logging 7 | import torch 8 | 9 | sys.path.append(os.path.join(sys.path[0], '..')) 10 | from dataloading import get_dataloader, load_config 11 | from model.checkpoints import CheckpointIO 12 | from model.common import compute_errors 13 | from model.eval_images import Eval_Images 14 | import model as mdl 15 | import imageio 16 | import numpy as np 17 | import lpips as lpips_lib 18 | from utils_poses.align_traj import align_scale_c2b_use_a2b, align_ate_c2b_use_a2b 19 | from tqdm import tqdm 20 | from model.common import mse2psnr 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | def eval(cfg): 24 | torch.manual_seed(0) 25 | is_cuda = (torch.cuda.is_available()) 26 | device = torch.device("cuda" if is_cuda else "cpu") 27 | 28 | out_dir = cfg['training']['out_dir'] 29 | generation_dir = os.path.join(out_dir, cfg['eval_pose']['extraction_dir']) 30 | if not os.path.exists(generation_dir): 31 | os.makedirs(generation_dir) 32 | log_out_dir = os.path.join(out_dir, 'logs') 33 | writer = SummaryWriter(log_out_dir) 34 | logger = logging.getLogger() 35 | logger.setLevel(logging.WARNING) 36 | file_handler = logging.FileHandler(os.path.join(generation_dir, 'log.txt')) 37 | file_handler.setLevel(logging.INFO) 38 | logger.addHandler(file_handler) 39 | # logger.info(args) 40 | 41 | # Model 42 | network_type = cfg['model']['network_type'] 43 | if network_type=='official': 44 | model = mdl.OfficialStaticNerf(cfg) 45 | 46 | rendering_cfg = cfg['rendering'] 47 | renderer = mdl.Renderer(model, rendering_cfg, device=device) 48 | 49 | # init model 50 | nope_nerf = mdl.get_model(renderer, cfg, device=device) 51 | 52 | checkpoint_io = CheckpointIO(out_dir, model=nope_nerf) # changed 53 | # Dataloading 54 | train_loader, train_dataset = get_dataloader(cfg, mode='train', shuffle=False) 55 | eval_loader, eval_dataset = get_dataloader(cfg, mode='eval', shuffle=False) 56 | checkpoint_io.load(cfg['extract_images']['model_file']) 57 | use_learnt_poses = cfg['pose']['learn_pose'] 58 | use_learnt_focal = cfg['pose']['learn_focal'] 59 | num_epoch = cfg['eval_pose']['opt_pose_epoch'] 60 | init_method = cfg['eval_pose']['init_method'] 61 | opt_eval_lr = cfg['eval_pose']['opt_eval_lr'] 62 | 63 | if cfg['eval_pose']['type_to_eval'] == 'train': 64 | N_imgs = train_dataset['img'].N_imgs 65 | img_list = train_dataset['img'].img_list 66 | loader = train_loader 67 | render_dir = os.path.join(generation_dir, 'eval_trained') 68 | else: 69 | N_imgs = eval_dataset['img'].N_imgs 70 | img_list = eval_dataset['img'].img_list 71 | loader = eval_loader 72 | render_dir = os.path.join(generation_dir, 'eval', init_method) 73 | 74 | 75 | 76 | if use_learnt_focal: 77 | focal_net = mdl.LearnFocal(cfg['pose']['learn_focal'], cfg['pose']['fx_only'], order=cfg['pose']['focal_order']) 78 | checkpoint_io_focal = mdl.CheckpointIO(out_dir, model=focal_net) 79 | checkpoint_io_focal.load(cfg['extract_images']['model_file_focal']) 80 | fxfy = focal_net(0) 81 | print('learned fx: {0:.2f}, fy: {1:.2f}'.format(fxfy[0].item(), fxfy[1].item())) 82 | else: 83 | focal_net = None 84 | fxfy = None 85 | 86 | if use_learnt_poses: 87 | if cfg['pose']['init_pose']: 88 | init_pose = train_dataset['img'].c2ws # init with colmap 89 | else: 90 | init_pose = None 91 | learned_pose_param_net = mdl.LearnPose(train_dataset['img'].N_imgs, cfg['pose']['learn_R'], cfg['pose']['learn_t'], cfg=cfg,init_c2w=init_pose).to(device=device) 92 | checkpoint_io_pose = mdl.CheckpointIO(out_dir, model=learned_pose_param_net) 93 | checkpoint_io_pose.load(cfg['extract_images']['model_file_pose']) 94 | if cfg['eval_pose']['type_to_eval'] == 'train': 95 | eval_pose_param_net = learned_pose_param_net 96 | else: 97 | with torch.no_grad(): 98 | init_c2ws = eval_dataset['img'].c2ws.to(device) 99 | learned_c2ws_train = torch.stack([learned_pose_param_net(i) for i in range(train_dataset['img'].N_imgs_train)]) 100 | colmap_c2ws_train = train_dataset['img'].c2ws # (N, 4, 4) 101 | colmap_c2ws_train = colmap_c2ws_train.to(device) 102 | if init_method=='scale': 103 | init_c2ws, scale_colmap2est = align_scale_c2b_use_a2b(colmap_c2ws_train, learned_c2ws_train, init_c2ws.clone()) 104 | elif init_method=='ate': 105 | init_c2ws = align_ate_c2b_use_a2b(colmap_c2ws_train, learned_c2ws_train, init_c2ws) 106 | elif init_method=='pre': 107 | sample_rate = train_dataset['img'].sample_rate 108 | init_c2ws = learned_c2ws_train[int(sample_rate/2)-1::sample_rate-1][:N_imgs] 109 | elif init_method=='none': 110 | init_c2ws = None 111 | eval_pose_param_net = mdl.LearnPose(eval_dataset['img'].N_imgs, learn_R=True, learn_t=True, cfg=cfg, init_c2w=init_c2ws).to(device=device) 112 | optimizer_eval_pose = torch.optim.Adam(eval_pose_param_net.parameters(), lr=opt_eval_lr) 113 | scheduler_eval_pose = torch.optim.lr_scheduler.MultiStepLR(optimizer_eval_pose, 114 | milestones=list(range(0, int(num_epoch), int(num_epoch/5))), 115 | gamma=0.5) 116 | '''Optimise eval poses''' 117 | if cfg['eval_pose']['type_to_eval'] != 'train': 118 | eval_pose_cfg = cfg['eval_pose'] 119 | trainer = mdl.Trainer_pose(nope_nerf, eval_pose_cfg, device=device, optimizer_pose=optimizer_eval_pose, 120 | pose_param_net=eval_pose_param_net, focal_net=focal_net) 121 | for epoch_i in tqdm(range(num_epoch), desc='optimising eval'): 122 | L2_loss_epoch = [] 123 | psnr_epoch = [] 124 | for batch in eval_loader: 125 | losses = trainer.train_step(batch) 126 | L2_loss_epoch.append(losses['loss'].item()) 127 | L2_loss_mean = np.mean(L2_loss_epoch) 128 | opt_pose_psnr = mse2psnr(L2_loss_mean) 129 | scheduler_eval_pose.step() 130 | 131 | writer.add_scalar('opt/psnr', opt_pose_psnr, epoch_i) 132 | 133 | tqdm.write('{0:6d} ep: Opt: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(epoch_i, L2_loss_mean, opt_pose_psnr)) 134 | eval_pose_param_net.eval() 135 | eval_c2ws= torch.stack([eval_pose_param_net(i) for i in range(N_imgs)]) 136 | 137 | 138 | # Generator 139 | generator = Eval_Images( 140 | renderer, cfg,use_learnt_poses=use_learnt_poses, 141 | use_learnt_focal=use_learnt_focal, 142 | device=device, render_type=cfg['rendering']['type'], c2ws=eval_c2ws, img_list=img_list 143 | ) 144 | 145 | if not os.path.exists(render_dir): 146 | os.makedirs(render_dir) 147 | 148 | imgs = [] 149 | depths = [] 150 | eval_mse_list = [] 151 | eval_psnr_list = [] 152 | eval_ssim_list = [] 153 | eval_lpips_list = [] 154 | depth_gts = [] 155 | depth_preds = [] 156 | # init lpips loss. 157 | lpips_metric = lpips_lib.LPIPS(net='vgg').to(device) 158 | min_depth=0.1 159 | max_depth=20 160 | for data in loader: 161 | out = generator.eval_images(data, render_dir, fxfy, lpips_metric, logger=logger, min_depth=min_depth, max_depth=max_depth) 162 | imgs.append(out['img']) 163 | depths.append(out['depth']) 164 | eval_mse_list.append(out['mse']) 165 | eval_psnr_list.append(out['psnr']) 166 | eval_ssim_list.append(out['ssim']) 167 | eval_lpips_list.append(out['lpips']) 168 | depth_preds.append(out['depth_pred']) 169 | depth_gts.append(out['depth_gt']) 170 | 171 | mean_mse = np.mean(eval_mse_list) 172 | mean_psnr = np.mean(eval_psnr_list) 173 | mean_ssim = np.mean(eval_ssim_list) 174 | mean_lpips = np.mean(eval_lpips_list) 175 | print('--------------------------') 176 | print('Mean MSE: {0:.2f}, PSNR: {1:.2f}, SSIM: {2:.2f}, LPIPS {3:.2f}'.format(mean_mse, mean_psnr, 177 | mean_ssim, mean_lpips)) 178 | 179 | print("{0:.2f}".format(mean_psnr),'&' "{0:.2f}".format(mean_ssim), '&', "{0:.2f}".format(mean_lpips)) 180 | 181 | if cfg['extract_images']['eval_depth']: 182 | depth_errors = [] 183 | ratio = np.median(np.concatenate(depth_gts)) / \ 184 | np.median(np.concatenate(depth_preds)) 185 | for i in range(len(depth_preds)): 186 | gt_depth = depth_gts[i] 187 | pred_depth = depth_preds[i] 188 | 189 | pred_depth *= ratio 190 | pred_depth[pred_depth < min_depth] = min_depth 191 | pred_depth[pred_depth > max_depth] = max_depth 192 | 193 | depth_errors.append(compute_errors(gt_depth, pred_depth)) 194 | 195 | 196 | mean_errors = np.array(depth_errors).mean(0) 197 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 198 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 199 | print("\n-> Done!") 200 | 201 | with open(os.path.join(generation_dir, 'depth_evaluation.txt'), 'a') as f: 202 | f.writelines(("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3") + '\n') 203 | f.writelines(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 204 | 205 | imgs = np.stack(imgs, axis=0) 206 | video_out_dir = os.path.join(render_dir, 'video_out') 207 | if not os.path.exists(video_out_dir): 208 | os.makedirs(video_out_dir) 209 | imageio.mimwrite(os.path.join(video_out_dir, 'img.mp4'), imgs, fps=30, quality=9) 210 | 211 | if __name__=='__main__': 212 | # Config 213 | parser = argparse.ArgumentParser( 214 | description='Extract images.' 215 | ) 216 | parser.add_argument('config', type=str, help='Path to config file.') 217 | args = parser.parse_args() 218 | cfg = load_config(args.config, 'configs/default.yaml') 219 | eval(cfg) 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /evaluation/eval_poses.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | sys.path.append(os.path.join(sys.path[0], '..')) 6 | from dataloading import get_dataloader, load_config 7 | import model as mdl 8 | import numpy as np 9 | 10 | from utils_poses.vis_cam_traj import draw_camera_frustum_geometry 11 | from utils_poses.align_traj import align_ate_c2b_use_a2b 12 | from utils_poses.comp_ate import compute_rpe, compute_ATE 13 | import ATE.transformations as tf 14 | torch.manual_seed(0) 15 | 16 | # Config 17 | parser = argparse.ArgumentParser( 18 | description='Eval Poses.' 19 | ) 20 | parser.add_argument('config', type=str, help='Path to config file.') 21 | parser.add_argument('--vis',action='store_true') 22 | args = parser.parse_args() 23 | cfg = load_config(args.config, 'configs/default.yaml') 24 | 25 | is_cuda = (torch.cuda.is_available()) 26 | device = torch.device("cuda" if is_cuda else "cpu") 27 | 28 | 29 | out_dir = cfg['training']['out_dir'] 30 | 31 | test_loader, field = get_dataloader(cfg, mode='train', shuffle=False) 32 | N_imgs = field['img'].N_imgs 33 | with torch.no_grad(): 34 | if cfg['pose']['init_pose']: 35 | if cfg['pose']['init_pose_type']=='gt': 36 | init_pose = field['img'].c2ws # init with colmap 37 | elif cfg['pose']['init_pose_type']=='colmap': 38 | init_pose = field['img'].c2ws_colmap 39 | else: 40 | init_pose = None 41 | pose_param_net = mdl.LearnPose(N_imgs, cfg['pose']['learn_R'], 42 | cfg['pose']['learn_t'], cfg=cfg, init_c2w=init_pose).to(device=device) 43 | checkpoint_io_pose = mdl.CheckpointIO(out_dir, model=pose_param_net) 44 | checkpoint_io_pose.load(cfg['extract_images']['model_file_pose'], device) 45 | learned_poses = torch.stack([pose_param_net(i) for i in range(N_imgs)]) 46 | 47 | H = field['img'].H 48 | W = field['img'].W 49 | gt_poses = field['img'].c2ws 50 | if cfg['pose']['learn_focal']: 51 | focal_net = mdl.LearnFocal(cfg['pose']['learn_focal'], cfg['pose']['fx_only'], order=cfg['pose']['focal_order']) 52 | checkpoint_io_focal = mdl.CheckpointIO(out_dir, model=focal_net) 53 | checkpoint_io_focal.load(cfg['extract_images']['model_file_focal'], device) 54 | fxfy = focal_net(0) 55 | fx = fxfy[0] * W / 2 56 | fy = fxfy[1] * H / 2 57 | else: 58 | fx = field['img'].focal 59 | fy = field['img'].focal 60 | 61 | 62 | '''Define camera frustums''' 63 | frustum_length = 0.1 64 | est_traj_color = np.array([39, 125, 161], dtype=np.float32) / 255 65 | cmp_traj_color = np.array([249, 65, 68], dtype=np.float32) / 255 66 | 67 | '''Align est traj to colmap traj''' 68 | c2ws_est_to_draw_align2cmp = learned_poses.clone() 69 | ATE_align = True 70 | 71 | if ATE_align: # Align learned poses to colmap poses 72 | c2ws_est_aligned = align_ate_c2b_use_a2b(learned_poses, gt_poses) # (N, 4, 4) 73 | c2ws_est_to_draw_align2cmp = c2ws_est_aligned 74 | 75 | # compute ate 76 | ate = compute_ATE(gt_poses.cpu().numpy(), c2ws_est_aligned.cpu().numpy()) 77 | rpe_trans, rpe_rot = compute_rpe(gt_poses.cpu().numpy(), c2ws_est_aligned.cpu().numpy()) 78 | print("{0:.3f}".format(rpe_trans*100),'&' "{0:.3f}".format(rpe_rot * 180 / np.pi), '&', "{0:.3f}".format(ate)) 79 | 80 | 81 | if args.vis: 82 | import open3d as o3d 83 | frustum_est_list = draw_camera_frustum_geometry(c2ws_est_to_draw_align2cmp.cpu().numpy(), H, W, 84 | fx, fy, 85 | frustum_length, est_traj_color) 86 | frustum_colmap_list = draw_camera_frustum_geometry(gt_poses.cpu().numpy(), H, W, 87 | fx, fy, 88 | frustum_length, cmp_traj_color) 89 | 90 | geometry_to_draw = [] 91 | geometry_to_draw.append(frustum_est_list) 92 | geometry_to_draw.append(frustum_colmap_list) 93 | 94 | '''o3d for line drawing''' 95 | t_est_list = c2ws_est_to_draw_align2cmp[:, :3, 3] 96 | t_cmp_list = gt_poses[:, :3, 3] 97 | 98 | '''line set to note pose correspondence between two trajs''' 99 | line_points = torch.cat([t_est_list, t_cmp_list], dim=0).cpu().numpy() # (2N, 3) 100 | line_ends = [[i, i+N_imgs] for i in range(N_imgs)] # (N, 2) connect two end points. 101 | 102 | 103 | line_set = o3d.geometry.LineSet() 104 | line_set.points = o3d.utility.Vector3dVector(line_points) 105 | line_set.lines = o3d.utility.Vector2iVector(line_ends) 106 | unit_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=1.0, resolution=2) 107 | unit_sphere = o3d.geometry.LineSet.create_from_triangle_mesh(unit_sphere) 108 | unit_sphere.paint_uniform_color((0, 1, 0)) 109 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame() 110 | 111 | geometry_to_draw.append(line_set) 112 | 113 | o3d.visualization.draw_geometries(geometry_to_draw) 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.checkpoints import CheckpointIO 2 | from model.network import nope_nerf 3 | from model.training import Trainer 4 | from model.rendering import Renderer 5 | from model.config import get_model 6 | from model.official_nerf import OfficialStaticNerf 7 | from model.poses import LearnPose 8 | from model.intrinsics import LearnFocal 9 | from model.eval_pose_one_epoch import Trainer_pose 10 | from model.distortions import Learn_Distortion 11 | -------------------------------------------------------------------------------- /model/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | import shutil 6 | import datetime 7 | 8 | 9 | class CheckpointIO(object): 10 | ''' CheckpointIO class. 11 | 12 | It handles saving and loading checkpoints. 13 | 14 | Args: 15 | checkpoint_dir (str): path where checkpoints are saved 16 | ''' 17 | 18 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 19 | self.module_dict = kwargs 20 | self.checkpoint_dir = checkpoint_dir 21 | if not os.path.exists(checkpoint_dir): 22 | os.makedirs(checkpoint_dir) 23 | 24 | def register_modules(self, **kwargs): 25 | ''' Registers modules in current module dictionary. 26 | ''' 27 | self.module_dict.update(kwargs) 28 | 29 | def save(self, filename, **kwargs): 30 | ''' Saves the current module dictionary. 31 | 32 | Args: 33 | filename (str): name of output file 34 | ''' 35 | if not os.path.isabs(filename): 36 | filename = os.path.join(self.checkpoint_dir, filename) 37 | 38 | outdict = kwargs 39 | for k, v in self.module_dict.items(): 40 | outdict[k] = v.state_dict() 41 | torch.save(outdict, filename) 42 | 43 | def backup_model_best(self, filename, **kwargs): 44 | if not os.path.isabs(filename): 45 | filename = os.path.join(self.checkpoint_dir, filename) 46 | if os.path.exists(filename): 47 | # Backup model 48 | backup_dir = os.path.join(self.checkpoint_dir, 'backup_model_best') 49 | if not os.path.exists(backup_dir): 50 | os.makedirs(backup_dir) 51 | ts = datetime.datetime.now().timestamp() 52 | filename_backup = os.path.join(backup_dir, '%s.pt' % ts) 53 | shutil.copy(filename, filename_backup) 54 | 55 | def load(self, filename, device=None, load_model_only=False): 56 | '''Loads a module dictionary from local file or url. 57 | 58 | Args: 59 | filename (str): name of saved module dictionary 60 | ''' 61 | if is_url(filename): 62 | return self.load_url(filename) 63 | else: 64 | return self.load_file(filename, device, load_model_only) 65 | 66 | def load_file(self, filename, device=None, load_model_only=False): 67 | '''Loads a module dictionary from file. 68 | 69 | Args: 70 | filename (str): name of saved module dictionary 71 | ''' 72 | 73 | if not os.path.isabs(filename): 74 | filename = os.path.join(self.checkpoint_dir, filename) 75 | if os.path.exists(filename): 76 | print(filename) 77 | print('=> Loading checkpoint from local file...') 78 | # state_dict = torch.load(filename) 79 | if device is not None: 80 | state_dict = torch.load(filename, map_location=device) 81 | else: 82 | state_dict = torch.load(filename) 83 | if load_model_only: 84 | state_dict_model = {} 85 | state_dict_model['model'] = state_dict['model'] 86 | else: 87 | state_dict_model = state_dict 88 | scalars = self.parse_state_dict(state_dict_model) 89 | return scalars 90 | else: 91 | raise FileExistsError 92 | 93 | def load_url(self, url): 94 | '''Load a module dictionary from url. 95 | 96 | Args: 97 | url (str): url to saved model 98 | ''' 99 | print(url) 100 | print('=> Loading checkpoint from url...') 101 | state_dict = model_zoo.load_url(url, progress=True, check_hash=False) 102 | scalars = self.parse_state_dict(state_dict) 103 | return scalars 104 | 105 | def parse_state_dict(self, state_dict): 106 | '''Parse state_dict of model and return scalars. 107 | 108 | Args: 109 | state_dict (dict): State dict of model 110 | ''' 111 | 112 | for k, v in self.module_dict.items(): 113 | if k in state_dict: 114 | # v.load_state_dict(state_dict[k], strict=False) 115 | v.load_state_dict(state_dict[k]) 116 | else: 117 | print('Warning: Could not find %s in checkpoint!' % k) 118 | scalars = {k: v for k, v in state_dict.items() 119 | if k not in self.module_dict} 120 | return scalars 121 | 122 | 123 | def is_url(url): 124 | ''' Checks if input string is a URL. 125 | 126 | Args: 127 | url (string): URL 128 | ''' 129 | scheme = urllib.parse.urlparse(url).scheme 130 | return scheme in ('http', 'https') 131 | -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import model as mdl 2 | from DPT.dpt.models import DPTDepthModel 3 | 4 | def get_model(renderer, cfg, device=None, **kwargs): 5 | depth_estimator = cfg['depth']['type'] 6 | if depth_estimator== 'DPT': 7 | path = cfg['depth']['path'] 8 | non_negative = cfg['depth']['non_negative'] 9 | scale = cfg['depth']['scale'] 10 | shift = cfg['depth']['shift'] 11 | invert = cfg['depth']['invert'] 12 | freeze = cfg['depth']['freeze'] 13 | depth_estimator = DPTDepthModel(path, non_negative, scale, shift, invert, freeze) 14 | else: 15 | depth_estimator = None 16 | model = mdl.nope_nerf(cfg, renderer, depth_estimator, device) 17 | 18 | return model -------------------------------------------------------------------------------- /model/distortions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Learn_Distortion(nn.Module): 5 | def __init__(self, num_cams, learn_scale, learn_shift, cfg): 6 | """depth distortion parameters 7 | 8 | Args: 9 | num_cams (int): number of cameras 10 | learn_scale (bool): whether to update scale 11 | learn_shift (bool): whether to update shift 12 | cfg (dict): argument options 13 | """ 14 | super(Learn_Distortion, self).__init__() 15 | self.global_scales=nn.Parameter(torch.ones(size=(num_cams, 1), dtype=torch.float32), requires_grad=learn_scale) 16 | self.global_shifts=nn.Parameter(torch.zeros(size=(num_cams, 1), dtype=torch.float32), requires_grad=learn_shift) 17 | self.fix_scaleN = cfg['distortion']['fix_scaleN'] 18 | self.num_cams = num_cams 19 | def forward(self, cam_id): 20 | scale = self.global_scales[cam_id] 21 | if scale<0.01: 22 | scale = torch.tensor(0.01, device=self.global_scales.device) 23 | if self.fix_scaleN and cam_id ==(self.num_cams-1): 24 | scale = torch.tensor(1, device=self.global_scales.device) 25 | shift = self.global_shifts[cam_id] 26 | 27 | return scale, shift -------------------------------------------------------------------------------- /model/eval_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | import logging 6 | import numpy as np 7 | import cv2 8 | import imageio 9 | from model.common import mse2psnr 10 | from third_party import pytorch_ssim 11 | from skimage import metrics 12 | from model.common import ( 13 | get_tensor_values, arange_pixels 14 | ) 15 | logger_py = logging.getLogger(__name__) 16 | class Eval_Images(object): 17 | 18 | def __init__(self, renderer, cfg, points_batch_size=100000, use_learnt_poses=True, use_learnt_focal=True, device=None,render_type=None, c2ws=None, img_list=None): 19 | self.points_batch_size = points_batch_size 20 | self.renderer = renderer 21 | self.resolution = cfg['extract_images']['resolution'] 22 | self.device = device 23 | self.use_learnt_poses = use_learnt_poses 24 | self.use_learnt_focal = use_learnt_focal 25 | self.render_type = render_type 26 | self.c2ws = c2ws 27 | self.img_list = img_list 28 | 29 | def process_data_dict(self, data): 30 | ''' Processes the data dictionary and returns respective tensors 31 | 32 | Args: 33 | data (dictionary): data dictionary 34 | ''' 35 | device = self.device 36 | img = data.get('img').to(device) 37 | batch_size, _, h, w = img.shape 38 | depth_img = data.get('img.depth', torch.ones(batch_size, h, w)) 39 | img_idx = data.get('img.idx') 40 | camera_mat = data.get('img.camera_mat').to(device) 41 | scale_mat = data.get('img.scale_mat').to(device) 42 | 43 | return (img, depth_img, camera_mat, scale_mat, img_idx) 44 | 45 | def eval_images(self, data, render_dir, fxfy, lpips_vgg_fn, logger, min_depth=0.1, max_depth=20, it=0): 46 | self.renderer.eval() 47 | (img_gt, depth_gt, camera_mat, scale_mat, img_idx) = self.process_data_dict(data) 48 | img_idx = int(img_idx) 49 | img_gt = img_gt.squeeze(0).permute(1, 2, 0) 50 | 51 | depth_gt = depth_gt.squeeze(0).numpy() 52 | mask = (depth_gt > min_depth) * (depth_gt < max_depth) 53 | 54 | if self.use_learnt_poses: 55 | c2w = self.c2ws[img_idx] 56 | world_mat = torch.inverse(c2w).unsqueeze(0) 57 | if self.use_learnt_focal: 58 | camera_mat = torch.tensor([[[fxfy[0], 0, 0, 0], 59 | [0, -fxfy[1], 0, 0], 60 | [0, 0, -1, 0], 61 | [0, 0, 0, 1]]]).to(self.device) 62 | h, w = self.resolution 63 | 64 | p_loc, pixels = arange_pixels(resolution=(h, w)) 65 | 66 | pixels = pixels.to(self.device) 67 | 68 | # redundancy, set depth_input values to ones to avoid masking 69 | depth_input = torch.zeros(1, 1, h, w).to(self.device) 70 | depth_input = get_tensor_values(depth_input, pixels.clone(), mode='nearest', scale=True, detach=False) 71 | depth_input = torch.ones_like(depth_input) 72 | 73 | with torch.no_grad(): 74 | rgb_pred = [] 75 | depth_pred = [] 76 | for i, (pixels_i, depth_i) in enumerate(zip(torch.split(pixels, self.points_batch_size, dim=1), torch.split(depth_input, self.points_batch_size, dim=1))): 77 | out_dict = self.renderer(pixels_i, depth_i, camera_mat, world_mat, scale_mat, 78 | self.render_type, eval_=True, it=it, add_noise=False) 79 | rgb_pred_i = out_dict['rgb'] 80 | rgb_pred.append(rgb_pred_i) 81 | depth_pred_i = out_dict['depth_pred'] 82 | depth_pred.append(depth_pred_i) 83 | rgb_pred = torch.cat(rgb_pred, dim=1) 84 | img_out = rgb_pred.view(h, w, 3) 85 | depth_pred = torch.cat(depth_pred, dim=0) 86 | depth_pred = depth_pred.view(h, w).detach().cpu().numpy() 87 | depth_out = depth_pred 88 | 89 | 90 | # mse for the entire image 91 | mse = F.mse_loss(img_out, img_gt).item() 92 | psnr = mse2psnr(mse) 93 | ssim = pytorch_ssim.ssim(img_out.permute(2, 0, 1).unsqueeze(0), img_gt.permute(2, 0, 1).unsqueeze(0)).item() 94 | 95 | lpips_loss = lpips_vgg_fn(img_out.permute(2, 0, 1).unsqueeze(0).contiguous(), 96 | img_gt.permute(2, 0, 1).unsqueeze(0).contiguous(), normalize=True).item() 97 | 98 | tqdm.write('{0:4d} img: PSNR: {1:.2f}, SSIM: {2:.2f}, LPIPS {3:.2f}'.format(img_idx, psnr, ssim, lpips_loss)) 99 | 100 | 101 | gt_height, gt_width = depth_gt.shape[:2] 102 | depth_out = cv2.resize(depth_out, (gt_width, gt_height), interpolation=cv2.INTER_NEAREST) 103 | 104 | img_out_dir = os.path.join(render_dir, 'img_out') 105 | depth_out_dir = os.path.join(render_dir, 'depth_out') 106 | img_gt_dir = os.path.join(render_dir, 'img_gt_out') 107 | if not os.path.exists(img_out_dir): 108 | os.makedirs(img_out_dir) 109 | if not os.path.exists(depth_out_dir): 110 | os.makedirs(depth_out_dir) 111 | if not os.path.exists(img_gt_dir): 112 | os.makedirs(img_gt_dir) 113 | 114 | 115 | depth_out = (np.clip(255.0 / depth_out.max() * (depth_out - depth_out.min()), 0, 255)).astype(np.uint8) 116 | img_out = (img_out.cpu().numpy() * 255).astype(np.uint8) 117 | img_gt = (img_gt.cpu().numpy() * 255).astype(np.uint8) 118 | imageio.imwrite(os.path.join(img_out_dir, str(img_idx).zfill(4) + '.png'), img_out) 119 | imageio.imwrite(os.path.join(depth_out_dir, str(img_idx).zfill(4) + '.png'), depth_out) 120 | imageio.imwrite(os.path.join(img_gt_dir, str(img_idx).zfill(4) + '.png'), img_gt) 121 | 122 | depth_out = depth_out[mask] 123 | depth_gt = depth_gt[mask] 124 | # frame_id = self.img_list[img_idx].split('.')[0] 125 | # filename = os.path.join(depth_out_dir, '{}_depth.npy'.format(frame_id)) 126 | # np.save(filename, depth_out) 127 | # filename = os.path.join(img_out_dir, '{}.npy'.format(frame_id)) 128 | # np.save(filename, img_out.cpu().numpy()) 129 | img_dict = {'img': img_out, 130 | 'depth': depth_out, 131 | 'mse': mse, 132 | 'psnr': psnr, 133 | 'ssim': ssim, 134 | 'lpips': lpips_loss, 135 | 'depth_pred': depth_out, 136 | 'depth_gt': depth_gt} 137 | return img_dict 138 | 139 | -------------------------------------------------------------------------------- /model/eval_pose_one_epoch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from model.common import ( 4 | arange_pixels 5 | ) 6 | import logging 7 | from model.losses import Loss_Eval 8 | logger_py = logging.getLogger(__name__) 9 | 10 | class Trainer_pose(object): 11 | 12 | def __init__(self, model, cfg, device=None, optimizer_pose=None, pose_param_net=None, 13 | focal_net=None, **kwargs): 14 | self.model = model 15 | self.device = device 16 | self.optimizer_pose = optimizer_pose 17 | self.pose_param_net = pose_param_net 18 | self.focal_net = focal_net 19 | self.n_points = cfg['n_points'] 20 | self.rendering_technique = cfg['type'] 21 | 22 | self.loss = Loss_Eval() 23 | 24 | 25 | def train_step(self, data, it=100000): 26 | ''' Performs a training step. 27 | 28 | Args: 29 | data (dict): data dictionary 30 | it (int): training iteration 31 | ''' 32 | self.model.eval() 33 | self.pose_param_net.train() 34 | self.optimizer_pose.zero_grad() 35 | if self.focal_net is not None: 36 | self.focal_net.eval() 37 | loss_dict = self.compute_loss(data, it=it) 38 | loss = loss_dict['loss'] 39 | loss.backward() 40 | self.optimizer_pose.step() 41 | return loss_dict 42 | 43 | 44 | def process_data_dict(self, data): 45 | ''' Processes the data dictionary and returns respective tensors 46 | 47 | Args: 48 | data (dictionary): data dictionary 49 | ''' 50 | device = self.device 51 | 52 | # Get "ordinary" data 53 | 54 | img = data.get('img').to(device) 55 | img_idx = data.get('img.idx').to(device) 56 | batch_size, _, h, w = img.shape 57 | depth_img = data.get('img.depth', torch.ones(batch_size, h, w)).unsqueeze(1).to(device) # add for nope_nerf 58 | camera_mat = data.get('img.camera_mat').to(device) 59 | scale_mat = data.get('img.scale_mat').to(device) 60 | return (img, depth_img, camera_mat, scale_mat, img_idx) 61 | 62 | def compute_loss(self, data, eval_mode=False, it=100000): 63 | ''' Compute the loss. 64 | 65 | Args: 66 | data (dict): data dictionary 67 | eval_mode (bool): whether to use eval mode 68 | it (int): training iteration 69 | ''' 70 | n_points = self.n_points 71 | (img, depth_img, camera_mat, scale_mat, img_idx) = self.process_data_dict(data) 72 | # Shortcuts 73 | device = self.device 74 | batch_size, _, h, w = img.shape 75 | c2w = self.pose_param_net(img_idx) 76 | world_mat = torch.inverse(c2w).unsqueeze(0) 77 | if self.focal_net is not None: 78 | fxfy = self.focal_net(0) 79 | pad = torch.zeros(4) 80 | one = torch.tensor([1]) 81 | camera_mat = torch.cat([fxfy[0:1], pad, -fxfy[1:2], pad, -one, pad, one]).to(device) 82 | camera_mat = camera_mat.view(1, 4, 4) 83 | 84 | 85 | ray_idx = torch.randperm(h*w,device=device)[:n_points] 86 | img_flat = img.view(batch_size, 3, h*w).permute(0,2,1) 87 | rgb_gt = img_flat[:,ray_idx] 88 | p_full = arange_pixels((h, w), batch_size)[1].to(device) 89 | p = p_full[:, ray_idx] 90 | pix = ray_idx 91 | 92 | out_dict = self.model( 93 | p, pix, camera_mat, world_mat, scale_mat, 94 | self.rendering_technique, it=it, 95 | eval_mode=True, depth_img=depth_img, add_noise=False,img_size=(h, w) 96 | ) 97 | loss_dict = self.loss(out_dict['rgb'], rgb_gt) 98 | return loss_dict 99 | -------------------------------------------------------------------------------- /model/extracting_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import defaultdict 4 | from model.common import ( 5 | get_tensor_values, arange_pixels 6 | ) 7 | from tqdm import tqdm 8 | import logging 9 | import numpy as np 10 | logger_py = logging.getLogger(__name__) 11 | from PIL import Image 12 | import imageio 13 | class Extract_Images(object): 14 | def __init__(self, renderer, cfg, use_learnt_poses=True, use_learnt_focal=True, device=None,render_type=None): 15 | self.points_batch_size = 100000 16 | self.renderer = renderer 17 | self.resolution = cfg['extract_images']['resolution'] 18 | self.device = device 19 | self.use_learnt_poses = use_learnt_poses 20 | self.use_learnt_focal = use_learnt_focal 21 | self.render_type = render_type 22 | 23 | def process_data_dict(self, data): 24 | ''' Processes the data dictionary and returns respective tensors 25 | 26 | Args: 27 | data (dictionary): data dictionary 28 | ''' 29 | device = self.device 30 | 31 | img_idx = data.get('img.idx') 32 | # world_mat = data.get('img.world_mat').to(device) 33 | camera_mat = data.get('img.camera_mat').to(device) 34 | scale_mat = data.get('img.scale_mat').to(device) 35 | 36 | return (camera_mat, scale_mat, img_idx) 37 | 38 | def generate_images(self, data, render_dir, c2ws, fxfy, it, output_geo): 39 | self.renderer.eval() 40 | (camera_mat, scale_mat, img_idx) = self.process_data_dict(data) 41 | img_idx = int(img_idx) 42 | if self.use_learnt_poses: 43 | c2w = c2ws[img_idx] 44 | world_mat = torch.inverse(c2w).unsqueeze(0) 45 | if self.use_learnt_focal: 46 | camera_mat = torch.tensor([[[fxfy[0], 0, 0, 0], 47 | [0, -fxfy[1], 0, 0], 48 | [0, 0, -1, 0], 49 | [0, 0, 0, 1]]]).to(self.device) 50 | h, w = self.resolution 51 | 52 | p_loc, pixels = arange_pixels(resolution=(h, w)) 53 | 54 | pixels = pixels.to(self.device) 55 | 56 | 57 | # redundancy, set depth_input values to ones to avoid masking 58 | depth_input = torch.zeros(1, 1, h, w).to(self.device) 59 | depth_input = get_tensor_values(depth_input, pixels.clone(), mode='nearest', scale=True, detach=False) 60 | depth_input = torch.ones_like(depth_input) 61 | with torch.no_grad(): 62 | rgb_pred = [] 63 | depth_pred = [] 64 | for ii, (pixels_i, depth_i) in enumerate(zip(torch.split(pixels, self.points_batch_size, dim=1), torch.split(depth_input, self.points_batch_size, dim=1))): 65 | out_dict = self.renderer(pixels_i, depth_i, camera_mat, world_mat, scale_mat, 66 | self.render_type, eval_=True, it=it, add_noise=False) 67 | rgb_pred_i = out_dict['rgb'] 68 | rgb_pred.append(rgb_pred_i) 69 | depth_pred_i = out_dict['depth_pred'] 70 | depth_pred.append(depth_pred_i) 71 | rgb_pred = torch.cat(rgb_pred, dim=1) 72 | rgb_pred = rgb_pred.view(h, w, 3).detach().cpu().numpy() 73 | depth_pred = torch.cat(depth_pred, dim=0) 74 | depth_pred = depth_pred.view(h, w).detach().cpu().numpy() 75 | 76 | img_out = (rgb_pred * 255).astype(np.uint8) 77 | depth_out = depth_pred 78 | 79 | 80 | if output_geo: 81 | with torch.no_grad(): 82 | mask_pred = torch.ones(pixels.shape[0], pixels.shape[1]).bool() 83 | rgb_pred = \ 84 | [self.renderer( 85 | pixels_i, None, camera_mat, world_mat, scale_mat, 86 | 'phong_renderer', eval_=True, it=it, add_noise=False)['rgb'] 87 | for ii, pixels_i in enumerate(torch.split(pixels, 1024, dim=1))] 88 | 89 | rgb_pred = torch.cat(rgb_pred, dim=1).cpu() 90 | p_loc1 = p_loc[mask_pred] 91 | geo_out = (255 * np.zeros((h, w, 3))).astype(np.uint8) 92 | 93 | if mask_pred.sum() > 0: 94 | rgb_hat = rgb_pred[mask_pred].detach().cpu().numpy() 95 | rgb_hat = (rgb_hat * 255).astype(np.uint8) 96 | geo_out[p_loc1[:, 1], p_loc1[:, 0]] = rgb_hat 97 | geo_out_dir = os.path.join(render_dir, 'geo_out') 98 | if not os.path.exists(geo_out_dir): 99 | os.makedirs(geo_out_dir) 100 | imageio.imwrite(os.path.join(geo_out_dir, str(img_idx).zfill(4) + '.png'), geo_out) 101 | else: 102 | geo_out = None 103 | 104 | img_out_dir = os.path.join(render_dir, 'img_out') 105 | depth_out_dir = os.path.join(render_dir, 'depth_out') 106 | 107 | if not os.path.exists(img_out_dir): 108 | os.makedirs(img_out_dir) 109 | if not os.path.exists(depth_out_dir): 110 | os.makedirs(depth_out_dir) 111 | 112 | filename = os.path.join(depth_out_dir, '{}.npy'.format(img_idx)) 113 | np.save(filename, depth_out) 114 | 115 | depth_out = (np.clip(255.0 / depth_out.max() * (depth_out - depth_out.min()), 0, 255)).astype(np.uint8) 116 | 117 | imageio.imwrite(os.path.join(img_out_dir, str(img_idx).zfill(4) + '.png'), img_out) 118 | imageio.imwrite(os.path.join(depth_out_dir, str(img_idx).zfill(4) + '.png'), depth_out) 119 | 120 | 121 | img_dict = {'img': img_out, 122 | 'depth': depth_out, 123 | 'geo': geo_out} 124 | return img_dict -------------------------------------------------------------------------------- /model/intrinsics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class LearnFocal(nn.Module): 6 | def __init__(self, req_grad, fx_only, order=2, init_focal=None): 7 | super(LearnFocal, self).__init__() 8 | 9 | self.fx_only = fx_only # If True, output [fx, fx]. If False, output [fx, fy] 10 | self.order = order # check our supplementary section. 11 | 12 | if self.fx_only: 13 | if init_focal is None: 14 | self.fx = nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=req_grad) # (1, ) 15 | else: 16 | if self.order == 2: 17 | # a**2 * W = fx ---> a**2 = fx / W 18 | coe_x = torch.tensor(np.sqrt(init_focal), requires_grad=False).float() 19 | elif self.order == 1: 20 | # a * W = fx ---> a = fx / W 21 | coe_x = torch.tensor(init_focal, requires_grad=False).float() 22 | else: 23 | print('Focal init order need to be 1 or 2. Exit') 24 | exit() 25 | self.fx = nn.Parameter(coe_x, requires_grad=req_grad) # (1, ) 26 | else: 27 | if init_focal is None: 28 | self.fx = nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=req_grad) # (1, ) 29 | self.fy = nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=req_grad) # (1, ) 30 | elif isinstance(init_focal, list): 31 | if self.order == 2: 32 | # a**2 * W = fx ---> a**2 = fx / W 33 | coe_x = torch.tensor(np.sqrt(init_focal[0]), requires_grad=False).float() 34 | coe_y = torch.tensor(np.sqrt(init_focal[1]), requires_grad=False).float() 35 | elif self.order == 1: 36 | # a * W = fx ---> a = fx / W 37 | coe_x = torch.tensor(init_focal[0], requires_grad=False).float() 38 | coe_y = torch.tensor(init_focal[1], requires_grad=False).float() 39 | else: 40 | print('Focal init order need to be 1 or 2. Exit') 41 | exit() 42 | self.fx = nn.Parameter(coe_x, requires_grad=req_grad) # (1, ) 43 | self.fy = nn.Parameter(coe_y, requires_grad=req_grad) # (1, ) 44 | else: 45 | if self.order == 2: 46 | # a**2 * W = fx ---> a**2 = fx / W 47 | coe_x = torch.tensor(np.sqrt(init_focal), requires_grad=False).float() 48 | coe_y = torch.tensor(np.sqrt(init_focal), requires_grad=False).float() 49 | elif self.order == 1: 50 | # a * W = fx ---> a = fx / W 51 | coe_x = torch.tensor(init_focal, requires_grad=False).float() 52 | coe_y = torch.tensor(init_focal, requires_grad=False).float() 53 | else: 54 | print('Focal init order need to be 1 or 2. Exit') 55 | exit() 56 | self.fx = nn.Parameter(coe_x, requires_grad=req_grad) # (1, ) 57 | self.fy = nn.Parameter(coe_y, requires_grad=req_grad) # (1, ) 58 | 59 | def forward(self, i=None): # the i=None is just to enable multi-gpu training 60 | if self.fx_only: 61 | if self.order == 2: 62 | fxfy = torch.stack([self.fx ** 2, self.fx ** 2]) 63 | else: 64 | fxfy = torch.stack([self.fx, self.fx]) 65 | else: 66 | if self.order == 2: 67 | fxfy = torch.stack([self.fx**2, self.fy**2]) 68 | else: 69 | fxfy = torch.stack([self.fx, self.fy]) 70 | return fxfy 71 | -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | class Loss_Eval(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | def forward(self, rgb_pred, rgb_gt): 10 | loss = F.mse_loss(rgb_pred, rgb_gt) 11 | return_dict = { 12 | 'loss': loss 13 | } 14 | return return_dict 15 | 16 | class Loss(nn.Module): 17 | def __init__(self, cfg=None): 18 | super().__init__() 19 | 20 | self.depth_loss_type = cfg['depth_loss_type'] 21 | 22 | self.l1_loss = nn.L1Loss(reduction='sum') 23 | self.l2_loss = nn.MSELoss(reduction='sum') 24 | 25 | self.cfg = cfg 26 | 27 | def get_rgb_full_loss(self, rgb_values, rgb_gt, rgb_loss_type='l2'): 28 | if rgb_loss_type == 'l1': 29 | rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(rgb_values.shape[1]) 30 | elif rgb_loss_type == 'l2': 31 | rgb_loss = self.l2_loss(rgb_values, rgb_gt) / float(rgb_values.shape[1]) 32 | return rgb_loss 33 | 34 | def depth_loss_dpt(self, pred_depth, gt_depth, weight=None): 35 | """ 36 | :param pred_depth: (H, W) 37 | :param gt_depth: (H, W) 38 | :param weight: (H, W) 39 | :return: scalar 40 | """ 41 | 42 | t_pred = torch.median(pred_depth) 43 | s_pred = torch.mean(torch.abs(pred_depth - t_pred)) 44 | 45 | t_gt = torch.median(gt_depth) 46 | s_gt = torch.mean(torch.abs(gt_depth - t_gt)) 47 | 48 | pred_depth_n = (pred_depth - t_pred) / s_pred 49 | gt_depth_n = (gt_depth - t_gt) / s_gt 50 | 51 | if weight is not None: 52 | loss = F.mse_loss(pred_depth_n, gt_depth_n, reduction='none') 53 | loss = loss * weight 54 | loss = loss.sum() / (weight.sum() + 1e-8) 55 | else: 56 | loss = F.mse_loss(pred_depth_n, gt_depth_n) 57 | return loss 58 | 59 | def get_depth_loss(self, depth_pred, depth_gt): 60 | if self.depth_loss_type == 'l1': 61 | loss = self.l1_loss(depth_pred, depth_gt) / float(depth_pred.shape[0]) 62 | elif self.depth_loss_type=='invariant': 63 | loss = self.depth_loss_dpt(depth_pred, depth_gt) 64 | return loss 65 | def get_reprojection_loss(self, rgb, rgb_refs, valid_points, rgb_refs_ori): 66 | cfg = self.cfg 67 | loss = 0 68 | for (rgb_ref, rgb_ref_ori) in zip(rgb_refs, rgb_refs_ori): 69 | diff_img = (rgb - rgb_ref).abs() 70 | if cfg['with_auto_mask'] == True: 71 | auto_mask = (diff_img.mean(dim=-1, keepdim=True) < (rgb - rgb_ref_ori).abs().mean(dim=-1, keepdim=True)).float() * valid_points 72 | valid_points = auto_mask 73 | loss = loss + self.mean_on_mask(diff_img, valid_points) 74 | loss = loss / len(rgb_refs) 75 | return loss 76 | # compute mean value given a binary mask 77 | def mean_on_mask(self, diff, valid_mask): 78 | mask = valid_mask.expand_as(diff) 79 | if mask.sum() > 0: 80 | mean_value = (diff[mask]).sum() / mask.sum() 81 | # mean_value = (diff * mask).sum() / mask.sum() 82 | else: 83 | print('============invalid mask==========') 84 | mean_value = torch.tensor(0).float().cuda() 85 | return mean_value 86 | def get_DPT_reprojection_loss(self, rgb, rgb_refs, valid_points, rgb_img_refs_ori): 87 | cfg = self.cfg 88 | loss = 0 89 | for rgb_ref, rgb_img_ref_ori in zip(rgb_refs, rgb_img_refs_ori): 90 | diff_img = (rgb - rgb_ref).abs() 91 | diff_img = diff_img.clamp(0, 1) 92 | if cfg['with_auto_mask'] == True: 93 | auto_mask = (diff_img.mean(dim=1, keepdim=True) < (rgb - rgb_img_ref_ori).abs().mean(dim=1, keepdim=True)).float() 94 | auto_mask = auto_mask* valid_points 95 | valid_points = auto_mask 96 | 97 | if cfg['with_ssim'] == True: 98 | ssim_map = compute_ssim_loss(rgb, rgb_ref) 99 | diff_img = (0.15 * diff_img + 0.85 * ssim_map) 100 | loss = loss + self.mean_on_mask(diff_img, valid_points) 101 | loss = loss / len(rgb_refs) 102 | return loss 103 | def get_weight_dist_loss(self, t_list): 104 | dist = t_list - t_list.roll(shifts=1, dims=0) 105 | dist = dist[1:] # the first dist is meaningless 106 | dist = dist.norm(dim=1) # (N-1, ) 107 | dist_diff = dist - dist.roll(shifts=1) 108 | dist_diff = dist_diff[1:] # (N-2, ) 109 | 110 | loss_dist_1st = dist.mean() 111 | loss_dist_2nd = dist_diff.pow(2.0).mean() 112 | return loss_dist_1st, loss_dist_2nd 113 | 114 | def get_pc_loss(self, Xt, Yt): 115 | # compute error 116 | match_method = self.cfg['match_method'] 117 | if match_method=='dense': 118 | loss1 = self.comp_point_point_error(Xt[0].permute(1, 0), Yt[0].permute(1, 0)) 119 | loss2= self.comp_point_point_error(Yt[0].permute(1, 0), Xt[0].permute(1, 0)) 120 | loss = loss1 + loss2 121 | return loss 122 | def get_depth_consistency_loss(self, d1_proj, d2, d2_proj=None, d1=None): 123 | loss = self.l1_loss(d1_proj, d2) / float(d1_proj.shape[1]) 124 | if d2_proj is not None: 125 | loss = 0.5 * loss + 0.5 * self.l1_loss(d2_proj, d1) / float(d2_proj.shape[1]) 126 | return loss 127 | def comp_closest_pts_idx_with_split(self, pts_src, pts_des): 128 | """ 129 | :param pts_src: (3, S) 130 | :param pts_des: (3, D) 131 | :param num_split: 132 | :return: 133 | """ 134 | pts_src_list = torch.split(pts_src, 500000, dim=1) 135 | idx_list = [] 136 | for pts_src_sec in pts_src_list: 137 | diff = pts_src_sec[:, :, np.newaxis] - pts_des[:, np.newaxis, :] # (3, S, 1) - (3, 1, D) -> (3, S, D) 138 | dist = torch.linalg.norm(diff, dim=0) # (S, D) 139 | closest_idx = torch.argmin(dist, dim=1) # (S,) 140 | idx_list.append(closest_idx) 141 | closest_idx = torch.cat(idx_list) 142 | return closest_idx 143 | def comp_point_point_error(self, Xt, Yt): 144 | closest_idx = self.comp_closest_pts_idx_with_split(Xt, Yt) 145 | pt_pt_vec = Xt - Yt[:, closest_idx] # (3, S) - (3, S) -> (3, S) 146 | pt_pt_dist = torch.linalg.norm(pt_pt_vec, dim=0) 147 | eng = torch.mean(pt_pt_dist) 148 | return eng 149 | 150 | def get_rgb_s_loss(self, rgb1, rgb2, valid_points): 151 | diff_img = (rgb1 - rgb2).abs() 152 | diff_img = diff_img.clamp(0, 1) 153 | if self.cfg['with_ssim'] == True: 154 | ssim_map = compute_ssim_loss(rgb1, rgb2) 155 | diff_img = (0.15 * diff_img + 0.85 * ssim_map) 156 | loss = self.mean_on_mask(diff_img, valid_points) 157 | return loss 158 | def forward(self, rgb_pred, rgb_gt, depth_pred=None, depth_gt=None, 159 | t_list=None, X=None, Y=None, rgb_pc1=None, 160 | rgb_pc1_proj=None, valid_points=None, 161 | d1_proj=None, d2=None, d2_proj=None, d1=None, weights={}, rgb_loss_type='l2', **kwargs): 162 | rgb_gt = rgb_gt.cuda() 163 | 164 | if weights['rgb_weight'] != 0.0: 165 | rgb_full_loss = self.get_rgb_full_loss(rgb_pred, rgb_gt, rgb_loss_type) 166 | else: 167 | rgb_full_loss = torch.tensor(0.0).cuda().float() 168 | if weights['depth_weight'] != 0.0: 169 | depth_loss = self.get_depth_loss(depth_pred, depth_gt) 170 | else: 171 | depth_loss = torch.tensor(0.0).cuda().float() 172 | 173 | if weights['weight_dist_2nd_loss'] !=0.0 or weights['weight_dist_1st_loss'] !=0.0: 174 | loss_dist_1st, loss_dist_2nd = self.get_weight_dist_loss(t_list) 175 | else: 176 | loss_dist_1st, loss_dist_2nd = torch.tensor(0.0).cuda().float(), torch.tensor(0.0).cuda().float() 177 | if weights['pc_weight']!=0.0: 178 | pc_loss = self.get_pc_loss(X, Y) 179 | else: 180 | pc_loss = torch.tensor(0.0).cuda().float() 181 | if weights['rgb_s_weight']!=0.0: 182 | rgb_s_loss = self.get_rgb_s_loss(rgb_pc1, rgb_pc1_proj, valid_points) 183 | else: 184 | rgb_s_loss = torch.tensor(0.0).cuda().float() 185 | if weights['depth_consistency_weight'] != 0.0: 186 | depth_consistency_loss = self.get_depth_consistency_loss(d1_proj, d2, d2_proj, d1) 187 | else: 188 | depth_consistency_loss = torch.tensor(0.0).cuda().float() 189 | 190 | 191 | if (weights['rgb_weight']!=0.0) or (weights['depth_weight'] !=0.0): 192 | rgb_l2_mean = F.mse_loss(rgb_pred, rgb_gt) 193 | else: 194 | rgb_l2_mean = torch.tensor(0.0).cuda().float() 195 | 196 | loss = weights['rgb_weight'] * rgb_full_loss + \ 197 | weights['depth_weight'] * depth_loss +\ 198 | weights['weight_dist_1st_loss'] * loss_dist_1st+\ 199 | weights['weight_dist_2nd_loss'] * loss_dist_2nd+\ 200 | weights['pc_weight'] * pc_loss+\ 201 | weights['rgb_s_weight'] * rgb_s_loss+\ 202 | weights['depth_consistency_weight'] * depth_consistency_loss 203 | 204 | if torch.isnan(loss): 205 | breakpoint() 206 | return_dict = { 207 | 'loss': loss, 208 | 'loss_rgb': rgb_full_loss, 209 | 'loss_depth': depth_loss, 210 | 'l2_mean': rgb_l2_mean, 211 | 'loss_dist_1st':loss_dist_1st, 212 | 'loss_dist_2nd': loss_dist_2nd, 213 | 'loss_pc': pc_loss, 214 | 'loss_rgb_s': rgb_s_loss, 215 | 'loss_depth_consistency': depth_consistency_loss 216 | } 217 | 218 | return return_dict 219 | 220 | 221 | 222 | class SSIM(nn.Module): 223 | """Layer to compute the SSIM loss between a pair of images 224 | """ 225 | 226 | def __init__(self): 227 | super(SSIM, self).__init__() 228 | self.mu_x_pool = nn.AvgPool2d(3, 1) 229 | self.mu_y_pool = nn.AvgPool2d(3, 1) 230 | self.sig_x_pool = nn.AvgPool2d(3, 1) 231 | self.sig_y_pool = nn.AvgPool2d(3, 1) 232 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 233 | 234 | self.refl = nn.ReflectionPad2d(1) 235 | 236 | self.C1 = 0.01 ** 2 237 | self.C2 = 0.03 ** 2 238 | 239 | def forward(self, x, y): 240 | x = self.refl(x) 241 | y = self.refl(y) 242 | 243 | mu_x = self.mu_x_pool(x) 244 | mu_y = self.mu_y_pool(y) 245 | 246 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 247 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 248 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 249 | 250 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 251 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 252 | 253 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 254 | compute_ssim_loss = SSIM().to('cuda') -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import cv2 6 | 7 | class nope_nerf(nn.Module): 8 | def __init__(self, cfg, renderer, depth_estimator=None, device=None, **kwargs): 9 | super().__init__() 10 | 11 | self.renderer = renderer.to(device) 12 | 13 | if depth_estimator is not None: 14 | self.depth_estimator = depth_estimator.to(device) 15 | else: 16 | self.depth_estimator = None 17 | 18 | self.device = device 19 | def forward(self, p, ray_idx, camera_mat, world_mat, scale_mat, rendering_technique, it=0, eval_mode=False, depth_img=None, 20 | add_noise=True, img_size=None): 21 | if rendering_technique=='nope_nerf': 22 | depth_img_resized = F.interpolate(depth_img, img_size ,mode='nearest') 23 | depth_img_resized = depth_img_resized.view(1, 1, -1).permute(0, 2, 1) 24 | depth = depth_img_resized[:,ray_idx] 25 | else: 26 | depth = None 27 | out_dict = self.renderer( 28 | p, depth, camera_mat, world_mat, scale_mat, 29 | rendering_technique, eval_=eval_mode, it=it, add_noise=add_noise 30 | ) 31 | 32 | 33 | return out_dict 34 | 35 | -------------------------------------------------------------------------------- /model/official_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | class OfficialStaticNerf(nn.Module): 9 | def __init__(self, cfg): 10 | super(OfficialStaticNerf, self).__init__() 11 | D = cfg['model']['hidden_dim'] 12 | pos_enc_levels = cfg['model']['pos_enc_levels'] 13 | dir_enc_levels = cfg['model']['dir_enc_levels'] 14 | pos_in_dims = (2 * pos_enc_levels + 1) * 3 # (2L + 0 or 1) * 3 15 | dir_in_dims = (2 * dir_enc_levels + 1) * 3 # (2L + 0 or 1) * 3 16 | self.white_bkgd = cfg['rendering']['white_background'] 17 | self.dist_alpha = cfg['rendering']['dist_alpha'] 18 | self.occ_activation = cfg['model']['occ_activation'] 19 | 20 | self.layers0 = nn.Sequential( 21 | nn.Linear(pos_in_dims, D), nn.ReLU(), 22 | nn.Linear(D, D), nn.ReLU(), 23 | nn.Linear(D, D), nn.ReLU(), 24 | nn.Linear(D, D), nn.ReLU(), 25 | ) 26 | 27 | self.layers1 = nn.Sequential( 28 | nn.Linear(D + pos_in_dims, D), nn.ReLU(), # short cut 29 | nn.Linear(D, D), nn.ReLU(), 30 | nn.Linear(D, D), nn.ReLU(), 31 | nn.Linear(D, D), nn.ReLU(), 32 | ) 33 | 34 | self.fc_density = nn.Linear(D, 1) 35 | self.fc_feature = nn.Linear(D, D) 36 | self.rgb_layers = nn.Sequential(nn.Linear(D + dir_in_dims, D//2), nn.ReLU()) 37 | self.fc_rgb = nn.Linear(D//2, 3) 38 | 39 | self.fc_density.bias.data = torch.tensor([0.1]).float() 40 | self.sigmoid = nn.Sigmoid() 41 | if self.white_bkgd: 42 | self.fc_rgb.bias.data = torch.tensor([0.8, 0.8, 0.8]).float() 43 | else: 44 | self.fc_rgb.bias.data = torch.tensor([0.02, 0.02, 0.02]).float() 45 | 46 | def gradient(self, p, it): 47 | with torch.enable_grad(): 48 | p.requires_grad_(True) 49 | _, y = self.infer_occ(p) 50 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 51 | gradients = torch.autograd.grad( 52 | outputs=y, 53 | inputs=p, 54 | grad_outputs=d_output, 55 | create_graph=True, 56 | retain_graph=True, 57 | only_inputs=True, allow_unused=True)[0] 58 | return -gradients.unsqueeze(1) 59 | 60 | def infer_occ(self, p): 61 | pos_enc = encode_position(p, levels=10, inc_input=True) 62 | x = self.layers0(pos_enc) # (H, W, N_sample, D) 63 | x = torch.cat([x, pos_enc], dim=-1) # (H, W, N_sample, D+pos_in_dims) 64 | x = self.layers1(x) # (H, W, N_sample, D) 65 | 66 | density = self.fc_density(x) # (H, W, N_sample, 1) 67 | return x, density 68 | 69 | def forward(self, p, ray_d=None, only_occupancy=False, return_logits=False,return_addocc=False, 70 | noise=False, it=100000, **kwargs): 71 | """ 72 | :param pos_enc: (H, W, N_sample, pos_in_dims) 73 | :param dir_enc: (H, W, N_sample, dir_in_dims) 74 | :return: rgb_density (H, W, N_sample, 4) 75 | """ 76 | x, density = self.infer_occ(p) 77 | if self.occ_activation=='softplus': 78 | density = F.softplus(density) 79 | else: 80 | density = density.relu() 81 | 82 | if not self.dist_alpha: 83 | density = 1 - torch.exp(-1.0 * density) 84 | if only_occupancy: 85 | return density 86 | elif ray_d is not None: 87 | dir_enc = encode_position(ray_d, levels=4, inc_input=True) 88 | feat = self.fc_feature(x) # (H, W, N_sample, D) 89 | x = torch.cat([feat, dir_enc], dim=-1) # (H, W, N_sample, D+dir_in_dims) 90 | x = self.rgb_layers(x) # (H, W, N_sample, D/2) 91 | rgb = self.fc_rgb(x) # (H, W, N_sample, 3) 92 | rgb = self.sigmoid(rgb) 93 | if return_addocc: 94 | return rgb, density 95 | else: 96 | return rgb 97 | 98 | 99 | def encode_position(input, levels, inc_input): 100 | """ 101 | For each scalar, we encode it using a series of sin() and cos() functions with different frequency. 102 | - With L pairs of sin/cos function, each scalar is encoded to a vector that has 2L elements. Concatenating with 103 | itself results in 2L+1 elements. 104 | - With C channels, we get C(2L+1) channels output. 105 | 106 | :param input: (..., C) torch.float32 107 | :param levels: scalar L int 108 | :return: (..., C*(2L+1)) torch.float32 109 | """ 110 | 111 | # this is already doing 'log_sampling' in the official code. 112 | result_list = [input] if inc_input else [] 113 | for i in range(levels): 114 | temp = 2.0**i * input # (..., C) 115 | result_list.append(torch.sin(temp)) # (..., C) 116 | result_list.append(torch.cos(temp)) # (..., C) 117 | 118 | result_list = torch.cat(result_list, dim=-1) # (..., C*(2L+1)) The list has (2L+1) elements, with (..., C) shape each. 119 | return result_list # (..., C*(2L+1)) 120 | 121 | 122 | -------------------------------------------------------------------------------- /model/poses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.common import make_c2w, convert3x4_4x4 4 | 5 | 6 | class LearnPose(nn.Module): 7 | def __init__(self, num_cams, learn_R, learn_t, cfg, init_c2w=None): 8 | """ 9 | :param num_cams: 10 | :param learn_R: True/False 11 | :param learn_t: True/False 12 | :param cfg: config argument options 13 | :param init_c2w: (N, 4, 4) torch tensor 14 | """ 15 | super(LearnPose, self).__init__() 16 | self.num_cams = num_cams 17 | self.init_c2w = None 18 | if init_c2w is not None: 19 | self.init_c2w = nn.Parameter(init_c2w, requires_grad=False) 20 | self.r = nn.Parameter(torch.zeros(size=(num_cams, 3), dtype=torch.float32), requires_grad=learn_R) # (N, 3) 21 | self.t = nn.Parameter(torch.zeros(size=(num_cams, 3), dtype=torch.float32), requires_grad=learn_t) # (N, 3) 22 | 23 | def forward(self, cam_id): 24 | cam_id = int(cam_id) 25 | r = self.r[cam_id] # (3, ) axis-angle 26 | t = self.t[cam_id] # (3, ) 27 | c2w = make_c2w(r, t) # (4, 4) 28 | # learn a delta pose between init pose and target pose, if a init pose is provided 29 | if self.init_c2w is not None: 30 | c2w = c2w @ self.init_c2w[cam_id] 31 | return c2w 32 | def get_t(self): 33 | return self.t 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /model/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from model.losses import Loss 5 | import numpy as np 6 | from PIL import Image 7 | import imageio 8 | from torch.nn import functional as F 9 | from model.common import ( 10 | get_tensor_values, 11 | arange_pixels, project_to_cam, transform_to_world, 12 | ) 13 | logger_py = logging.getLogger(__name__) 14 | class Trainer(object): 15 | def __init__(self, model, optimizer, cfg, device=None, optimizer_pose=None, pose_param_net=None, 16 | optimizer_focal=None, focal_net=None, optimizer_distortion=None,distortion_net=None, **kwargs): 17 | """model trainer 18 | 19 | Args: 20 | model (nn.Module): model 21 | optimizer (optimizer):pytorch optimizer object 22 | cfg (dict): config argument options 23 | device (device): Pytorch device option. Defaults to None. 24 | optimizer_pose (optimizer, optional): pytorch optimizer for poses. Defaults to None. 25 | pose_param_net (nn.Module, optional): model with pose parameters. Defaults to None. 26 | optimizer_focal (optimizer, optional): pytorch optimizer for focal. Defaults to None. 27 | focal_net (nn.Module, optional): model with focal parameters. Defaults to None. 28 | optimizer_distortion (optimizer, optional): pytorch optimizer for depth distortion. Defaults to None. 29 | distortion_net (nn.Module, optional): model with distortion parameters. Defaults to None. 30 | """ 31 | self.model = model 32 | self.optimizer = optimizer 33 | self.device = device 34 | self.optimizer_pose = optimizer_pose 35 | self.pose_param_net = pose_param_net 36 | self.focal_net = focal_net 37 | self.optimizer_focal = optimizer_focal 38 | self.distortion_net = distortion_net 39 | self.optimizer_distortion = optimizer_distortion 40 | 41 | self.n_training_points = cfg['n_training_points'] 42 | self.rendering_technique = cfg['type'] 43 | self.vis_geo = cfg['vis_geo'] 44 | 45 | self.detach_gt_depth = cfg['detach_gt_depth'] 46 | self.pc_ratio = cfg['pc_ratio'] 47 | self.match_method = cfg['match_method'] 48 | self.shift_first = cfg['shift_first'] 49 | self.detach_ref_img = cfg['detach_ref_img'] 50 | self.scale_pcs = cfg['scale_pcs'] 51 | self.detach_rgbs_scale = cfg['detach_rgbs_scale'] 52 | self.vis_reprojection_every = cfg['vis_reprojection_every'] 53 | self.nearest_limit = cfg['nearest_limit'] 54 | self.annealing_epochs = cfg['annealing_epochs'] 55 | 56 | self.pc_weight = cfg['pc_weight'] 57 | self.rgb_s_weight = cfg['rgb_s_weight'] 58 | self.rgb_weight = cfg['rgb_weight'] 59 | self.depth_weight = cfg['depth_weight'] 60 | self.weight_dist_2nd_loss = cfg['weight_dist_2nd_loss'] 61 | self.weight_dist_1st_loss = cfg['weight_dist_1st_loss'] 62 | self.depth_consistency_weight = cfg['depth_consistency_weight'] 63 | 64 | 65 | self.loss = Loss(cfg) 66 | 67 | def train_step(self, data, it=None, epoch=None,scheduling_start=None, render_path=None): 68 | ''' Performs a training step. 69 | 70 | Args: 71 | data (dict): data dictionary 72 | it (int): training iteration 73 | epoch(int): current number of epochs 74 | scheduling_start(int): num of epochs to start scheduling 75 | ''' 76 | self.model.train() 77 | self.optimizer.zero_grad() 78 | if self.pose_param_net: 79 | self.pose_param_net.train() 80 | self.optimizer_pose.zero_grad() 81 | if self.focal_net: 82 | self.focal_net.train() 83 | self.optimizer_focal.zero_grad() 84 | if self.distortion_net: 85 | self.distortion_net.train() 86 | self.optimizer_distortion.zero_grad() 87 | loss_dict = self.compute_loss(data, it=it, epoch=epoch, scheduling_start=scheduling_start, out_render_path=render_path) 88 | loss = loss_dict['loss'] 89 | loss.backward() 90 | self.optimizer.step() 91 | if self.optimizer_pose: 92 | self.optimizer_pose.step() 93 | if self.optimizer_focal: 94 | self.optimizer_focal.step() 95 | if self.optimizer_distortion: 96 | self.optimizer_distortion.step() 97 | return loss_dict 98 | 99 | 100 | def render_visdata(self, data, resolution, it, out_render_path): 101 | (img, dpt, camera_mat, scale_mat, img_idx) = self.process_data_dict(data) 102 | h, w = resolution 103 | if self.pose_param_net: 104 | c2w = self.pose_param_net(img_idx) 105 | world_mat = torch.inverse(c2w).unsqueeze(0) 106 | if self.optimizer_focal: 107 | fxfy = self.focal_net(0) 108 | camera_mat = torch.tensor([[[fxfy[0], 0, 0, 0], 109 | [0, -fxfy[1], 0, 0], 110 | [0, 0, -1, 0], 111 | [0, 0, 0, 1]]]).to(self.device) 112 | p_idx = torch.arange(h*w).to(self.device) 113 | p_loc, pixels = arange_pixels(resolution=(h, w)) 114 | 115 | pixels = pixels.to(self.device) 116 | depth_input = dpt 117 | 118 | with torch.no_grad(): 119 | rgb_pred = [] 120 | depth_pred = [] 121 | for i, (pixels_i, p_idx_i) in enumerate(zip(torch.split(pixels, 1024, dim=1), torch.split(p_idx, 1024, dim=0))): 122 | out_dict = self.model( 123 | pixels_i, p_idx_i, camera_mat, world_mat, scale_mat, self.rendering_technique, 124 | add_noise=False, eval_mode=True, it=it, depth_img=depth_input, img_size=(h, w)) 125 | rgb_pred_i = out_dict['rgb'] 126 | rgb_pred.append(rgb_pred_i) 127 | depth_pred_i = out_dict['depth_pred'] 128 | depth_pred.append(depth_pred_i) 129 | 130 | rgb_pred = torch.cat(rgb_pred, dim=1) 131 | depth_pred = torch.cat(depth_pred, dim=0) 132 | 133 | rgb_pred = rgb_pred.view(h, w, 3).detach().cpu().numpy() 134 | img_out = (rgb_pred * 255).astype(np.uint8) 135 | depth_pred_out = depth_pred.view(h, w).detach().cpu().numpy() 136 | imageio.imwrite(os.path.join(out_render_path,'%04d_depth.png'% img_idx), 137 | np.clip(255.0 / depth_pred_out.max() * (depth_pred_out - depth_pred_out.min()), 0, 255).astype(np.uint8)) 138 | 139 | img1 = Image.fromarray( 140 | (img_out).astype(np.uint8) 141 | ).convert("RGB").save( 142 | os.path.join(out_render_path, '%04d_img.png' % img_idx) 143 | ) 144 | if self.vis_geo: 145 | with torch.no_grad(): 146 | rgb_pred = \ 147 | [self.model( 148 | pixels_i, None, camera_mat, world_mat, scale_mat, 'phong_renderer', 149 | add_noise=False, eval_mode=True, it=it, depth_img=depth_input, img_size=(h, w))['rgb'] 150 | for i, pixels_i in enumerate(torch.split(pixels, 1024, dim=1))] 151 | 152 | rgb_pred = torch.cat(rgb_pred, dim=1).cpu() 153 | rgb_pred = rgb_pred.view(h, w, 3).detach().cpu().numpy() 154 | img_out = (rgb_pred * 255).astype(np.uint8) 155 | 156 | 157 | img1 = Image.fromarray( 158 | (img_out).astype(np.uint8) 159 | ).convert("RGB").save( 160 | os.path.join(out_render_path, '%04d_geo.png' % img_idx) 161 | ) 162 | 163 | return img_out.astype(np.uint8) 164 | def process_data_dict(self, data): 165 | ''' Processes the data dictionary and returns respective tensors 166 | Args: 167 | data (dictionary): data dictionary 168 | ''' 169 | device = self.device 170 | img = data.get('img').to(device) 171 | img_idx = data.get('img.idx') 172 | dpt = data.get('img.dpt').to(device).unsqueeze(1) 173 | camera_mat = data.get('img.camera_mat').to(device) 174 | scale_mat = data.get('img.scale_mat').to(device) 175 | 176 | return (img, dpt, camera_mat, scale_mat, img_idx) 177 | def process_data_reference(self, data): 178 | ''' Processes the data dictionary and returns respective tensors 179 | Args: 180 | data (dictionary): data dictionary 181 | ''' 182 | device = self.device 183 | ref_imgs = data.get('img.ref_imgs').to(device) 184 | ref_dpts = data.get('img.ref_dpts').to(device).unsqueeze(1) 185 | ref_idxs = data.get('img.ref_idxs') 186 | return ( ref_imgs, ref_dpts, ref_idxs) 187 | def anneal(self, start_weight, end_weight, anneal_start_epoch, anneal_epoches, current): 188 | """Anneal the weight from start_weight to end_weight 189 | """ 190 | if current <= anneal_start_epoch: 191 | return start_weight 192 | elif current >= anneal_start_epoch + anneal_epoches: 193 | return end_weight 194 | else: 195 | return start_weight + (end_weight - start_weight) * (current - anneal_start_epoch) / anneal_epoches 196 | 197 | def compute_loss(self, data, eval_mode=False, it=None, epoch=None, scheduling_start=None, out_render_path=None): 198 | ''' Compute the loss. 199 | 200 | Args: 201 | data (dict): data dictionary 202 | eval_mode (bool): whether to use eval mode 203 | it (int): training iteration 204 | epoch(int): current number of epochs 205 | scheduling_start(int): num of epochs to start scheduling 206 | out_render_path(str): path to save rendered images 207 | ''' 208 | weights = {} 209 | weights_name_list = ['rgb_weight', 'depth_weight', 'pc_weight', 'rgb_s_weight', 'depth_consistency_weight', 'weight_dist_2nd_loss', 'weight_dist_1st_loss'] 210 | weights_list = [self.anneal(getattr(self, w)[0], getattr(self, w)[1], scheduling_start, self.annealing_epochs, epoch) for w in weights_name_list] # loss weights 211 | rgb_loss_type = 'l1' if epoch < self.annealing_epochs + scheduling_start else 'l2' 212 | 213 | for i, weight in enumerate(weights_list): 214 | weight_name = weights_name_list[i] 215 | weights[weight_name] = weight 216 | render_model=(weights['rgb_weight']!=0.0) or (weights['depth_weight']!=0.0) 217 | use_ref_imgs = ((weights['pc_weight']!=0.0) or (weights['rgb_s_weight']!=0.0)) 218 | 219 | n_points = self.n_training_points 220 | nl = self.nearest_limit 221 | (img, depth_input, camera_mat_gt, scale_mat, img_idx) = self.process_data_dict(data) 222 | if use_ref_imgs: 223 | (ref_img, depth_ref, ref_idx) = self.process_data_reference(data) 224 | 225 | device = self.device 226 | batch_size, _, h, w = img.shape 227 | batch_size, _, h_depth, w_depth = depth_input.shape 228 | kwargs = dict() 229 | kwargs['t_list']=self.pose_param_net.get_t() 230 | kwargs['weights'] = weights 231 | kwargs['rgb_loss_type'] = rgb_loss_type 232 | 233 | 234 | 235 | if self.pose_param_net is not None: 236 | num_cams = self.pose_param_net.num_cams 237 | c2w = self.pose_param_net(img_idx) 238 | world_mat = torch.inverse(c2w).unsqueeze(0) 239 | 240 | if self.distortion_net is not None: 241 | scale_input,shift_input = self.distortion_net(img_idx) 242 | if self.shift_first: 243 | depth_input = (depth_input + shift_input) * scale_input 244 | else: 245 | depth_input = depth_input * scale_input + shift_input 246 | 247 | if self.optimizer_focal: 248 | fxfy = self.focal_net(0) 249 | pad = torch.zeros(4).to(device) 250 | one = torch.tensor([1]).to(device) 251 | camera_mat = torch.cat([fxfy[0:1], pad, -fxfy[1:2], pad, -one, pad, one]) 252 | camera_mat = camera_mat.view(1, 4, 4) 253 | else: 254 | camera_mat = camera_mat_gt 255 | 256 | # Sample pixels 257 | ray_idx = torch.randperm(h*w,device=device)[:n_points] 258 | img_flat = img.view(batch_size, 3, h*w).permute(0,2,1) 259 | rgb_gt = img_flat[:,ray_idx] 260 | p_full = arange_pixels((h, w), batch_size, device=device)[1] 261 | p = p_full[:, ray_idx] 262 | pix = ray_idx 263 | 264 | 265 | 266 | if render_model: 267 | out_dict = self.model( 268 | p, pix, camera_mat, world_mat, scale_mat, 269 | self.rendering_technique, it=it, 270 | eval_mode=eval_mode, depth_img=depth_input, 271 | img_size=(h, w)) 272 | rendered_rgb = out_dict['rgb'] 273 | rendered_depth = out_dict['depth_pred'] 274 | gt_depth = out_dict['depth_gt'] 275 | else: 276 | rendered_rgb = None 277 | rendered_depth = None 278 | gt_depth = None 279 | 280 | if use_ref_imgs: 281 | c2w_ref = self.pose_param_net(ref_idx) 282 | if self.distortion_net is not None: 283 | scale_ref, shift_ref = self.distortion_net(ref_idx) 284 | if self.shift_first: 285 | depth_ref = scale_ref * (depth_ref + shift_ref) 286 | else: 287 | depth_ref = scale_ref * depth_ref + shift_ref 288 | if self.detach_ref_img: 289 | c2w_ref = c2w_ref.detach() 290 | scale_ref = scale_ref.detach() 291 | shift_ref = shift_ref.detach() 292 | depth_ref = depth_ref.detach() 293 | ref_Rt = torch.inverse(c2w_ref).unsqueeze(0) 294 | 295 | 296 | if img_idx < (num_cams-1): 297 | d1 = depth_input 298 | d2 = depth_ref 299 | img1 = img 300 | img2 = ref_img 301 | Rt_rel_12 = ref_Rt @ torch.inverse(world_mat) 302 | R_rel_12 = Rt_rel_12[:, :3, :3] 303 | t_rel_12 = Rt_rel_12[:, :3, 3] 304 | scale2 = scale_ref 305 | else: 306 | d1 = depth_ref 307 | d2 = depth_input 308 | img1 = ref_img 309 | img2 = img 310 | Rt_rel_12 = world_mat @ torch.inverse(ref_Rt) 311 | R_rel_12 = Rt_rel_12[:, :3, :3] 312 | t_rel_12 = Rt_rel_12[:, :3, 3] 313 | scale2 = scale_input 314 | 315 | ratio = self.pc_ratio 316 | sample_resolution = (int(h_depth/ratio), int(w_depth/ratio)) 317 | pixel_locations, p_pc = arange_pixels(resolution=sample_resolution, device=device) 318 | d1 = F.interpolate(d1, sample_resolution ,mode='nearest') 319 | d2 = F.interpolate(d2, sample_resolution ,mode='nearest') 320 | d1[d10: 200 | gt_poses = train_dataset['img'].c2ws.to(device) 201 | # for epoch_it in tqdm(range(epoch_start+1, exit_after), desc='epochs'): 202 | while epoch_it < (scheduling_start + scheduling_epoch): 203 | epoch_it +=1 204 | L2_loss_epoch = [] 205 | pc_loss_epoch = [] 206 | rgb_s_loss_epoch = [] 207 | for batch in train_loader: 208 | it += 1 209 | idx = batch.get('img.idx') 210 | loss_dict = trainer.train_step(batch, it, epoch_it, scheduling_start, render_path) 211 | loss = loss_dict['loss'] 212 | L2_loss_epoch.append(loss_dict['l2_mean'].item()) 213 | pc_loss_epoch.append(loss_dict['loss_pc'].item()) 214 | rgb_s_loss_epoch.append(loss_dict['loss_rgb_s'].item()) 215 | scale_dict['view %02d' % (idx)] = loss_dict['scale'] 216 | shift_dict['view %02d' % (idx)] = loss_dict['shift'] 217 | if print_every > 0 and (it % print_every) == 0: 218 | tqdm.write('[Epoch %02d] it=%03d, loss=%.8f, time=%.4f' 219 | % (epoch_it, it, loss, time.time() - t0b)) 220 | logger_py.info('[Epoch %02d] it=%03d, loss=%.4f, time=%.4f' 221 | % (epoch_it, it, loss, time.time() - t0b)) 222 | t0b = time.time() 223 | for l, num in loss_dict.items(): 224 | logger.add_scalar('train/'+l, num.detach().cpu(), it) 225 | if log_scale_shift_per_view: 226 | for l, num in scale_dict.items(): 227 | logger.add_scalar('train/scale'+l, num, it) 228 | for l, num in shift_dict.items(): 229 | logger.add_scalar('train/shift'+l, num, it) 230 | 231 | if visualize_every > 0 and (it % visualize_every)==0: 232 | logger_py.info("Rendering") 233 | out_render_path = os.path.join(render_path, '%04d_vis' % it) 234 | if not os.path.exists(out_render_path): 235 | os.makedirs(out_render_path) 236 | val_rgb = trainer.render_visdata( 237 | data_test, 238 | cfg['training']['vis_resolution'], 239 | it, out_render_path) 240 | #logger.add_image('rgb', val_rgb, it) 241 | # Run validation 242 | if validate_every > 0 and (it % validate_every) == 0: 243 | eval_dict = trainer.evaluate(test_loader) 244 | 245 | for k, v in eval_dict.items(): 246 | logger.add_scalar('val/%s' % k, v, it) 247 | 248 | # Save checkpoint 249 | if (checkpoint_every > 0 and (it % checkpoint_every) == 0): 250 | logger_py.info('Saving checkpoint') 251 | print('Saving checkpoint') 252 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 253 | loss_val_best=metric_val_best, scheduling_start=scheduling_start, patient_count=patient_count) 254 | if cfg['pose']['learn_pose']: 255 | checkpoint_io_pose.save('model_pose.pt', epoch_it=epoch_it, it=it) 256 | if cfg['pose']['learn_focal']: 257 | checkpoint_io_focal.save('model_focal.pt', epoch_it=epoch_it, it=it) 258 | if cfg['distortion']['learn_distortion']: 259 | checkpoint_io_distortion.save('model_distortion.pt', epoch_it=epoch_it, it=it) 260 | 261 | # Backup if necessary 262 | if (backup_every > 0 and (it % backup_every) == 0): 263 | logger_py.info('Backup checkpoint') 264 | checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it, 265 | loss_val_best=metric_val_best, scheduling_start=scheduling_start, patient_count=patient_count) 266 | if cfg['pose']['learn_pose']: 267 | checkpoint_io_pose.save('model_pose_%d.pt' % it, epoch_it=epoch_it, it=it) 268 | if cfg['pose']['learn_focal']: 269 | checkpoint_io_focal.save('model_focal_%d.pt' % it, epoch_it=epoch_it, it=it) 270 | if cfg['distortion']['learn_distortion']: 271 | checkpoint_io_distortion.save('model_distortion_%d.pt' % it, epoch_it=epoch_it, it=it) 272 | 273 | pc_loss_epoch = np.mean(pc_loss_epoch) 274 | logger.add_scalar('train/loss_pc_epoch', pc_loss_epoch, it) 275 | rgb_s_loss_epoch = np.mean(rgb_s_loss_epoch) 276 | logger.add_scalar('train/loss_rgbs_epoch', rgb_s_loss_epoch, it) 277 | if (eval_pose_every>0 and (epoch_it % eval_pose_every) == 0): 278 | with torch.no_grad(): 279 | learned_poses = torch.stack([pose_param_net(i) for i in range(n_views)]) 280 | c2ws_est_aligned = align_ate_c2b_use_a2b(learned_poses, gt_poses) 281 | ate = compute_ATE(gt_poses.cpu().numpy(), c2ws_est_aligned.cpu().numpy()) 282 | rpe_trans, rpe_rot = compute_rpe(gt_poses.cpu().numpy(), c2ws_est_aligned.cpu().numpy()) 283 | tqdm.write('{0:6d} ep: Train: ATE: {1:.3f} RPE_r: {2:.3f}'.format(epoch_it, ate, rpe_rot* 180 / np.pi)) 284 | eval_dict = { 285 | 'ate_trans': ate, 286 | 'rpe_trans': rpe_trans*100, 287 | 'rpe_rot': rpe_rot* 180 / np.pi 288 | } 289 | for l, num in eval_dict.items(): 290 | logger.add_scalar('eval/'+l, num, it) 291 | if (eval_img_every>0 and (epoch_it % eval_img_every) == 0): 292 | L2_loss_mean = np.mean(L2_loss_epoch) 293 | psnr = mse2psnr(L2_loss_mean) 294 | tqdm.write('{0:6d} ep: Train: PSNR: {1:.3f}'.format(epoch_it, psnr)) 295 | logger.add_scalar('train/psnr', psnr, it) 296 | 297 | if not auto_scheduler: 298 | scheduler.step() 299 | new_lr = scheduler.get_lr()[0] 300 | if cfg['pose']['learn_pose']: 301 | scheduler_pose.step() 302 | new_lr_pose = scheduler_pose.get_lr()[0] 303 | if cfg['pose']['learn_focal']: 304 | scheduler_focal.step() 305 | new_lr_focal = scheduler_focal.get_lr()[0] 306 | if cfg['distortion']['learn_distortion']: 307 | scheduler_distortion.step() 308 | new_lr_distortion = scheduler_distortion.get_lr()[0] 309 | else: 310 | psnr_window.append(psnr) 311 | if len(psnr_window) >= length_smooth: 312 | psnr_window = psnr_window[-length_smooth:] 313 | metric_val = np.array(psnr_window).mean() 314 | if (metric_val - metric_val_best) >= 0: 315 | metric_val_best = metric_val 316 | else: 317 | patient_count = patient_count + 1 318 | if (patient_count == patient): 319 | scheduling_start = epoch_it 320 | if epoch_it < scheduling_start: 321 | new_lr = cfg['training']['learning_rate'] 322 | new_lr_pose = cfg['training']['pose_lr'] 323 | new_lr_focal = cfg['training']['focal_lr'] 324 | new_lr_distortion = cfg['training']['distortion_lr'] 325 | else: 326 | new_lr = cfg['training']['learning_rate'] * ((cfg['training']['scheduler_gamma'])**int((epoch_it-scheduling_start)/10)) 327 | for param_group in optimizer.param_groups: 328 | param_group['lr'] = new_lr 329 | if cfg['pose']['learn_pose']: 330 | new_lr_pose = cfg['training']['pose_lr'] * ((cfg['training']['scheduler_gamma_pose'])**int((epoch_it-scheduling_start)/100)) 331 | for param_group in optimizer_pose.param_groups: 332 | param_group['lr'] = new_lr_pose 333 | if cfg['pose']['learn_focal']: 334 | new_lr_focal = cfg['training']['focal_lr'] * ((cfg['training']['scheduler_gamma_focal'])**int((epoch_it-scheduling_start)/100)) 335 | for param_group in optimizer_focal.param_groups: 336 | param_group['lr'] = new_lr_focal 337 | if cfg['distortion']['learn_distortion']: 338 | new_lr_distortion = cfg['training']['distortion_lr'] * ((cfg['training']['scheduler_gamma_distortion'])**int((epoch_it-scheduling_start)/100)) 339 | for param_group in optimizer_distortion.param_groups: 340 | param_group['lr'] = new_lr_distortion 341 | if scheduling_mode=='reset' and epoch_it == scheduling_start: 342 | for module in model.modules(): 343 | if isinstance(module, nn.Linear): 344 | module.reset_parameters() 345 | 346 | logger.add_scalar('train/lr', new_lr, it) 347 | if cfg['pose']['learn_pose']: 348 | logger.add_scalar('train/lr_pose', new_lr_pose, it) 349 | if cfg['pose']['learn_focal']: 350 | logger.add_scalar('train/lr_focal', new_lr_focal, it) 351 | if cfg['distortion']['learn_distortion']: 352 | logger.add_scalar('train/lr_distortion', new_lr_distortion, it) 353 | 354 | if __name__=='__main__': 355 | # Arguments 356 | parser = argparse.ArgumentParser( 357 | description='Training of nope-nerf model' 358 | ) 359 | parser.add_argument('config', type=str, help='Path to config file.') 360 | args = parser.parse_args() 361 | cfg = dl.load_config(args.config, 'configs/default.yaml') 362 | # backup model 363 | backup(cfg['training']['out_dir'], args.config) 364 | train(cfg=cfg) 365 | -------------------------------------------------------------------------------- /utils_poses/align_traj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ATE.align_utils import alignTrajectory 5 | from utils_poses.lie_group_helper import SO3_to_quat, convert3x4_4x4 6 | 7 | 8 | def pts_dist_max(pts): 9 | """ 10 | :param pts: (N, 3) torch or np 11 | :return: scalar 12 | """ 13 | if torch.is_tensor(pts): 14 | dist = pts.unsqueeze(0) - pts.unsqueeze(1) # (1, N, 3) - (N, 1, 3) -> (N, N, 3) 15 | dist = dist[0] # (N, 3) 16 | dist = dist.norm(dim=1) # (N, ) 17 | max_dist = dist.max() 18 | else: 19 | dist = pts[None, :, :] - pts[:, None, :] # (1, N, 3) - (N, 1, 3) -> (N, N, 3) 20 | dist = dist[0] # (N, 3) 21 | dist = np.linalg.norm(dist, axis=1) # (N, ) 22 | max_dist = dist.max() 23 | return max_dist 24 | 25 | 26 | def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None): 27 | """Align c to b using the sim3 from a to b. 28 | :param traj_a: (N0, 3/4, 4) torch tensor 29 | :param traj_b: (N0, 3/4, 4) torch tensor 30 | :param traj_c: None or (N1, 3/4, 4) torch tensor 31 | :return: (N1, 4, 4) torch tensor 32 | """ 33 | device = traj_a.device 34 | if traj_c is None: 35 | traj_c = traj_a.clone() 36 | 37 | traj_a = traj_a.float().cpu().numpy() 38 | traj_b = traj_b.float().cpu().numpy() 39 | traj_c = traj_c.float().cpu().numpy() 40 | 41 | R_a = traj_a[:, :3, :3] # (N0, 3, 3) 42 | t_a = traj_a[:, :3, 3] # (N0, 3) 43 | quat_a = SO3_to_quat(R_a) # (N0, 4) 44 | 45 | R_b = traj_b[:, :3, :3] # (N0, 3, 3) 46 | t_b = traj_b[:, :3, 3] # (N0, 3) 47 | quat_b = SO3_to_quat(R_b) # (N0, 4) 48 | 49 | # This function works in quaternion. 50 | # scalar, (3, 3), (3, ) gt = R * s * est + t. 51 | s, R, t = alignTrajectory(t_a, t_b, quat_a, quat_b, method='sim3') 52 | 53 | # reshape tensors 54 | R = R[None, :, :].astype(np.float32) # (1, 3, 3) 55 | t = t[None, :, None].astype(np.float32) # (1, 3, 1) 56 | s = float(s) 57 | 58 | R_c = traj_c[:, :3, :3] # (N1, 3, 3) 59 | t_c = traj_c[:, :3, 3:4] # (N1, 3, 1) 60 | 61 | R_c_aligned = R @ R_c # (N1, 3, 3) 62 | t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1) 63 | traj_c_aligned = np.concatenate([R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4) 64 | 65 | # append the last row 66 | traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4) 67 | 68 | traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device) 69 | return traj_c_aligned # (N1, 4, 4) 70 | 71 | 72 | 73 | def align_scale_c2b_use_a2b(traj_a, traj_b, traj_c=None): 74 | '''Scale c to b using the scale from a to b. 75 | :param traj_a: (N0, 3/4, 4) torch tensor 76 | :param traj_b: (N0, 3/4, 4) torch tensor 77 | :param traj_c: None or (N1, 3/4, 4) torch tensor 78 | :return: 79 | scaled_traj_c (N1, 4, 4) torch tensor 80 | scale scalar 81 | ''' 82 | if traj_c is None: 83 | traj_c = traj_a.clone() 84 | 85 | t_a = traj_a[:, :3, 3] # (N, 3) 86 | t_b = traj_b[:, :3, 3] # (N, 3) 87 | 88 | # scale estimated poses to colmap scale 89 | # s_a2b: a*s ~ b 90 | scale_a2b = pts_dist_max(t_b) / pts_dist_max(t_a) 91 | 92 | traj_c[:, :3, 3] *= scale_a2b 93 | 94 | if traj_c.shape[1] == 3: 95 | traj_c = convert3x4_4x4(traj_c) # (N, 4, 4) 96 | 97 | return traj_c, scale_a2b # (N, 4, 4) 98 | -------------------------------------------------------------------------------- /utils_poses/comp_ate.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import ATE.trajectory_utils as tu 5 | import ATE.transformations as tf 6 | def rotation_error(pose_error): 7 | """Compute rotation error 8 | Args: 9 | pose_error (4x4 array): relative pose error 10 | Returns: 11 | rot_error (float): rotation error 12 | """ 13 | a = pose_error[0, 0] 14 | b = pose_error[1, 1] 15 | c = pose_error[2, 2] 16 | d = 0.5*(a+b+c-1.0) 17 | rot_error = np.arccos(max(min(d, 1.0), -1.0)) 18 | return rot_error 19 | 20 | def translation_error(pose_error): 21 | """Compute translation error 22 | Args: 23 | pose_error (4x4 array): relative pose error 24 | Returns: 25 | trans_error (float): translation error 26 | """ 27 | dx = pose_error[0, 3] 28 | dy = pose_error[1, 3] 29 | dz = pose_error[2, 3] 30 | trans_error = np.sqrt(dx**2+dy**2+dz**2) 31 | return trans_error 32 | 33 | def compute_rpe(gt, pred): 34 | trans_errors = [] 35 | rot_errors = [] 36 | for i in range(len(gt)-1): 37 | gt1 = gt[i] 38 | gt2 = gt[i+1] 39 | gt_rel = np.linalg.inv(gt1) @ gt2 40 | 41 | pred1 = pred[i] 42 | pred2 = pred[i+1] 43 | pred_rel = np.linalg.inv(pred1) @ pred2 44 | rel_err = np.linalg.inv(gt_rel) @ pred_rel 45 | 46 | trans_errors.append(translation_error(rel_err)) 47 | rot_errors.append(rotation_error(rel_err)) 48 | rpe_trans = np.mean(np.asarray(trans_errors)) 49 | rpe_rot = np.mean(np.asarray(rot_errors)) 50 | return rpe_trans, rpe_rot 51 | 52 | def compute_ATE(gt, pred): 53 | """Compute RMSE of ATE 54 | Args: 55 | gt: ground-truth poses 56 | pred: predicted poses 57 | """ 58 | errors = [] 59 | 60 | for i in range(len(pred)): 61 | # cur_gt = np.linalg.inv(gt_0) @ gt[i] 62 | cur_gt = gt[i] 63 | gt_xyz = cur_gt[:3, 3] 64 | 65 | # cur_pred = np.linalg.inv(pred_0) @ pred[i] 66 | cur_pred = pred[i] 67 | pred_xyz = cur_pred[:3, 3] 68 | 69 | align_err = gt_xyz - pred_xyz 70 | 71 | errors.append(np.sqrt(np.sum(align_err ** 2))) 72 | ate = np.sqrt(np.mean(np.asarray(errors) ** 2)) 73 | return ate 74 | 75 | -------------------------------------------------------------------------------- /utils_poses/lie_group_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.spatial.transform import Rotation as RotLib 4 | 5 | 6 | def SO3_to_quat(R): 7 | """ 8 | :param R: (N, 3, 3) or (3, 3) np 9 | :return: (N, 4, ) or (4, ) np 10 | """ 11 | x = RotLib.from_matrix(R) 12 | quat = x.as_quat() 13 | return quat 14 | 15 | 16 | def quat_to_SO3(quat): 17 | """ 18 | :param quat: (N, 4, ) or (4, ) np 19 | :return: (N, 3, 3) or (3, 3) np 20 | """ 21 | x = RotLib.from_quat(quat) 22 | R = x.as_matrix() 23 | return R 24 | 25 | 26 | def convert3x4_4x4(input): 27 | """ 28 | :param input: (N, 3, 4) or (3, 4) torch or np 29 | :return: (N, 4, 4) or (4, 4) torch or np 30 | """ 31 | if torch.is_tensor(input): 32 | if len(input.shape) == 3: 33 | output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1) # (N, 4, 4) 34 | output[:, 3, 3] = 1.0 35 | else: 36 | output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) 37 | else: 38 | if len(input.shape) == 3: 39 | output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) 40 | output[:, 3, 3] = 1.0 41 | else: 42 | output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0) # (4, 4) 43 | output[3, 3] = 1.0 44 | return output 45 | 46 | 47 | def vec2skew(v): 48 | """ 49 | :param v: (3, ) torch tensor 50 | :return: (3, 3) 51 | """ 52 | zero = torch.zeros(1, dtype=torch.float32, device=v.device) 53 | skew_v0 = torch.cat([ zero, -v[2:3], v[1:2]]) # (3, 1) 54 | skew_v1 = torch.cat([ v[2:3], zero, -v[0:1]]) 55 | skew_v2 = torch.cat([-v[1:2], v[0:1], zero]) 56 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0) # (3, 3) 57 | return skew_v # (3, 3) 58 | 59 | 60 | def Exp(r): 61 | """so(3) vector to SO(3) matrix 62 | :param r: (3, ) axis-angle, torch tensor 63 | :return: (3, 3) 64 | """ 65 | skew_r = vec2skew(r) # (3, 3) 66 | norm_r = r.norm() + 1e-15 67 | eye = torch.eye(3, dtype=torch.float32, device=r.device) 68 | R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r) 69 | return R 70 | 71 | 72 | def make_c2w(r, t): 73 | """ 74 | :param r: (3, ) axis-angle torch tensor 75 | :param t: (3, ) translation vector torch tensor 76 | :return: (4, 4) 77 | """ 78 | R = Exp(r) # (3, 3) 79 | c2w = torch.cat([R, t.unsqueeze(1)], dim=1) # (3, 4) 80 | c2w = convert3x4_4x4(c2w) # (4, 4) 81 | return c2w 82 | -------------------------------------------------------------------------------- /utils_poses/vis_cam_traj.py: -------------------------------------------------------------------------------- 1 | # This file is modified from NeRF++: https://github.com/Kai-46/nerfplusplus 2 | 3 | import numpy as np 4 | 5 | try: 6 | import open3d as o3d 7 | except ImportError: 8 | pass 9 | 10 | 11 | def frustums2lineset(frustums): 12 | N = len(frustums) 13 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum 14 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum 15 | merged_colors = np.zeros((N*8, 3)) # each line gets a color 16 | 17 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums): 18 | merged_points[i*5:(i+1)*5, :] = frustum_points 19 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5 20 | merged_colors[i*8:(i+1)*8, :] = frustum_colors 21 | 22 | lineset = o3d.geometry.LineSet() 23 | lineset.points = o3d.utility.Vector3dVector(merged_points) 24 | lineset.lines = o3d.utility.Vector2iVector(merged_lines) 25 | lineset.colors = o3d.utility.Vector3dVector(merged_colors) 26 | 27 | return lineset 28 | 29 | 30 | def get_camera_frustum_opengl_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])): 31 | '''X right, Y up, Z backward to the observer. 32 | :param H, W: 33 | :param fx, fy: 34 | :param W2C: (4, 4) matrix 35 | :param frustum_length: scalar: scale the frustum 36 | :param color: (3,) list, frustum line color 37 | :return: 38 | frustum_points: (5, 3) frustum points in world coordinate 39 | frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index. 40 | frustum_colors: (8, 3) colors for 8 lines. 41 | ''' 42 | hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.) 43 | vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.) 44 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 45 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 46 | 47 | # build view frustum in camera space in homogenous coordinate (5, 4) 48 | frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin 49 | [-half_w, half_h, -frustum_length, 1.0], # top-left image corner 50 | [half_w, half_h, -frustum_length, 1.0], # top-right image corner 51 | [half_w, -half_h, -frustum_length, 1.0], # bottom-right image corner 52 | [-half_w, -half_h, -frustum_length, 1.0]]) # bottom-left image corner 53 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2) 54 | frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3) 55 | 56 | # transform view frustum from camera space to world space 57 | C2W = np.linalg.inv(W2C) 58 | frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4) 59 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate 60 | return frustum_points, frustum_lines, frustum_colors 61 | 62 | def get_camera_frustum_opencv_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])): 63 | '''X right, Y up, Z backward to the observer. 64 | :param H, W: 65 | :param fx, fy: 66 | :param W2C: (4, 4) matrix 67 | :param frustum_length: scalar: scale the frustum 68 | :param color: (3,) list, frustum line color 69 | :return: 70 | frustum_points: (5, 3) frustum points in world coordinate 71 | frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index. 72 | frustum_colors: (8, 3) colors for 8 lines. 73 | ''' 74 | hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.) 75 | vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.) 76 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 77 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 78 | 79 | # build view frustum in camera space in homogenous coordinate (5, 4) 80 | frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin 81 | [-half_w, -half_h, frustum_length, 1.0], # top-left image corner 82 | [ half_w, -half_h, frustum_length, 1.0], # top-right image corner 83 | [ half_w, half_h, frustum_length, 1.0], # bottom-right image corner 84 | [-half_w, +half_h, frustum_length, 1.0]]) # bottom-left image corner 85 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2) 86 | frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3) 87 | 88 | # transform view frustum from camera space to world space 89 | C2W = np.linalg.inv(W2C) 90 | frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4) 91 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate 92 | return frustum_points, frustum_lines, frustum_colors 93 | 94 | 95 | 96 | def draw_camera_frustum_geometry(c2ws, H, W, fx=600.0, fy=600.0, frustum_length=0.5, 97 | color=np.array([29.0, 53.0, 87.0])/255.0, draw_now=False, coord='opengl'): 98 | ''' 99 | :param c2ws: (N, 4, 4) np.array 100 | :param H: scalar 101 | :param W: scalar 102 | :param fx: scalar 103 | :param fy: scalar 104 | :param frustum_length: scalar 105 | :param color: None or (N, 3) or (3, ) or (1, 3) or (3, 1) np array 106 | :param draw_now: True/False call o3d vis now 107 | :return: 108 | ''' 109 | N = c2ws.shape[0] 110 | 111 | num_ele = color.flatten().shape[0] 112 | if num_ele == 3: 113 | color = color.reshape(1, 3) 114 | color = np.tile(color, (N, 1)) 115 | 116 | frustum_list = [] 117 | if coord == 'opengl': 118 | for i in range(N): 119 | frustum_list.append(get_camera_frustum_opengl_coord(H, W, fx, fy, 120 | W2C=np.linalg.inv(c2ws[i]), 121 | frustum_length=frustum_length, 122 | color=color[i])) 123 | elif coord == 'opencv': 124 | for i in range(N): 125 | frustum_list.append(get_camera_frustum_opencv_coord(H, W, fx, fy, 126 | W2C=np.linalg.inv(c2ws[i]), 127 | frustum_length=frustum_length, 128 | color=color[i])) 129 | else: 130 | print('Undefined coordinate system. Exit') 131 | exit() 132 | 133 | frustums_geometry = frustums2lineset(frustum_list) 134 | 135 | if draw_now: 136 | o3d.visualization.draw_geometries([frustums_geometry]) 137 | 138 | return frustums_geometry # this is an o3d geometry object. 139 | -------------------------------------------------------------------------------- /vis/render.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import torch 8 | 9 | sys.path.append(os.path.join(sys.path[0], '..')) 10 | from dataloading import get_dataloader, load_config 11 | from model.checkpoints import CheckpointIO 12 | from model.common import convert3x4_4x4, interp_poses, interp_poses_bspline, generate_spiral_nerf 13 | from model.extracting_images import Extract_Images 14 | import model as mdl 15 | import imageio 16 | import numpy as np 17 | 18 | torch.manual_seed(0) 19 | 20 | # Config 21 | parser = argparse.ArgumentParser( 22 | description='Extract images.' 23 | ) 24 | parser.add_argument('config', type=str, help='Path to config file.') 25 | args = parser.parse_args() 26 | cfg = load_config(args.config, 'configs/default.yaml') 27 | is_cuda = (torch.cuda.is_available()) 28 | device = torch.device("cuda" if is_cuda else "cpu") 29 | out_dir = cfg['training']['out_dir'] 30 | generation_dir = os.path.join(out_dir, cfg['extract_images']['extraction_dir']) 31 | 32 | # Model 33 | model_cfg = cfg['model'] 34 | network_type = cfg['model']['network_type'] 35 | if network_type=='official': 36 | model = mdl.OfficialStaticNerf(cfg) 37 | 38 | rendering_cfg = cfg['rendering'] 39 | renderer = mdl.Renderer(model, rendering_cfg, device=device) 40 | 41 | # init model 42 | nope_nerf = mdl.get_model(renderer, cfg, device=device) 43 | 44 | checkpoint_io = CheckpointIO(out_dir, model=nope_nerf) 45 | load_dict = checkpoint_io.load(cfg['extract_images']['model_file']) 46 | it = load_dict.get('it', -1) 47 | 48 | op = cfg['extract_images']['traj_option'] 49 | N_novel_imgs = cfg['extract_images']['N_novel_imgs'] 50 | 51 | train_loader, train_dataset = get_dataloader(cfg, mode='render', shuffle=False, n_views=N_novel_imgs) 52 | n_views = train_dataset['img'].N_imgs 53 | 54 | if cfg['pose']['learn_pose']: 55 | if cfg['pose']['init_pose']: 56 | init_pose = train_dataset['img'].c2ws 57 | else: 58 | init_pose = None 59 | pose_param_net = mdl.LearnPose(n_views, cfg['pose']['learn_R'], cfg['pose']['learn_t'], cfg=cfg, init_c2w=init_pose).to(device=device) 60 | checkpoint_io_pose = mdl.CheckpointIO(out_dir, model=pose_param_net) 61 | checkpoint_io_pose.load(cfg['extract_images']['model_file_pose']) 62 | learned_poses = torch.stack([pose_param_net(i) for i in range(n_views)]) 63 | 64 | if op=='sprial': 65 | bds = np.array([2., 4.]) 66 | hwf = train_dataset['img'].hwf 67 | c2ws = generate_spiral_nerf(learned_poses, bds, N_novel_imgs, hwf) 68 | c2ws = convert3x4_4x4(c2ws) 69 | elif op =='interp': 70 | c2ws = interp_poses(learned_poses.detach().cpu(), N_novel_imgs) 71 | elif op=='bspline': 72 | i_train = train_dataset['img'].i_train 73 | degree=cfg['extract_images']['bspline_degree'] 74 | c2ws = interp_poses_bspline(learned_poses.detach().cpu(), N_novel_imgs, i_train,degree) 75 | 76 | c2ws = c2ws.to(device) 77 | if cfg['pose']['learn_focal']: 78 | focal_net = mdl.LearnFocal(cfg['pose']['learn_focal'], cfg['pose']['fx_only'], order=cfg['pose']['focal_order']) 79 | checkpoint_io_focal = mdl.CheckpointIO(out_dir, model=focal_net) 80 | checkpoint_io_focal.load(cfg['extract_images']['model_file_focal']) 81 | fxfy = focal_net(0) 82 | print('learned fx: {0:.2f}, fy: {1:.2f}'.format(fxfy[0].item(), fxfy[1].item())) 83 | else: 84 | fxfy = None 85 | # Generator 86 | generator = Extract_Images( 87 | renderer,cfg,use_learnt_poses=cfg['pose']['learn_pose'], 88 | use_learnt_focal=cfg['pose']['learn_focal'], 89 | device=device, render_type=cfg['rendering']['type'] 90 | ) 91 | 92 | # Generate 93 | model.eval() 94 | 95 | render_dir = os.path.join(generation_dir, 'extracted_images', op) 96 | if not os.path.exists(render_dir): 97 | os.makedirs(render_dir) 98 | 99 | imgs = [] 100 | depths = [] 101 | geos = [] 102 | output_geo = False 103 | for data in train_loader: 104 | out = generator.generate_images(data, render_dir, c2ws, fxfy, it, output_geo) 105 | imgs.append(out['img']) 106 | depths.append(out['depth']) 107 | geos.append(out['geo']) 108 | imgs = np.stack(imgs, axis=0) 109 | depths = np.stack(depths, axis=0) 110 | 111 | video_out_dir = os.path.join(render_dir, 'video_out') 112 | if not os.path.exists(video_out_dir): 113 | os.makedirs(video_out_dir) 114 | imageio.mimwrite(os.path.join(video_out_dir, 'img.mp4'), imgs, fps=30, quality=9) 115 | imageio.mimwrite(os.path.join(video_out_dir, 'depth.mp4'), depths, fps=30, quality=9) 116 | if output_geo: 117 | geos = np.stack(geos, axis=0) 118 | imageio.mimwrite(os.path.join(video_out_dir, 'geo.mp4'), geos, fps=30, quality=9) 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /vis/vis_poses.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | sys.path.append(os.path.join(sys.path[0], '..')) 6 | from dataloading import get_dataloader, load_config 7 | import model as mdl 8 | import numpy as np 9 | 10 | from utils_poses.vis_cam_traj import draw_camera_frustum_geometry 11 | from utils_poses.align_traj import pts_dist_max 12 | import open3d as o3d 13 | torch.manual_seed(0) 14 | 15 | # Config 16 | parser = argparse.ArgumentParser( 17 | description='Eval Poses.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | args = parser.parse_args() 21 | cfg = load_config(args.config, 'configs/default.yaml') 22 | 23 | is_cuda = (torch.cuda.is_available()) 24 | device = torch.device("cuda" if is_cuda else "cpu") 25 | 26 | 27 | out_dir = cfg['training']['out_dir'] 28 | 29 | test_loader, field = get_dataloader(cfg, mode='train', shuffle=False) 30 | N_imgs = field['img'].N_imgs 31 | with torch.no_grad(): 32 | if cfg['pose']['init_pose']: 33 | if cfg['pose']['init_pose_type']=='gt': 34 | init_pose = field['img'].c2ws # init with colmap 35 | elif cfg['pose']['init_pose_type']=='colmap': 36 | init_pose = field['img'].c2ws_colmap 37 | else: 38 | init_pose = None 39 | pose_param_net = mdl.LearnPose(N_imgs, cfg['pose']['learn_R'], 40 | cfg['pose']['learn_t'], cfg=cfg, init_c2w=init_pose).to(device=device) 41 | checkpoint_io_pose = mdl.CheckpointIO(out_dir, model=pose_param_net) 42 | checkpoint_io_pose.load(cfg['extract_images']['model_file_pose'], device) 43 | learned_poses = torch.stack([pose_param_net(i) for i in range(N_imgs)]) 44 | 45 | H = field['img'].H 46 | W = field['img'].W 47 | gt_poses = field['img'].c2ws 48 | if cfg['pose']['learn_focal']: 49 | focal_net = mdl.LearnFocal(cfg['pose']['learn_focal'], cfg['pose']['fx_only'], order=cfg['pose']['focal_order']) 50 | checkpoint_io_focal = mdl.CheckpointIO(out_dir, model=focal_net) 51 | checkpoint_io_focal.load(cfg['extract_images']['model_file_focal'], device) 52 | fxfy = focal_net(0) 53 | fx = fxfy[0] * W / 2 54 | fy = fxfy[1] * H / 2 55 | else: 56 | fx = field['img'].focal 57 | fy = field['img'].focal 58 | 59 | # scale estimated poses to unit sphere 60 | ts_est = learned_poses[:, :3, 3] # (N, 3) 61 | learned_poses[:, :3, 3] /= pts_dist_max(ts_est) 62 | learned_poses[:, :3, 3] *= 2.0 63 | 64 | '''Define camera frustums''' 65 | frustum_length = 0.1 66 | est_traj_color = np.array([39, 125, 161], dtype=np.float32) / 255 67 | 68 | 69 | frustum_est_list = draw_camera_frustum_geometry(learned_poses.cpu().numpy(), H, W, 70 | fx, fy, 71 | frustum_length, est_traj_color) 72 | 73 | 74 | geometry_to_draw = [] 75 | geometry_to_draw.append(frustum_est_list) 76 | 77 | 78 | unit_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=1.0, resolution=2) 79 | unit_sphere = o3d.geometry.LineSet.create_from_triangle_mesh(unit_sphere) 80 | unit_sphere.paint_uniform_color((0, 1, 0)) 81 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame() 82 | 83 | 84 | o3d.visualization.draw_geometries(geometry_to_draw) 85 | 86 | 87 | 88 | 89 | 90 | --------------------------------------------------------------------------------