├── media └── teaser.png ├── data_generation ├── replica_render_config_vMAP.yaml ├── README.md ├── extract_inst_obj.py ├── transformation.py ├── settings.py └── habitat_renderer.py ├── vis.py ├── metric ├── metrics.py ├── eval_3D_scene.py └── eval_3D_obj.py ├── image_transforms.py ├── configs ├── ScanNet │ ├── config_scannet0000_iMAP.json │ ├── config_scannet0024_iMAP.json │ ├── config_scannet0000_vMAP.json │ └── config_scannet0024_vMAP.json └── Replica │ ├── config_replica_room0_iMAP.json │ └── config_replica_room0_vMAP.json ├── .gitignore ├── model.py ├── loss.py ├── environment.yml ├── embedding.py ├── trainer.py ├── cfg.py ├── render_rays.py ├── README.md ├── LICENSE ├── dataset.py ├── utils.py ├── train.py └── vmap.py /media/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kxhit/vMAP/HEAD/media/teaser.png -------------------------------------------------------------------------------- /data_generation/replica_render_config_vMAP.yaml: -------------------------------------------------------------------------------- 1 | # Agent settings 2 | default_agent: 0 3 | gpu_id: 0 4 | width: 1200 #1280 5 | height: 680 #960 6 | sensor_height: 0 7 | 8 | color_sensor: true # RGB sensor 9 | semantic_sensor: true # Semantic sensor 10 | depth_sensor: true # Depth sensor 11 | enable_semantics: true 12 | 13 | # room_0 14 | scene_file: "/home/xin/data/vmap/room_0/habitat/mesh_semantic.ply" 15 | instance2class_mapping: "/home/xin/data/vmap/room_0/habitat/info_semantic.json" 16 | save_path: "/home/xin/data/vmap/room_0/vmap/00/" 17 | pose_file: "/home/xin/data/vmap/room_0/vmap/00/traj_w_c.txt" 18 | ## HDR texture 19 | ## issue https://github.com/facebookresearch/Replica-Dataset/issues/41#issuecomment-566251467 20 | #scene_file: "/home/xin/data/vmap/room_0/mesh.ply" 21 | #instance2class_mapping: "/home/xin/data/vmap/room_0/habitat/info_semantic.json" 22 | #save_path: "/home/xin/data/vmap/room_0/vmap/00/" 23 | #pose_file: "/home/xin/data/vmap/room_0/vmap/00/traj_w_c.txt" -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import skimage.measure 2 | import trimesh 3 | import open3d as o3d 4 | import numpy as np 5 | 6 | def marching_cubes(occupancy, level=0.5): 7 | try: 8 | vertices, faces, vertex_normals, _ = skimage.measure.marching_cubes( #marching_cubes_lewiner( #marching_cubes( 9 | occupancy, level=level, gradient_direction='ascent') 10 | except (RuntimeError, ValueError): 11 | return None 12 | 13 | dim = occupancy.shape[0] 14 | vertices = vertices / (dim - 1) 15 | mesh = trimesh.Trimesh(vertices=vertices, 16 | vertex_normals=vertex_normals, 17 | faces=faces) 18 | 19 | return mesh 20 | 21 | def trimesh_to_open3d(src): 22 | dst = o3d.geometry.TriangleMesh() 23 | dst.vertices = o3d.utility.Vector3dVector(src.vertices) 24 | dst.triangles = o3d.utility.Vector3iVector(src.faces) 25 | vertex_colors = src.visual.vertex_colors[:, :3].astype(np.float) / 255.0 26 | dst.vertex_colors = o3d.utility.Vector3dVector(vertex_colors) 27 | dst.compute_vertex_normals() 28 | 29 | return dst -------------------------------------------------------------------------------- /metric/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import cKDTree as KDTree 3 | 4 | def completion_ratio(gt_points, rec_points, dist_th=0.01): 5 | gen_points_kd_tree = KDTree(rec_points) 6 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points) 7 | completion = np.mean((one_distances < dist_th).astype(np.float)) 8 | return completion 9 | 10 | 11 | def accuracy(gt_points, rec_points): 12 | gt_points_kd_tree = KDTree(gt_points) 13 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points) 14 | gen_to_gt_chamfer = np.mean(two_distances) 15 | return gen_to_gt_chamfer 16 | 17 | 18 | def completion(gt_points, rec_points): 19 | gt_points_kd_tree = KDTree(rec_points) 20 | one_distances, two_vertex_ids = gt_points_kd_tree.query(gt_points) 21 | gt_to_gen_chamfer = np.mean(one_distances) 22 | return gt_to_gen_chamfer 23 | 24 | 25 | def chamfer(gt_points, rec_points): 26 | # one direction 27 | gen_points_kd_tree = KDTree(rec_points) 28 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points) 29 | gt_to_gen_chamfer = np.mean(one_distances) 30 | 31 | # other direction 32 | gt_points_kd_tree = KDTree(gt_points) 33 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points) 34 | gen_to_gt_chamfer = np.mean(two_distances) 35 | 36 | return (gt_to_gen_chamfer + gen_to_gt_chamfer) / 2. 37 | 38 | -------------------------------------------------------------------------------- /image_transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | class BGRtoRGB(object): 6 | """bgr format to rgb""" 7 | 8 | def __call__(self, image): 9 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 10 | return image 11 | 12 | 13 | class DepthScale(object): 14 | """scale depth to meters""" 15 | 16 | def __init__(self, scale): 17 | self.scale = scale 18 | 19 | def __call__(self, depth): 20 | depth = depth.astype(np.float32) 21 | return depth * self.scale 22 | 23 | 24 | class DepthFilter(object): 25 | """scale depth to meters""" 26 | 27 | def __init__(self, max_depth): 28 | self.max_depth = max_depth 29 | 30 | def __call__(self, depth): 31 | far_mask = depth > self.max_depth 32 | depth[far_mask] = 0. 33 | return depth 34 | 35 | 36 | class Undistort(object): 37 | """scale depth to meters""" 38 | 39 | def __init__(self, 40 | w, h, 41 | fx, fy, cx, cy, 42 | k1, k2, k3, k4, k5, k6, 43 | p1, p2, 44 | interpolation): 45 | self.interpolation = interpolation 46 | K = np.array([[fx, 0., cx], 47 | [0., fy, cy], 48 | [0., 0., 1.]]) 49 | 50 | self.map1x, self.map1y = cv2.initUndistortRectifyMap( 51 | K, 52 | np.array([k1, k2, p1, p2, k3, k4, k5, k6]), 53 | np.eye(3), 54 | K, 55 | (w, h), 56 | cv2.CV_32FC1) 57 | 58 | def __call__(self, im): 59 | im = cv2.remap(im, self.map1x, self.map1y, self.interpolation) 60 | return im 61 | -------------------------------------------------------------------------------- /data_generation/README.md: -------------------------------------------------------------------------------- 1 | ## Replica Data Generation 2 | 3 | ### Download Replica Dataset 4 | Download 3D models and info files from [Replica](https://github.com/facebookresearch/Replica-Dataset) 5 | 6 | ### 3D Object Mesh Extraction 7 | Change the input path in `./data_generation/extract_inst_obj.py` and run 8 | ```angular2html 9 | python ./data_generation/extract_inst_obj.py 10 | ``` 11 | 12 | ### Camera Trajectory Generation 13 | Please refer to [Semantic-NeRF](https://github.com/Harry-Zhi/semantic_nerf/issues/25#issuecomment-1340595427) for more details. The random trajectory generation only works for single room scene. For multiple rooms scene, collision checking is needed. Welcome contributions. 14 | 15 | ### Rendering 2D Images 16 | Given camera trajectory t_wc (change pose_file in configs), we use [Habitat-Sim](https://github.com/facebookresearch/habitat-sim) to render RGB, Depth, Semantic and Instance images. 17 | 18 | #### Install Habitat-Sim 0.2.1 19 | We recommend to use conda to install habitat-sim 0.2.1. 20 | ```angular2html 21 | conda create -n habitat python=3.8.12 cmake=3.14.0 22 | conda activate habitat 23 | conda install habitat-sim=0.2.1 withbullet -c conda-forge -c aihabitat 24 | conda install numba=0.54.1 25 | ``` 26 | 27 | #### Run rendering with configs 28 | ```angular2html 29 | python ./data_generation/habitat_renderer.py --config ./data_generation/replica_render_config_vMAP.yaml 30 | ``` 31 | Note that to get HDR img, use mesh.ply not semantic_mesh.ply. Change path in configs. And copy rgb folder to replace previous high exposure rgb. 32 | ```angular2html 33 | python ./data_generation/habitat_renderer.py --config ./data_generation/replica_render_config_vMAP.yaml 34 | ``` -------------------------------------------------------------------------------- /data_generation/extract_inst_obj.py: -------------------------------------------------------------------------------- 1 | # reference https://github.com/facebookresearch/Replica-Dataset/issues/17#issuecomment-538757418 2 | 3 | from plyfile import * 4 | import numpy as np 5 | import trimesh 6 | 7 | 8 | # path_in = 'path/to/mesh_semantic.ply' 9 | path_in = '/home/xin/data/vmap/room_0_debug/habitat/mesh_semantic.ply' 10 | 11 | print("Reading input...") 12 | mesh = trimesh.load(path_in) 13 | # mesh.show() 14 | file_in = PlyData.read(path_in) 15 | vertices_in = file_in.elements[0] 16 | faces_in = file_in.elements[1] 17 | 18 | print("Filtering data...") 19 | objects = {} 20 | sub_mesh_indices = {} 21 | for i, f in enumerate(faces_in): 22 | object_id = f[1] 23 | if not object_id in objects: 24 | objects[object_id] = [] 25 | sub_mesh_indices[object_id] = [] 26 | objects[object_id].append((f[0],)) 27 | sub_mesh_indices[object_id].append(i) 28 | sub_mesh_indices[object_id].append(i+faces_in.data.shape[0]) 29 | 30 | 31 | print("Writing data...") 32 | for object_id, faces in objects.items(): 33 | path_out = path_in + f"_{object_id}.ply" 34 | # print("sub_mesh_indices[object_id] ", sub_mesh_indices[object_id]) 35 | obj_mesh = mesh.submesh([sub_mesh_indices[object_id]], append=True) 36 | in_n = len(sub_mesh_indices[object_id]) 37 | out_n = obj_mesh.faces.shape[0] 38 | # print("obj id ", object_id) 39 | # print("in_n ", in_n) 40 | # print("out_n ", out_n) 41 | # print("faces ", len(faces)) 42 | # assert in_n == out_n 43 | obj_mesh.export(path_out) 44 | # faces_out = PlyElement.describe(np.array(faces, dtype=[('vertex_indices', 'O')]), 'face') 45 | # print("faces out ", len(PlyData([vertices_in, faces_out]).elements[1].data)) 46 | # PlyData([vertices_in, faces_out]).write(path_out+"_cmp.ply") 47 | 48 | -------------------------------------------------------------------------------- /data_generation/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import quaternion 3 | import trimesh 4 | 5 | def habitat_world_transformations(): 6 | import habitat_sim 7 | # Transforms between the habitat frame H (y-up) and the world frame W (z-up). 8 | T_wh = np.identity(4) 9 | 10 | # https://stackoverflow.com/questions/1171849/finding-quaternion-representing-the-rotation-from-one-vector-to-another 11 | T_wh[0:3, 0:3] = quaternion.as_rotation_matrix(habitat_sim.utils.common.quat_from_two_vectors( 12 | habitat_sim.geo.GRAVITY, np.array([0.0, 0.0, -1.0]))) 13 | 14 | T_hw = np.linalg.inv(T_wh) 15 | 16 | return T_wh, T_hw 17 | 18 | def opencv_to_opengl_camera(transform=None): 19 | if transform is None: 20 | transform = np.eye(4) 21 | return transform @ trimesh.transformations.rotation_matrix( 22 | np.deg2rad(180), [1, 0, 0] 23 | ) 24 | 25 | def opengl_to_opencv_camera(transform=None): 26 | if transform is None: 27 | transform = np.eye(4) 28 | return transform @ trimesh.transformations.rotation_matrix( 29 | np.deg2rad(-180), [1, 0, 0] 30 | ) 31 | 32 | def Twc_to_Thc(T_wc): # opencv-camera to world transformation ---> habitat-caemra to habitat world transformation 33 | T_wh, T_hw = habitat_world_transformations() 34 | T_hc = T_hw @ T_wc @ opengl_to_opencv_camera() 35 | return T_hc 36 | 37 | 38 | def Thc_to_Twc(T_hc): # habitat-caemra to habitat world transformation ---> opencv-camera to world transformation 39 | T_wh, T_hw = habitat_world_transformations() 40 | T_wc = T_wh @ T_hc @ opencv_to_opengl_camera() 41 | return T_wc 42 | 43 | 44 | def combine_pose(t: np.array, q: quaternion.quaternion) -> np.array: 45 | T = np.identity(4) 46 | T[0:3, 3] = t 47 | T[0:3, 0:3] = quaternion.as_rotation_matrix(q) 48 | return T -------------------------------------------------------------------------------- /configs/ScanNet/config_scannet0000_iMAP.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "live": 0, 4 | "path": "/home/xin/data/ScanNet/NICESLAM/scene0000_00", 5 | "format": "ScanNet", 6 | "keep_alive": 20 7 | }, 8 | "optimizer": { 9 | "args":{ 10 | "lr": 0.001, 11 | "weight_decay": 0.013, 12 | "pose_lr": 0.001 13 | } 14 | }, 15 | "trainer": { 16 | "imap_mode": 1, 17 | "do_bg": 0, 18 | "n_models": 1, 19 | "train_device": "cuda:0", 20 | "data_device": "cuda:0", 21 | "training_strategy": "vmap", 22 | "epochs": 1000000, 23 | "scale": 1000.0 24 | }, 25 | "render": { 26 | "depth_range": [0.0, 6.0], 27 | "n_bins": 9, 28 | "n_bins_cam2surface": 5, 29 | "n_bins_cam2surface_bg": 5, 30 | "iters_per_frame": 20, 31 | "n_per_optim": 2400, 32 | "n_per_optim_bg": 1200 33 | }, 34 | "model": { 35 | "n_unidir_funcs": 5, 36 | "obj_scale": 3.0, 37 | "bg_scale": 15.0, 38 | "color_scaling": 5.0, 39 | "opacity_scaling": 10.0, 40 | "gt_scene": 1, 41 | "surface_eps": 0.1, 42 | "other_eps": 0.05, 43 | "keyframe_buffer_size": 20, 44 | "keyframe_step": 50, 45 | "keyframe_step_bg": 50, 46 | "window_size": 5, 47 | "window_size_bg": 10, 48 | "hidden_layers_block": 1, 49 | "hidden_feature_size": 256, 50 | "hidden_feature_size_bg": 128 51 | }, 52 | "camera": { 53 | "w": 640, 54 | "h": 480, 55 | "mw": 10, 56 | "mh": 10 57 | }, 58 | "vis": { 59 | "vis_device": "cuda:0", 60 | "n_vis_iter": 500, 61 | "n_bins_fine_vis": 10, 62 | "im_vis_reduce": 10, 63 | "grid_dim": 256, 64 | "live_vis": 1, 65 | "live_voxel_size": 0.005 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /configs/ScanNet/config_scannet0024_iMAP.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "live": 0, 4 | "path": "/home/xin/data/ScanNet/obj-imap/scene0024_00", 5 | "format": "ScanNet", 6 | "keep_alive": 20 7 | }, 8 | "optimizer": { 9 | "args":{ 10 | "lr": 0.001, 11 | "weight_decay": 0.013, 12 | "pose_lr": 0.001 13 | } 14 | }, 15 | "trainer": { 16 | "imap_mode": 1, 17 | "do_bg": 0, 18 | "n_models": 1, 19 | "train_device": "cuda:0", 20 | "data_device": "cuda:0", 21 | "training_strategy": "vmap", 22 | "epochs": 1000000, 23 | "scale": 1000.0 24 | }, 25 | "render": { 26 | "depth_range": [0.0, 6.0], 27 | "n_bins": 9, 28 | "n_bins_cam2surface": 5, 29 | "n_bins_cam2surface_bg": 5, 30 | "iters_per_frame": 20, 31 | "n_per_optim": 2400, 32 | "n_per_optim_bg": 1200 33 | }, 34 | "model": { 35 | "n_unidir_funcs": 5, 36 | "obj_scale": 2.0, 37 | "bg_scale": 5.0, 38 | "color_scaling": 5.0, 39 | "opacity_scaling": 10.0, 40 | "gt_scene": 1, 41 | "surface_eps": 0.1, 42 | "other_eps": 0.05, 43 | "keyframe_buffer_size": 20, 44 | "keyframe_step": 50, 45 | "keyframe_step_bg": 50, 46 | "window_size": 5, 47 | "window_size_bg": 10, 48 | "hidden_layers_block": 1, 49 | "hidden_feature_size": 256, 50 | "hidden_feature_size_bg": 128 51 | }, 52 | "camera": { 53 | "w": 640, 54 | "h": 480, 55 | "mw": 10, 56 | "mh": 10 57 | }, 58 | "vis": { 59 | "vis_device": "cuda:0", 60 | "n_vis_iter": 500, 61 | "n_bins_fine_vis": 10, 62 | "im_vis_reduce": 10, 63 | "grid_dim": 256, 64 | "live_vis": 1, 65 | "live_voxel_size": 0.005 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /configs/ScanNet/config_scannet0000_vMAP.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "live": 0, 4 | "path": "/home/xin/data/ScanNet/NICESLAM/scene0000_00", 5 | "format": "ScanNet", 6 | "keep_alive": 20 7 | }, 8 | "optimizer": { 9 | "args":{ 10 | "lr": 0.001, 11 | "weight_decay": 0.013, 12 | "pose_lr": 0.001 13 | } 14 | }, 15 | "trainer": { 16 | "imap_mode": 0, 17 | "do_bg": 1, 18 | "n_models": 100, 19 | "train_device": "cuda:0", 20 | "data_device": "cuda:0", 21 | "training_strategy": "vmap", 22 | "epochs": 1000000, 23 | "scale": 1000.0 24 | }, 25 | "render": { 26 | "depth_range": [0.0, 6.0], 27 | "n_bins": 9, 28 | "n_bins_cam2surface": 1, 29 | "n_bins_cam2surface_bg": 5, 30 | "iters_per_frame": 20, 31 | "n_per_optim": 120, 32 | "n_per_optim_bg": 1200 33 | }, 34 | "model": { 35 | "n_unidir_funcs": 5, 36 | "obj_scale": 3.0, 37 | "bg_scale": 10.0, 38 | "color_scaling": 5.0, 39 | "opacity_scaling": 10.0, 40 | "gt_scene": 1, 41 | "surface_eps": 0.1, 42 | "other_eps": 0.05, 43 | "keyframe_buffer_size": 20, 44 | "keyframe_step": 25, 45 | "keyframe_step_bg": 50, 46 | "window_size": 5, 47 | "window_size_bg": 10, 48 | "hidden_layers_block": 1, 49 | "hidden_feature_size": 32, 50 | "hidden_feature_size_bg": 128 51 | }, 52 | "camera": { 53 | "w": 640, 54 | "h": 480, 55 | "mw": 10, 56 | "mh": 10 57 | }, 58 | "vis": { 59 | "vis_device": "cuda:0", 60 | "n_vis_iter": 10000000, 61 | "n_bins_fine_vis": 10, 62 | "im_vis_reduce": 10, 63 | "grid_dim": 256, 64 | "live_vis": 1, 65 | "live_voxel_size": 0.005 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /configs/ScanNet/config_scannet0024_vMAP.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "live": 0, 4 | "path": "/home/xin/data/ScanNet/obj-imap/scene0024_00", 5 | "format": "ScanNet", 6 | "keep_alive": 20 7 | }, 8 | "optimizer": { 9 | "args":{ 10 | "lr": 0.001, 11 | "weight_decay": 0.013, 12 | "pose_lr": 0.001 13 | } 14 | }, 15 | "trainer": { 16 | "imap_mode": 0, 17 | "do_bg": 1, 18 | "n_models": 100, 19 | "train_device": "cuda:0", 20 | "data_device": "cuda:0", 21 | "training_strategy": "vmap", 22 | "epochs": 1000000, 23 | "scale": 1000.0 24 | }, 25 | "render": { 26 | "depth_range": [0.0, 6.0], 27 | "n_bins": 9, 28 | "n_bins_cam2surface": 1, 29 | "n_bins_cam2surface_bg": 5, 30 | "iters_per_frame": 20, 31 | "n_per_optim": 120, 32 | "n_per_optim_bg": 1200 33 | }, 34 | "model": { 35 | "n_unidir_funcs": 5, 36 | "obj_scale": 3.0, 37 | "bg_scale": 10.0, 38 | "color_scaling": 5.0, 39 | "opacity_scaling": 10.0, 40 | "gt_scene": 1, 41 | "surface_eps": 0.1, 42 | "other_eps": 0.05, 43 | "keyframe_buffer_size": 20, 44 | "keyframe_step": 25, 45 | "keyframe_step_bg": 50, 46 | "window_size": 5, 47 | "window_size_bg": 10, 48 | "hidden_layers_block": 1, 49 | "hidden_feature_size": 32, 50 | "hidden_feature_size_bg": 128 51 | }, 52 | "camera": { 53 | "w": 640, 54 | "h": 480, 55 | "mw": 10, 56 | "mh": 10 57 | }, 58 | "vis": { 59 | "vis_device": "cuda:0", 60 | "n_vis_iter": 10000000, 61 | "n_bins_fine_vis": 10, 62 | "im_vis_reduce": 10, 63 | "grid_dim": 256, 64 | "live_vis": 1, 65 | "live_voxel_size": 0.005 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /configs/Replica/config_replica_room0_iMAP.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "live": 0, 4 | "path": "/home/xin/data/Replica/vmap/room_0/imap/00", 5 | "format": "Replica", 6 | "keep_alive": 20 7 | }, 8 | "optimizer": { 9 | "args":{ 10 | "lr": 0.001, 11 | "weight_decay": 0.013, 12 | "pose_lr": 0.001 13 | } 14 | }, 15 | "trainer": { 16 | "imap_mode": 1, 17 | "do_bg": 0, 18 | "n_models": 1, 19 | "train_device": "cuda:0", 20 | "data_device": "cuda:0", 21 | "training_strategy": "vmap", 22 | "epochs": 1000000, 23 | "scale": 1000.0 24 | }, 25 | "render": { 26 | "depth_range": [0.0, 8.0], 27 | "n_bins": 9, 28 | "n_bins_cam2surface": 5, 29 | "n_bins_cam2surface_bg": 5, 30 | "iters_per_frame": 20, 31 | "n_per_optim": 4800, 32 | "n_per_optim_bg": 1200 33 | }, 34 | "model": { 35 | "n_unidir_funcs": 5, 36 | "obj_scale": 5.0, 37 | "bg_scale": 5.0, 38 | "color_scaling": 5.0, 39 | "opacity_scaling": 10.0, 40 | "gt_scene": 1, 41 | "surface_eps": 0.1, 42 | "other_eps": 0.05, 43 | "keyframe_buffer_size": 20, 44 | "keyframe_step": 50, 45 | "keyframe_step_bg": 50, 46 | "window_size": 5, 47 | "window_size_bg": 10, 48 | "hidden_layers_block": 1, 49 | "hidden_feature_size": 256, 50 | "hidden_feature_size_bg": 128 51 | }, 52 | "camera": { 53 | "w": 1200, 54 | "h": 680, 55 | "fx": 600.0, 56 | "fy": 600.0, 57 | "cx": 599.5, 58 | "cy": 339.5, 59 | "mw": 0, 60 | "mh": 0 61 | }, 62 | "vis": { 63 | "vis_device": "cuda:0", 64 | "n_vis_iter": 500, 65 | "n_bins_fine_vis": 10, 66 | "im_vis_reduce": 10, 67 | "grid_dim": 256, 68 | "live_vis": 1, 69 | "live_voxel_size": 0.005 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /configs/Replica/config_replica_room0_vMAP.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "live": 0, 4 | "path": "/home/xin/data/Replica/vmap/room_0/imap/00", 5 | "format": "Replica", 6 | "keep_alive": 20 7 | }, 8 | "optimizer": { 9 | "args":{ 10 | "lr": 0.001, 11 | "weight_decay": 0.013, 12 | "pose_lr": 0.001 13 | } 14 | }, 15 | "trainer": { 16 | "imap_mode": 0, 17 | "do_bg": 1, 18 | "n_models": 100, 19 | "train_device": "cuda:0", 20 | "data_device": "cuda:0", 21 | "training_strategy": "vmap", 22 | "epochs": 1000000, 23 | "scale": 1000.0 24 | }, 25 | "render": { 26 | "depth_range": [0.0, 8.0], 27 | "n_bins": 9, 28 | "n_bins_cam2surface": 1, 29 | "n_bins_cam2surface_bg": 5, 30 | "iters_per_frame": 20, 31 | "n_per_optim": 120, 32 | "n_per_optim_bg": 1200 33 | }, 34 | "model": { 35 | "n_unidir_funcs": 5, 36 | "obj_scale": 2.0, 37 | "bg_scale": 5.0, 38 | "color_scaling": 5.0, 39 | "opacity_scaling": 10.0, 40 | "gt_scene": 1, 41 | "surface_eps": 0.1, 42 | "other_eps": 0.05, 43 | "keyframe_buffer_size": 20, 44 | "keyframe_step": 25, 45 | "keyframe_step_bg": 50, 46 | "window_size": 5, 47 | "window_size_bg": 10, 48 | "hidden_layers_block": 1, 49 | "hidden_feature_size": 32, 50 | "hidden_feature_size_bg": 128 51 | }, 52 | "camera": { 53 | "w": 1200, 54 | "h": 680, 55 | "fx": 600.0, 56 | "fy": 600.0, 57 | "cx": 599.5, 58 | "cy": 339.5, 59 | "mw": 0, 60 | "mh": 0 61 | }, 62 | "vis": { 63 | "vis_device": "cuda:0", 64 | "n_vis_iter": 500, 65 | "n_bins_fine_vis": 10, 66 | "im_vis_reduce": 10, 67 | "grid_dim": 256, 68 | "live_vis": 1, 69 | "live_voxel_size": 0.005 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.anaconda3/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | HierarchicalPriors.sublime-project 128 | HierarchicalPriors.sublime-workspace 129 | scripts/res1/ 130 | scripts/results/ 131 | scripts/traj.txt 132 | vids_ims 133 | res_datasets/ 134 | results/ 135 | ScenePriors/train/examples/experiments/ 136 | 137 | *.pkl 138 | *.idea 139 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def init_weights(m, init_fn=torch.nn.init.xavier_normal_): 5 | if type(m) == torch.nn.Linear: 6 | init_fn(m.weight) 7 | 8 | 9 | def fc_block(in_f, out_f): 10 | return torch.nn.Sequential( 11 | torch.nn.Linear(in_f, out_f), 12 | torch.nn.ReLU(out_f) 13 | ) 14 | 15 | 16 | class OccupancyMap(torch.nn.Module): 17 | def __init__( 18 | self, 19 | emb_size1, 20 | emb_size2, 21 | hidden_size=256, 22 | do_color=True, 23 | hidden_layers_block=1 24 | ): 25 | super(OccupancyMap, self).__init__() 26 | self.do_color = do_color 27 | self.embedding_size1 = emb_size1 28 | self.in_layer = fc_block(self.embedding_size1, hidden_size) 29 | 30 | hidden1 = [fc_block(hidden_size, hidden_size) 31 | for _ in range(hidden_layers_block)] 32 | self.mid1 = torch.nn.Sequential(*hidden1) 33 | # self.embedding_size2 = 21*(5+1)+3 - self.embedding_size # 129-66=63 32 34 | self.embedding_size2 = emb_size2 35 | self.cat_layer = fc_block( 36 | hidden_size + self.embedding_size1, hidden_size) 37 | 38 | # self.cat_layer = fc_block( 39 | # hidden_size , hidden_size) 40 | 41 | hidden2 = [fc_block(hidden_size, hidden_size) 42 | for _ in range(hidden_layers_block)] 43 | self.mid2 = torch.nn.Sequential(*hidden2) 44 | 45 | self.out_alpha = torch.nn.Linear(hidden_size, 1) 46 | 47 | if self.do_color: 48 | self.color_linear = fc_block(self.embedding_size2 + hidden_size, hidden_size) 49 | self.out_color = torch.nn.Linear(hidden_size, 3) 50 | 51 | # self.relu = torch.nn.functional.relu 52 | self.sigmoid = torch.sigmoid 53 | 54 | def forward(self, x, 55 | noise_std=None, 56 | do_alpha=True, 57 | do_color=True, 58 | do_cat=True): 59 | fc1 = self.in_layer(x[...,:self.embedding_size1]) 60 | fc2 = self.mid1(fc1) 61 | # fc3 = self.cat_layer(fc2) 62 | if do_cat: 63 | fc2_x = torch.cat((fc2, x[...,:self.embedding_size1]), dim=-1) 64 | fc3 = self.cat_layer(fc2_x) 65 | else: 66 | fc3 = fc2 67 | fc4 = self.mid2(fc3) 68 | 69 | alpha = None 70 | if do_alpha: 71 | raw = self.out_alpha(fc4) # todo ignore noise 72 | if noise_std is not None: 73 | noise = torch.randn(raw.shape, device=x.device) * noise_std 74 | raw = raw + noise 75 | 76 | # alpha = self.relu(raw) * scale # nerf 77 | alpha = raw * 10. #self.scale # unisurf 78 | 79 | color = None 80 | if self.do_color and do_color: 81 | fc4_cat = self.color_linear(torch.cat((fc4, x[..., self.embedding_size1:]), dim=-1)) 82 | raw_color = self.out_color(fc4_cat) 83 | color = self.sigmoid(raw_color) 84 | 85 | return alpha, color 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import render_rays 3 | import torch.nn.functional as F 4 | 5 | def step_batch_loss(alpha, color, gt_depth, gt_color, sem_labels, mask_depth, z_vals, 6 | color_scaling=5.0, opacity_scaling=10.0): 7 | """ 8 | apply depth where depth are valid -> mask_depth 9 | apply depth, color loss on this_obj & unkown_obj == (~other_obj) -> mask_obj 10 | apply occupancy/opacity loss on this_obj & other_obj == (~unknown_obj) -> mask_sem 11 | 12 | output: 13 | loss for training 14 | loss_all for per sample, could be used for active sampling, replay buffer 15 | """ 16 | mask_obj = sem_labels != 0 17 | mask_obj = mask_obj.detach() 18 | mask_sem = sem_labels != 2 19 | mask_sem = mask_sem.detach() 20 | 21 | alpha = alpha.squeeze(dim=-1) 22 | color = color.squeeze(dim=-1) 23 | 24 | occupancy = render_rays.occupancy_activation(alpha) 25 | termination = render_rays.occupancy_to_termination(occupancy, is_batch=True) # shape [num_batch, num_ray, points_per_ray] 26 | 27 | render_depth = render_rays.render(termination, z_vals) 28 | diff_sq = (z_vals - render_depth[..., None]) ** 2 29 | var = render_rays.render(termination, diff_sq).detach() # must detach here! 30 | render_color = render_rays.render(termination[..., None], color, dim=-2) 31 | render_opacity = torch.sum(termination, dim=-1) # similar to obj-nerf opacity loss 32 | 33 | # 2D depth loss: only on valid depth & mask 34 | # [mask_depth & mask_obj] 35 | # loss_all = torch.zeros_like(render_depth) 36 | loss_depth_raw = render_rays.render_loss(render_depth, gt_depth, loss="L1", normalise=False) 37 | loss_depth = torch.mul(loss_depth_raw, mask_depth & mask_obj) # keep dim but set invalid element be zero 38 | # loss_all += loss_depth 39 | loss_depth = render_rays.reduce_batch_loss(loss_depth, var=var, avg=True, mask=mask_depth & mask_obj) # apply var as imap 40 | 41 | # 2D color loss: only on obj mask 42 | # [mask_obj] 43 | loss_col_raw = render_rays.render_loss(render_color, gt_color, loss="L1", normalise=False) 44 | loss_col = torch.mul(loss_col_raw.sum(-1), mask_obj) 45 | # loss_all += loss_col / 3. * color_scaling 46 | loss_col = render_rays.reduce_batch_loss(loss_col, var=None, avg=True, mask=mask_obj) 47 | 48 | # 2D occupancy/opacity loss: apply except unknown area 49 | # [mask_sem] 50 | # loss_opacity_raw = F.mse_loss(torch.clamp(render_opacity, 0, 1), mask_obj.float().detach()) # encourage other_obj to be empty, while this_obj to be solid 51 | # print("opacity max ", torch.max(render_opacity.max())) 52 | # print("opacity min ", torch.max(render_opacity.min())) 53 | loss_opacity_raw = render_rays.render_loss(render_opacity, mask_obj.float(), loss="L1", normalise=False) 54 | loss_opacity = torch.mul(loss_opacity_raw, mask_sem) # but ignore -1 unkown area e.g., mask edges 55 | # loss_all += loss_opacity * opacity_scaling 56 | loss_opacity = render_rays.reduce_batch_loss(loss_opacity, var=None, avg=True, mask=mask_sem) # todo var 57 | 58 | # loss for bp 59 | l_batch = loss_depth + loss_col * color_scaling + loss_opacity * opacity_scaling 60 | loss = l_batch.sum() 61 | 62 | return loss, None # return loss, loss_all.detach() 63 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vmap 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - brotlipy=0.7.0=py38h27cfd23_1003 10 | - bzip2=1.0.8=h7b6447c_0 11 | - ca-certificates=2022.10.11=h06a4308_0 12 | - certifi=2022.9.24=py38h06a4308_0 13 | - cffi=1.15.1=py38h5eee18b_2 14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 15 | - cryptography=38.0.1=py38h9ce1e76_0 16 | - cudatoolkit=11.3.1=h2bc3f7f_2 17 | - ffmpeg=4.3=hf484d3e_0 18 | - freetype=2.12.1=h4a9f257_0 19 | - giflib=5.2.1=h7b6447c_0 20 | - gmp=6.2.1=h295c915_3 21 | - gnutls=3.6.15=he1e5248_0 22 | - idna=3.4=py38h06a4308_0 23 | - intel-openmp=2021.4.0=h06a4308_3561 24 | - jpeg=9e=h7f8727e_0 25 | - lame=3.100=h7b6447c_0 26 | - lcms2=2.12=h3be6417_0 27 | - ld_impl_linux-64=2.38=h1181459_1 28 | - lerc=3.0=h295c915_0 29 | - libdeflate=1.8=h7f8727e_5 30 | - libffi=3.4.2=h6a678d5_6 31 | - libgcc-ng=11.2.0=h1234567_1 32 | - libgomp=11.2.0=h1234567_1 33 | - libiconv=1.16=h7f8727e_2 34 | - libidn2=2.3.2=h7f8727e_0 35 | - libpng=1.6.37=hbc83047_0 36 | - libstdcxx-ng=11.2.0=h1234567_1 37 | - libtasn1=4.16.0=h27cfd23_0 38 | - libtiff=4.4.0=hecacb30_2 39 | - libunistring=0.9.10=h27cfd23_0 40 | - libwebp=1.2.4=h11a3e52_0 41 | - libwebp-base=1.2.4=h5eee18b_0 42 | - lz4-c=1.9.3=h295c915_1 43 | - mkl=2021.4.0=h06a4308_640 44 | - mkl-service=2.4.0=py38h7f8727e_0 45 | - mkl_fft=1.3.1=py38hd3c417c_0 46 | - mkl_random=1.2.2=py38h51133e4_0 47 | - ncurses=6.3=h5eee18b_3 48 | - nettle=3.7.3=hbbd107a_1 49 | - numpy=1.23.4=py38h14f4228_0 50 | - numpy-base=1.23.4=py38h31eccc5_0 51 | - openh264=2.1.1=h4ff587b_0 52 | - openssl=1.1.1s=h7f8727e_0 53 | - pillow=9.2.0=py38hace64e9_1 54 | - pip=22.2.2=py38h06a4308_0 55 | - pycparser=2.21=pyhd3eb1b0_0 56 | - pyopenssl=22.0.0=pyhd3eb1b0_0 57 | - pysocks=1.7.1=py38h06a4308_0 58 | - python=3.8.15=h7a1cb2a_2 59 | - pytorch=1.12.1=py3.8_cuda11.3_cudnn8.3.2_0 60 | - pytorch-mutex=1.0=cuda 61 | - readline=8.2=h5eee18b_0 62 | - requests=2.28.1=py38h06a4308_0 63 | - setuptools=65.5.0=py38h06a4308_0 64 | - six=1.16.0=pyhd3eb1b0_1 65 | - sqlite=3.40.0=h5082296_0 66 | - tk=8.6.12=h1ccaba5_0 67 | - torchaudio=0.12.1=py38_cu113 68 | - torchvision=0.13.1=py38_cu113 69 | - typing_extensions=4.3.0=py38h06a4308_0 70 | - urllib3=1.26.12=py38h06a4308_0 71 | - wheel=0.37.1=pyhd3eb1b0_0 72 | - xz=5.2.6=h5eee18b_0 73 | - zlib=1.2.13=h5eee18b_0 74 | - zstd=1.5.2=ha4553b6_0 75 | - pip: 76 | - antlr4-python3-runtime==4.9.3 77 | - bidict==0.22.0 78 | - click==8.1.3 79 | - dash==2.7.0 80 | - dash-core-components==2.0.0 81 | - dash-html-components==2.0.0 82 | - dash-table==5.0.0 83 | - entrypoints==0.4 84 | - flask==2.2.2 85 | - functorch==0.2.0 86 | - fvcore==0.1.5.post20221122 87 | - h5py==3.7.0 88 | - imageio==2.22.4 89 | - importlib-metadata==5.1.0 90 | - iopath==0.1.10 91 | - itsdangerous==2.1.2 92 | - lpips==0.1.4 93 | - nbformat==5.5.0 94 | - omegaconf==2.2.3 95 | - open3d==0.16.0 96 | - pexpect==4.8.0 97 | - plotly==5.11.0 98 | - pycocotools==2.0.6 99 | - pyquaternion==0.9.9 100 | - scikit-image==0.19.3 101 | - tenacity==8.1.0 102 | - termcolor==2.1.1 103 | - timm==0.6.12 104 | - werkzeug==2.2.2 105 | - yacs==0.1.8 106 | - zipp==3.11.0 107 | 108 | -------------------------------------------------------------------------------- /embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def positional_encoding( 5 | tensor, 6 | B_layer=None, 7 | num_encoding_functions=6, 8 | scale=10. 9 | ): 10 | if B_layer is not None: 11 | embedding_gauss = B_layer(tensor / scale) 12 | embedding_gauss = torch.sin(embedding_gauss) 13 | embedding = embedding_gauss 14 | else: 15 | frequency_bands = 2.0 ** torch.linspace( 16 | 0.0, 17 | num_encoding_functions - 1, 18 | num_encoding_functions, 19 | dtype=tensor.dtype, 20 | device=tensor.device, 21 | ) 22 | 23 | n_repeat = num_encoding_functions * 2 + 1 24 | embedding = tensor[..., None, :].repeat(1, 1, n_repeat, 1) / scale 25 | even_idx = np.arange(1, num_encoding_functions + 1) * 2 26 | odd_idx = even_idx - 1 27 | 28 | frequency_bands = frequency_bands[None, None, :, None] 29 | 30 | embedding[:, :, even_idx, :] = torch.cos( 31 | frequency_bands * embedding[:, :, even_idx, :]) 32 | embedding[:, :, odd_idx, :] = torch.sin( 33 | frequency_bands * embedding[:, :, odd_idx, :]) 34 | 35 | n_dim = tensor.shape[-1] 36 | embedding = embedding.view( 37 | embedding.shape[0], embedding.shape[1], n_repeat * n_dim) 38 | # print("embedding ", embedding.shape) 39 | embedding = embedding.squeeze(0) 40 | 41 | return embedding 42 | 43 | class UniDirsEmbed(torch.nn.Module): 44 | def __init__(self, min_deg=0, max_deg=2, scale=2.): 45 | super(UniDirsEmbed, self).__init__() 46 | self.min_deg = min_deg 47 | self.max_deg = max_deg 48 | self.n_freqs = max_deg - min_deg + 1 49 | self.tensor_scale = torch.tensor(scale, requires_grad=False) 50 | 51 | dirs = torch.tensor([ 52 | 0.8506508, 0, 0.5257311, 53 | 0.809017, 0.5, 0.309017, 54 | 0.5257311, 0.8506508, 0, 55 | 1, 0, 0, 56 | 0.809017, 0.5, -0.309017, 57 | 0.8506508, 0, -0.5257311, 58 | 0.309017, 0.809017, -0.5, 59 | 0, 0.5257311, -0.8506508, 60 | 0.5, 0.309017, -0.809017, 61 | 0, 1, 0, 62 | -0.5257311, 0.8506508, 0, 63 | -0.309017, 0.809017, -0.5, 64 | 0, 0.5257311, 0.8506508, 65 | -0.309017, 0.809017, 0.5, 66 | 0.309017, 0.809017, 0.5, 67 | 0.5, 0.309017, 0.809017, 68 | 0.5, -0.309017, 0.809017, 69 | 0, 0, 1, 70 | -0.5, 0.309017, 0.809017, 71 | -0.809017, 0.5, 0.309017, 72 | -0.809017, 0.5, -0.309017 73 | ]).reshape(-1, 3) 74 | 75 | self.B_layer = torch.nn.Linear(3, 21, bias=False) 76 | self.B_layer.weight.data = dirs 77 | 78 | frequency_bands = 2.0 ** torch.linspace(self.min_deg, self.max_deg, self.n_freqs) 79 | self.register_buffer("frequency_bands", frequency_bands, persistent=False) 80 | self.register_buffer("scale", self.tensor_scale, persistent=True) 81 | 82 | def forward(self, x): 83 | tensor = x / self.scale # functorch needs buffer, otherwise changed 84 | proj = self.B_layer(tensor) 85 | proj_bands = proj[..., None, :] * self.frequency_bands[None, None, :, None] 86 | xb = proj_bands.view(list(proj.shape[:-1]) + [-1]) 87 | # embedding = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1)) 88 | embedding = torch.sin(xb * np.pi) 89 | embedding = torch.cat([tensor] + [embedding], dim=-1) 90 | # print("emb size ", embedding.shape) 91 | return embedding -------------------------------------------------------------------------------- /metric/eval_3D_scene.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import trimesh 4 | from metrics import accuracy, completion, completion_ratio 5 | import os 6 | 7 | def calc_3d_metric(mesh_rec, mesh_gt, N=200000): 8 | """ 9 | 3D reconstruction metric. 10 | """ 11 | metrics = [[] for _ in range(4)] 12 | rec_pc = trimesh.sample.sample_surface(mesh_rec, N) 13 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0]) 14 | 15 | gt_pc = trimesh.sample.sample_surface(mesh_gt, N) 16 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0]) 17 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices) 18 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices) 19 | completion_ratio_rec = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.05) 20 | completion_ratio_rec_1 = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.01) 21 | 22 | # accuracy_rec *= 100 # convert to cm 23 | # completion_rec *= 100 # convert to cm 24 | # completion_ratio_rec *= 100 # convert to % 25 | # print('accuracy: ', accuracy_rec) 26 | # print('completion: ', completion_rec) 27 | # print('completion ratio: ', completion_ratio_rec) 28 | # print("completion_ratio_rec_1cm ", completion_ratio_rec_1) 29 | metrics[0].append(accuracy_rec) 30 | metrics[1].append(completion_rec) 31 | metrics[2].append(completion_ratio_rec_1) 32 | metrics[3].append(completion_ratio_rec) 33 | return metrics 34 | 35 | 36 | if __name__ == "__main__": 37 | exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 38 | data_dir = "/home/xin/data/vmap/" 39 | # log_dir = "../logs/iMAP/" 40 | log_dir = "../logs/vMAP/" 41 | 42 | for exp in tqdm(exp_name): 43 | gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat") 44 | exp_dir = os.path.join(log_dir, exp) 45 | mesh_dir = os.path.join(exp_dir, "scene_mesh") 46 | output_path = os.path.join(exp_dir, "eval_mesh") 47 | os.makedirs(output_path, exist_ok=True) 48 | if "vMAP" in exp_dir: 49 | mesh_list = os.listdir(mesh_dir) 50 | if "frame_1999_scene.obj" in mesh_list: 51 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_scene.obj") 52 | else: # compose obj into scene mesh 53 | scene_meshes = [] 54 | for f in mesh_list: 55 | _, f_type = os.path.splitext(f) 56 | if f_type == ".obj" or f_type == ".ply": 57 | obj_mesh = trimesh.load(os.path.join(mesh_dir, f)) 58 | scene_meshes.append(obj_mesh) 59 | scene_mesh = trimesh.util.concatenate(scene_meshes) 60 | scene_mesh.export(os.path.join(mesh_dir, "frame_1999_scene.obj")) 61 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_scene.obj") 62 | elif "iMAP" in exp_dir: # obj0 is the scene mesh 63 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj0.obj") 64 | else: 65 | print("Not Implement") 66 | exit(-1) 67 | gt_mesh_files = os.listdir(gt_dir) 68 | gt_mesh_file = os.path.join(gt_dir, "../mesh.ply") 69 | mesh_rec = trimesh.load(rec_meshfile) 70 | # mesh_rec.invert() # niceslam mesh face needs invert 71 | metrics_3D = [[] for _ in range(4)] 72 | mesh_gt = trimesh.load(gt_mesh_file) 73 | metrics = calc_3d_metric(mesh_rec, mesh_gt, N=200000) # for objs use 10k, for scene use 200k points 74 | metrics_3D[0].append(metrics[0]) # acc 75 | metrics_3D[1].append(metrics[1]) # comp 76 | metrics_3D[2].append(metrics[2]) # comp ratio 1cm 77 | metrics_3D[3].append(metrics[3]) # comp ratio 5cm 78 | metrics_3D = np.array(metrics_3D) 79 | np.save(output_path + '/metrics_3D_scene.npy', metrics_3D) 80 | print("metrics 3D scene \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n ", metrics_3D.mean(axis=1)) 81 | print("-----------------------------------------") 82 | print("finish exp ", exp) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import model 3 | import embedding 4 | import render_rays 5 | import numpy as np 6 | import vis 7 | from tqdm import tqdm 8 | 9 | class Trainer: 10 | def __init__(self, cfg): 11 | self.obj_id = cfg.obj_id 12 | self.device = cfg.training_device 13 | self.hidden_feature_size = cfg.hidden_feature_size #32 for obj # 256 for iMAP, 128 for seperate bg 14 | self.obj_scale = cfg.obj_scale # 10 for bg and iMAP 15 | self.n_unidir_funcs = cfg.n_unidir_funcs 16 | self.emb_size1 = 21*(3+1)+3 17 | self.emb_size2 = 21*(5+1)+3 - self.emb_size1 18 | 19 | self.load_network() 20 | 21 | if self.obj_id == 0: 22 | self.bound_extent = 0.995 23 | else: 24 | self.bound_extent = 0.9 25 | 26 | def load_network(self): 27 | self.fc_occ_map = model.OccupancyMap( 28 | self.emb_size1, 29 | self.emb_size2, 30 | hidden_size=self.hidden_feature_size 31 | ) 32 | self.fc_occ_map.apply(model.init_weights).to(self.device) 33 | self.pe = embedding.UniDirsEmbed(max_deg=self.n_unidir_funcs, scale=self.obj_scale).to(self.device) 34 | 35 | def meshing(self, bound, obj_center, grid_dim=256): 36 | occ_range = [-1., 1.] 37 | range_dist = occ_range[1] - occ_range[0] 38 | scene_scale_np = bound.extent / (range_dist * self.bound_extent) 39 | scene_scale = torch.from_numpy(scene_scale_np).float().to(self.device) 40 | transform_np = np.eye(4, dtype=np.float32) 41 | transform_np[:3, 3] = bound.center 42 | transform_np[:3, :3] = bound.R 43 | # transform_np = np.linalg.inv(transform_np) # 44 | transform = torch.from_numpy(transform_np).to(self.device) 45 | grid_pc = render_rays.make_3D_grid(occ_range=occ_range, dim=grid_dim, device=self.device, 46 | scale=scene_scale, transform=transform).view(-1, 3) 47 | grid_pc -= obj_center.to(grid_pc.device) 48 | ret = self.eval_points(grid_pc) 49 | if ret is None: 50 | return None 51 | 52 | occ, _ = ret 53 | mesh = vis.marching_cubes(occ.view(grid_dim, grid_dim, grid_dim).cpu().numpy()) 54 | if mesh is None: 55 | print("marching cube failed") 56 | return None 57 | 58 | # Transform to [-1, 1] range 59 | mesh.apply_translation([-0.5, -0.5, -0.5]) 60 | mesh.apply_scale(2) 61 | 62 | # Transform to scene coordinates 63 | mesh.apply_scale(scene_scale_np) 64 | mesh.apply_transform(transform_np) 65 | 66 | vertices_pts = torch.from_numpy(np.array(mesh.vertices)).float().to(self.device) 67 | ret = self.eval_points(vertices_pts) 68 | if ret is None: 69 | return None 70 | _, color = ret 71 | mesh_color = color * 255 72 | vertex_colors = mesh_color.detach().squeeze(0).cpu().numpy().astype(np.uint8) 73 | mesh.visual.vertex_colors = vertex_colors 74 | 75 | return mesh 76 | 77 | def eval_points(self, points, chunk_size=100000): 78 | # 256^3 = 16777216 79 | alpha, color = [], [] 80 | n_chunks = int(np.ceil(points.shape[0] / chunk_size)) 81 | with torch.no_grad(): 82 | for k in tqdm(range(n_chunks)): # 2s/it 1000000 pts 83 | chunk_idx = slice(k * chunk_size, (k + 1) * chunk_size) 84 | embedding_k = self.pe(points[chunk_idx, ...]) 85 | alpha_k, color_k = self.fc_occ_map(embedding_k) 86 | alpha.extend(alpha_k.detach().squeeze()) 87 | color.extend(color_k.detach().squeeze()) 88 | alpha = torch.stack(alpha) 89 | color = torch.stack(color) 90 | 91 | occ = render_rays.occupancy_activation(alpha).detach() 92 | if occ.max() == 0: 93 | print("no occ") 94 | return None 95 | return (occ, color) 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import utils 5 | 6 | class Config: 7 | def __init__(self, config_file): 8 | # setting params 9 | with open(config_file) as json_file: 10 | config = json.load(json_file) 11 | 12 | # training strategy 13 | self.do_bg = bool(config["trainer"]["do_bg"]) 14 | self.training_device = config["trainer"]["train_device"] 15 | self.data_device = config["trainer"]["data_device"] 16 | self.max_n_models = config["trainer"]["n_models"] 17 | self.live_mode = bool(config["dataset"]["live"]) 18 | self.keep_live_time = config["dataset"]["keep_alive"] 19 | self.imap_mode = config["trainer"]["imap_mode"] 20 | self.training_strategy = config["trainer"]["training_strategy"] # "forloop" "vmap" 21 | self.obj_id = -1 22 | 23 | # dataset setting 24 | self.dataset_format = config["dataset"]["format"] 25 | self.dataset_dir = config["dataset"]["path"] 26 | self.depth_scale = 1 / config["trainer"]["scale"] 27 | # camera setting 28 | self.max_depth = config["render"]["depth_range"][1] 29 | self.min_depth = config["render"]["depth_range"][0] 30 | self.mh = config["camera"]["mh"] 31 | self.mw = config["camera"]["mw"] 32 | self.height = config["camera"]["h"] 33 | self.width = config["camera"]["w"] 34 | self.H = self.height - 2 * self.mh 35 | self.W = self.width - 2 * self.mw 36 | if "fx" in config["camera"]: 37 | self.fx = config["camera"]["fx"] 38 | self.fy = config["camera"]["fy"] 39 | self.cx = config["camera"]["cx"] - self.mw 40 | self.cy = config["camera"]["cy"] - self.mh 41 | else: # for scannet 42 | intrinsic = utils.load_matrix_from_txt(os.path.join(self.dataset_dir, "intrinsic/intrinsic_depth.txt")) 43 | self.fx = intrinsic[0, 0] 44 | self.fy = intrinsic[1, 1] 45 | self.cx = intrinsic[0, 2] - self.mw 46 | self.cy = intrinsic[1, 2] - self.mh 47 | if "distortion" in config["camera"]: 48 | self.distortion_array = np.array(config["camera"]["distortion"]) 49 | elif "k1" in config["camera"]: 50 | k1 = config["camera"]["k1"] 51 | k2 = config["camera"]["k2"] 52 | k3 = config["camera"]["k3"] 53 | k4 = config["camera"]["k4"] 54 | k5 = config["camera"]["k5"] 55 | k6 = config["camera"]["k6"] 56 | p1 = config["camera"]["p1"] 57 | p2 = config["camera"]["p2"] 58 | self.distortion_array = np.array([k1, k2, p1, p2, k3, k4, k5, k6]) 59 | else: 60 | self.distortion_array = None 61 | 62 | # training setting 63 | self.win_size = config["model"]["window_size"] 64 | self.n_iter_per_frame = config["render"]["iters_per_frame"] 65 | self.n_per_optim = config["render"]["n_per_optim"] 66 | self.n_samples_per_frame = self.n_per_optim // self.win_size 67 | self.win_size_bg = config["model"]["window_size_bg"] 68 | self.n_per_optim_bg = config["render"]["n_per_optim_bg"] 69 | self.n_samples_per_frame_bg = self.n_per_optim_bg // self.win_size_bg 70 | self.keyframe_buffer_size = config["model"]["keyframe_buffer_size"] 71 | self.keyframe_step = config["model"]["keyframe_step"] 72 | self.keyframe_step_bg = config["model"]["keyframe_step_bg"] 73 | self.obj_scale = config["model"]["obj_scale"] 74 | self.bg_scale = config["model"]["bg_scale"] 75 | self.hidden_feature_size = config["model"]["hidden_feature_size"] 76 | self.hidden_feature_size_bg = config["model"]["hidden_feature_size_bg"] 77 | self.n_bins_cam2surface = config["render"]["n_bins_cam2surface"] 78 | self.n_bins_cam2surface_bg = config["render"]["n_bins_cam2surface_bg"] 79 | self.n_bins = config["render"]["n_bins"] 80 | self.n_unidir_funcs = config["model"]["n_unidir_funcs"] 81 | self.surface_eps = config["model"]["surface_eps"] 82 | self.stop_eps = config["model"]["other_eps"] 83 | 84 | # optimizer setting 85 | self.learning_rate = config["optimizer"]["args"]["lr"] 86 | self.weight_decay = config["optimizer"]["args"]["weight_decay"] 87 | 88 | # vis setting 89 | self.vis_device = config["vis"]["vis_device"] 90 | self.n_vis_iter = config["vis"]["n_vis_iter"] 91 | self.live_voxel_size = config["vis"]["live_voxel_size"] 92 | self.grid_dim = config["vis"]["grid_dim"] -------------------------------------------------------------------------------- /render_rays.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def occupancy_activation(alpha, distances=None): 5 | # occ = 1.0 - torch.exp(-alpha * distances) 6 | occ = torch.sigmoid(alpha) # unisurf 7 | 8 | return occ 9 | 10 | def alpha_to_occupancy(depths, dirs, alpha, add_last=False): 11 | interval_distances = depths[..., 1:] - depths[..., :-1] 12 | if add_last: 13 | last_distance = torch.empty( 14 | (depths.shape[0], 1), 15 | device=depths.device, 16 | dtype=depths.dtype).fill_(0.1) 17 | interval_distances = torch.cat( 18 | [interval_distances, last_distance], dim=-1) 19 | 20 | dirs_norm = torch.norm(dirs, dim=-1) 21 | interval_distances = interval_distances * dirs_norm[:, None] 22 | occ = occupancy_activation(alpha, interval_distances) 23 | 24 | return occ 25 | 26 | def occupancy_to_termination(occupancy, is_batch=False): 27 | if is_batch: 28 | first = torch.ones(list(occupancy.shape[:2]) + [1], device=occupancy.device) 29 | free_probs = (1. - occupancy + 1e-10)[:, :, :-1] 30 | else: 31 | first = torch.ones([occupancy.shape[0], 1], device=occupancy.device) 32 | free_probs = (1. - occupancy + 1e-10)[:, :-1] 33 | free_probs = torch.cat([first, free_probs], dim=-1) 34 | term_probs = occupancy * torch.cumprod(free_probs, dim=-1) 35 | 36 | # using escape probability 37 | # occupancy = occupancy[:, :-1] 38 | # first = torch.ones([occupancy.shape[0], 1], device=occupancy.device) 39 | # free_probs = (1. - occupancy + 1e-10) 40 | # free_probs = torch.cat([first, free_probs], dim=-1) 41 | # last = torch.ones([occupancy.shape[0], 1], device=occupancy.device) 42 | # occupancy = torch.cat([occupancy, last], dim=-1) 43 | # term_probs = occupancy * torch.cumprod(free_probs, dim=-1) 44 | 45 | return term_probs 46 | 47 | def render(termination, vals, dim=-1): 48 | weighted_vals = termination * vals 49 | render = weighted_vals.sum(dim=dim) 50 | 51 | return render 52 | 53 | def render_loss(render, gt, loss="L1", normalise=False): 54 | residual = render - gt 55 | if loss == "L2": 56 | loss_mat = residual ** 2 57 | elif loss == "L1": 58 | loss_mat = torch.abs(residual) 59 | else: 60 | print("loss type {} not implemented!".format(loss)) 61 | 62 | if normalise: 63 | loss_mat = loss_mat / gt 64 | 65 | return loss_mat 66 | 67 | def reduce_batch_loss(loss_mat, var=None, avg=True, mask=None, loss_type="L1"): 68 | mask_num = torch.sum(mask, dim=-1) 69 | if (mask_num == 0).any(): # no valid sample, return 0 loss 70 | loss = torch.zeros_like(loss_mat) 71 | if avg: 72 | loss = torch.mean(loss, dim=-1) 73 | return loss 74 | if var is not None: 75 | eps = 1e-4 76 | if loss_type == "L2": 77 | information = 1.0 / (var + eps) 78 | elif loss_type == "L1": 79 | information = 1.0 / (torch.sqrt(var) + eps) 80 | 81 | loss_weighted = loss_mat * information 82 | else: 83 | loss_weighted = loss_mat 84 | 85 | if avg: 86 | if mask is not None: 87 | loss = (torch.sum(loss_weighted, dim=-1)/(torch.sum(mask, dim=-1)+1e-10)) 88 | if (loss > 100000).any(): 89 | print("loss explode") 90 | exit(-1) 91 | else: 92 | loss = torch.mean(loss_weighted, dim=-1).sum() 93 | else: 94 | loss = loss_weighted 95 | 96 | return loss 97 | 98 | def make_3D_grid(occ_range=[-1., 1.], dim=256, device="cuda:0", transform=None, scale=None): 99 | t = torch.linspace(occ_range[0], occ_range[1], steps=dim, device=device) 100 | grid = torch.meshgrid(t, t, t) 101 | grid_3d = torch.cat( 102 | (grid[0][..., None], 103 | grid[1][..., None], 104 | grid[2][..., None]), dim=3 105 | ) 106 | 107 | if scale is not None: 108 | grid_3d = grid_3d * scale 109 | if transform is not None: 110 | R1 = transform[None, None, None, 0, :3] 111 | R2 = transform[None, None, None, 1, :3] 112 | R3 = transform[None, None, None, 2, :3] 113 | 114 | grid1 = (R1 * grid_3d).sum(-1, keepdim=True) 115 | grid2 = (R2 * grid_3d).sum(-1, keepdim=True) 116 | grid3 = (R3 * grid_3d).sum(-1, keepdim=True) 117 | grid_3d = torch.cat([grid1, grid2, grid3], dim=-1) 118 | 119 | trans = transform[None, None, None, :3, 3] 120 | grid_3d = grid_3d + trans 121 | 122 | return grid_3d 123 | -------------------------------------------------------------------------------- /metric/eval_3D_obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import trimesh 4 | from metrics import accuracy, completion, completion_ratio 5 | import os 6 | import json 7 | 8 | def calc_3d_metric(mesh_rec, mesh_gt, N=200000): 9 | """ 10 | 3D reconstruction metric. 11 | """ 12 | metrics = [[] for _ in range(4)] 13 | transform, extents = trimesh.bounds.oriented_bounds(mesh_gt) 14 | extents = extents / 0.9 # enlarge 0.9 15 | box = trimesh.creation.box(extents=extents, transform=np.linalg.inv(transform)) 16 | mesh_rec = mesh_rec.slice_plane(box.facets_origin, -box.facets_normal) 17 | if mesh_rec.vertices.shape[0] == 0: 18 | print("no mesh found") 19 | return 20 | rec_pc = trimesh.sample.sample_surface(mesh_rec, N) 21 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0]) 22 | 23 | gt_pc = trimesh.sample.sample_surface(mesh_gt, N) 24 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0]) 25 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices) 26 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices) 27 | completion_ratio_rec = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.05) 28 | completion_ratio_rec_1 = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.01) 29 | 30 | # accuracy_rec *= 100 # convert to cm 31 | # completion_rec *= 100 # convert to cm 32 | # completion_ratio_rec *= 100 # convert to % 33 | # print('accuracy: ', accuracy_rec) 34 | # print('completion: ', completion_rec) 35 | # print('completion ratio: ', completion_ratio_rec) 36 | # print("completion_ratio_rec_1cm ", completion_ratio_rec_1) 37 | metrics[0].append(accuracy_rec) 38 | metrics[1].append(completion_rec) 39 | metrics[2].append(completion_ratio_rec_1) 40 | metrics[3].append(completion_ratio_rec) 41 | return metrics 42 | 43 | def get_gt_bg_mesh(gt_dir, background_cls_list): 44 | with open(os.path.join(gt_dir, "info_semantic.json")) as f: 45 | label_obj_list = json.load(f)["objects"] 46 | 47 | bg_meshes = [] 48 | for obj in label_obj_list: 49 | if int(obj["class_id"]) in background_cls_list: 50 | obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(int(obj["id"])) + ".ply") 51 | obj_mesh = trimesh.load(obj_file) 52 | bg_meshes.append(obj_mesh) 53 | 54 | bg_mesh = trimesh.util.concatenate(bg_meshes) 55 | return bg_mesh 56 | 57 | def get_obj_ids(obj_dir): 58 | files = os.listdir(obj_dir) 59 | obj_ids = [] 60 | for f in files: 61 | obj_id = f.split("obj")[1][:-1] 62 | if obj_id == '': 63 | continue 64 | obj_ids.append(int(obj_id)) 65 | return obj_ids 66 | 67 | 68 | if __name__ == "__main__": 69 | background_cls_list = [5, 12, 30, 31, 40, 60, 92, 93, 95, 97, 98, 79] 70 | exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 71 | data_dir = "/home/xin/data/vmap/" 72 | log_dir = "../logs/iMAP/" 73 | # log_dir = "../logs/vMAP/" 74 | 75 | for exp in tqdm(exp_name): 76 | gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat") 77 | exp_dir = os.path.join(log_dir, exp) 78 | mesh_dir = os.path.join(exp_dir, "scene_mesh") 79 | output_path = os.path.join(exp_dir, "eval_mesh") 80 | os.makedirs(output_path, exist_ok=True) 81 | metrics_3D = [[] for _ in range(4)] 82 | 83 | # get obj ids 84 | # obj_ids = np.loadtxt() # todo use a pre-defined obj list or use vMAP results 85 | obj_ids = get_obj_ids(mesh_dir.replace("iMAP", "vMAP")) 86 | for obj_id in tqdm(obj_ids): 87 | if obj_id == 0: # for bg 88 | N = 200000 89 | mesh_gt = get_gt_bg_mesh(gt_dir, background_cls_list) 90 | else: # for obj 91 | N = 10000 92 | obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(obj_id) + ".ply") 93 | mesh_gt = trimesh.load(obj_file) 94 | 95 | if "vMAP" in exp_dir: 96 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj"+str(obj_id)+".obj") 97 | elif "iMAP" in exp_dir: 98 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj0.obj") 99 | else: 100 | print("Not Implement") 101 | exit(-1) 102 | 103 | mesh_rec = trimesh.load(rec_meshfile) 104 | # mesh_rec.invert() # niceslam mesh face needs invert 105 | metrics = calc_3d_metric(mesh_rec, mesh_gt, N=N) # for objs use 10k, for scene use 200k points 106 | if metrics is None: 107 | continue 108 | np.save(output_path + '/metric_obj{}.npy'.format(obj_id), np.array(metrics)) 109 | metrics_3D[0].append(metrics[0]) # acc 110 | metrics_3D[1].append(metrics[1]) # comp 111 | metrics_3D[2].append(metrics[2]) # comp ratio 1cm 112 | metrics_3D[3].append(metrics[3]) # comp ratio 5cm 113 | metrics_3D = np.array(metrics_3D) 114 | np.save(output_path + '/metrics_3D_obj.npy', metrics_3D) 115 | print("metrics 3D obj \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n", metrics_3D.mean(axis=1)) 116 | print("-----------------------------------------") 117 | print("finish exp ", exp) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [comment]: <> (# vMAP: Vectorised Object Mapping for Neural Field SLAM) 2 | 3 | 4 | 5 |

6 | 7 |

vMAP: Vectorised Object Mapping for Neural Field SLAM

8 |

9 | Xin Kong 10 | · 11 | Shikun Liu 12 | · 13 | Marwan Taher 14 | · 15 | Andrew Davison 16 |

17 | 18 | [comment]: <> (

PAPER

) 19 |

Paper | Video | Project Page

20 |
21 | 22 |

23 | 24 | Logo 25 | 26 |

27 |

28 | vMAP builds an object-level map from a real-time RGB-D input stream. Each object is represented by a separate MLP neural field model, all optimised in parallel via vectorised training. 29 |

30 |
31 | 32 | We provide the implementation of the following neural-field SLAM frameworks: 33 | - **vMAP** [Official Implementation] 34 | - **iMAP** [Simplified and Improved Re-Implementation, with depth guided sampling] 35 | 36 | 37 | 38 | ## Install 39 | First, let's start with a virtual environment with the required dependencies. 40 | ```bash 41 | conda env create -f environment.yml 42 | ``` 43 | 44 | ## Dataset 45 | Please download the following datasets to reproduce our results. 46 | 47 | * [Replica Demo](https://huggingface.co/datasets/kxic/vMAP/resolve/main/demo_replica_room_0.zip) - Replica Room 0 only for faster experimentation. 48 | * [Replica](https://huggingface.co/datasets/kxic/vMAP/resolve/main/vmap.zip) - All Pre-generated Replica sequences. For Replica data generation, please refer to directory `data_generation`. 49 | * [ScanNet](https://github.com/ScanNet/ScanNet) - Official ScanNet sequences. 50 | Each dataset contains a sequence of RGB-D images, as well as their corresponding camera poses, and object instance labels. 51 | To extract data from ScanNet .sens files, run 52 | ```bash 53 | conda activate py2 54 | python2 reader.py --filename ~/data/ScanNet/scannet/scans/scene0024_00/scene0024_00.sens --output_path ~/data/ScanNet/objnerf/ --export_depth_images --export_color_images --export_poses --export_intrinsics 55 | ``` 56 | 57 | ## Config 58 | 59 | Then update the config files in `configs/.json` with your dataset paths, as well as other training hyper-parameters. 60 | ```json 61 | "dataset": { 62 | "path": "path/to/ims/folder/", 63 | } 64 | ``` 65 | 66 | ## Running vMAP / iMAP 67 | The following commands will run vMAP / iMAP in a single-thread setting. 68 | 69 | #### vMAP 70 | ```bash 71 | python ./train.py --config ./configs/Replica/config_replica_room0_vMAP.json --logdir ./logs/vMAP/room0 --save_ckpt True 72 | ``` 73 | #### iMAP 74 | ```bash 75 | python ./train.py --config ./configs/Replica/config_replica_room0_iMAP.json --logdir ./logs/iMAP/room0 --save_ckpt True 76 | ``` 77 | 78 | [comment]: <> (#### Multi thread demo) 79 | 80 | [comment]: <> (```bash) 81 | 82 | [comment]: <> (./parallel_train.py --config "config_file.json" --logdir ./logs) 83 | 84 | [comment]: <> (```) 85 | 86 | ## Evaluation 87 | To evaluate the quality of reconstructed scenes, we provide two different methods, 88 | #### 3D Scene-level Evaluation 89 | The same metrics following the original iMAP, to compare with GT scene meshes by **Accuracy**, **Completion** and **Completion Ratio**. 90 | ```bash 91 | python ./metric/eval_3D_scene.py 92 | ``` 93 | #### 3D Object-level Evaluation 94 | We also provide the object-level metrics by computing the same metrics but averaging across all objects in a scene. 95 | ```bash 96 | python ./metric/eval_3D_obj.py 97 | ``` 98 | 99 | [comment]: <> (### Novel View Synthesis) 100 | 101 | [comment]: <> (##### 2D Novel View Eval) 102 | 103 | [comment]: <> (We rendered a new trajectory in each scene and randomly choose novel view pose from it, evaluating 2D rendering performance) 104 | 105 | [comment]: <> (```bash) 106 | 107 | [comment]: <> (./metric/eval_2D_view.py) 108 | 109 | [comment]: <> (```) 110 | 111 | ## Results 112 | We provide raw results, including 3D meshes, 2D novel view rendering, and evaluated metrics of vMAP and iMAP* for easier comparison. 113 | 114 | * [Replica](https://huggingface.co/datasets/kxic/vMAP/resolve/main/vMAP_Replica_Results.zip) 115 | 116 | ## Acknowledgement 117 | We would like thank the following open-source repositories that we have build upon for the implementation of this work: [NICE-SLAM](https://github.com/cvg/nice-slam), and [functorch](https://github.com/pytorch/functorch). 118 | 119 | ## Citation 120 | If you found this code/work to be useful in your own research, please considering citing the following: 121 | ```bibtex 122 | @article{kong2023vmap, 123 | title={vMAP: Vectorised Object Mapping for Neural Field SLAM}, 124 | author={Kong, Xin and Liu, Shikun and Taher, Marwan and Davison, Andrew J}, 125 | journal={arXiv preprint arXiv:2302.01838}, 126 | year={2023} 127 | } 128 | ``` 129 | 130 | ```bibtex 131 | @inproceedings{sucar2021imap, 132 | title={iMAP: Implicit mapping and positioning in real-time}, 133 | author={Sucar, Edgar and Liu, Shikun and Ortiz, Joseph and Davison, Andrew J}, 134 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 135 | pages={6229--6238}, 136 | year={2021} 137 | } 138 | ``` 139 | 140 | -------------------------------------------------------------------------------- /data_generation/settings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import habitat_sim 6 | import habitat_sim.agent 7 | 8 | default_sim_settings = { 9 | # settings shared by example.py and benchmark.py 10 | "max_frames": 1000, 11 | "width": 640, 12 | "height": 480, 13 | "default_agent": 0, 14 | "sensor_height": 1.5, 15 | "hfov": 90, 16 | "color_sensor": True, # RGB sensor (default: ON) 17 | "semantic_sensor": False, # semantic sensor (default: OFF) 18 | "depth_sensor": False, # depth sensor (default: OFF) 19 | "ortho_rgba_sensor": False, # Orthographic RGB sensor (default: OFF) 20 | "ortho_depth_sensor": False, # Orthographic depth sensor (default: OFF) 21 | "ortho_semantic_sensor": False, # Orthographic semantic sensor (default: OFF) 22 | "fisheye_rgba_sensor": False, 23 | "fisheye_depth_sensor": False, 24 | "fisheye_semantic_sensor": False, 25 | "equirect_rgba_sensor": False, 26 | "equirect_depth_sensor": False, 27 | "equirect_semantic_sensor": False, 28 | "seed": 1, 29 | "silent": False, # do not print log info (default: OFF) 30 | # settings exclusive to example.py 31 | "save_png": False, # save the pngs to disk (default: OFF) 32 | "print_semantic_scene": False, 33 | "print_semantic_mask_stats": False, 34 | "compute_shortest_path": False, 35 | "compute_action_shortest_path": False, 36 | "scene": "data/scene_datasets/habitat-test-scenes/skokloster-castle.glb", 37 | "test_scene_data_url": "http://dl.fbaipublicfiles.com/habitat/habitat-test-scenes.zip", 38 | "goal_position": [5.047, 0.199, 11.145], 39 | "enable_physics": False, 40 | "enable_gfx_replay_save": False, 41 | "physics_config_file": "./data/default.physics_config.json", 42 | "num_objects": 10, 43 | "test_object_index": 0, 44 | "frustum_culling": True, 45 | } 46 | 47 | # build SimulatorConfiguration 48 | def make_cfg(settings): 49 | sim_cfg = habitat_sim.SimulatorConfiguration() 50 | if "frustum_culling" in settings: 51 | sim_cfg.frustum_culling = settings["frustum_culling"] 52 | else: 53 | sim_cfg.frustum_culling = False 54 | if "enable_physics" in settings: 55 | sim_cfg.enable_physics = settings["enable_physics"] 56 | if "physics_config_file" in settings: 57 | sim_cfg.physics_config_file = settings["physics_config_file"] 58 | # if not settings["silent"]: 59 | # print("sim_cfg.physics_config_file = " + sim_cfg.physics_config_file) 60 | if "scene_light_setup" in settings: 61 | sim_cfg.scene_light_setup = settings["scene_light_setup"] 62 | sim_cfg.gpu_device_id = 0 63 | if not hasattr(sim_cfg, "scene_id"): 64 | raise RuntimeError( 65 | "Error: Please upgrade habitat-sim. SimulatorConfig API version mismatch" 66 | ) 67 | sim_cfg.scene_id = settings["scene_file"] 68 | 69 | # define default sensor parameters (see src/esp/Sensor/Sensor.h) 70 | sensor_specs = [] 71 | 72 | def create_camera_spec(**kw_args): 73 | camera_sensor_spec = habitat_sim.CameraSensorSpec() 74 | camera_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR 75 | camera_sensor_spec.resolution = [settings["height"], settings["width"]] 76 | camera_sensor_spec.position = [0, settings["sensor_height"], 0] 77 | for k in kw_args: 78 | setattr(camera_sensor_spec, k, kw_args[k]) 79 | return camera_sensor_spec 80 | 81 | if settings["color_sensor"]: 82 | color_sensor_spec = create_camera_spec( 83 | uuid="color_sensor", 84 | # hfov=settings["hfov"], 85 | sensor_type=habitat_sim.SensorType.COLOR, 86 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE, 87 | ) 88 | sensor_specs.append(color_sensor_spec) 89 | 90 | if settings["depth_sensor"]: 91 | depth_sensor_spec = create_camera_spec( 92 | uuid="depth_sensor", 93 | # hfov=settings["hfov"], 94 | sensor_type=habitat_sim.SensorType.DEPTH, 95 | channels=1, 96 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE, 97 | ) 98 | sensor_specs.append(depth_sensor_spec) 99 | 100 | if settings["semantic_sensor"]: 101 | semantic_sensor_spec = create_camera_spec( 102 | uuid="semantic_sensor", 103 | # hfov=settings["hfov"], 104 | sensor_type=habitat_sim.SensorType.SEMANTIC, 105 | channels=1, 106 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE, 107 | ) 108 | sensor_specs.append(semantic_sensor_spec) 109 | 110 | # if settings["ortho_rgba_sensor"]: 111 | # ortho_rgba_sensor_spec = create_camera_spec( 112 | # uuid="ortho_rgba_sensor", 113 | # sensor_type=habitat_sim.SensorType.COLOR, 114 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC, 115 | # ) 116 | # sensor_specs.append(ortho_rgba_sensor_spec) 117 | # 118 | # if settings["ortho_depth_sensor"]: 119 | # ortho_depth_sensor_spec = create_camera_spec( 120 | # uuid="ortho_depth_sensor", 121 | # sensor_type=habitat_sim.SensorType.DEPTH, 122 | # channels=1, 123 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC, 124 | # ) 125 | # sensor_specs.append(ortho_depth_sensor_spec) 126 | # 127 | # if settings["ortho_semantic_sensor"]: 128 | # ortho_semantic_sensor_spec = create_camera_spec( 129 | # uuid="ortho_semantic_sensor", 130 | # sensor_type=habitat_sim.SensorType.SEMANTIC, 131 | # channels=1, 132 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC, 133 | # ) 134 | # sensor_specs.append(ortho_semantic_sensor_spec) 135 | 136 | # TODO Figure out how to implement copying of specs 137 | def create_fisheye_spec(**kw_args): 138 | fisheye_sensor_spec = habitat_sim.FisheyeSensorDoubleSphereSpec() 139 | fisheye_sensor_spec.uuid = "fisheye_sensor" 140 | fisheye_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR 141 | fisheye_sensor_spec.sensor_model_type = ( 142 | habitat_sim.FisheyeSensorModelType.DOUBLE_SPHERE 143 | ) 144 | 145 | # The default value (alpha, xi) is set to match the lens "GoPro" found in Table 3 of this paper: 146 | # Vladyslav Usenko, Nikolaus Demmel and Daniel Cremers: The Double Sphere 147 | # Camera Model, The International Conference on 3D Vision (3DV), 2018 148 | # You can find the intrinsic parameters for the other lenses in the same table as well. 149 | fisheye_sensor_spec.xi = -0.27 150 | fisheye_sensor_spec.alpha = 0.57 151 | fisheye_sensor_spec.focal_length = [364.84, 364.86] 152 | 153 | fisheye_sensor_spec.resolution = [settings["height"], settings["width"]] 154 | # The default principal_point_offset is the middle of the image 155 | fisheye_sensor_spec.principal_point_offset = None 156 | # default: fisheye_sensor_spec.principal_point_offset = [i/2 for i in fisheye_sensor_spec.resolution] 157 | fisheye_sensor_spec.position = [0, settings["sensor_height"], 0] 158 | for k in kw_args: 159 | setattr(fisheye_sensor_spec, k, kw_args[k]) 160 | return fisheye_sensor_spec 161 | 162 | # if settings["fisheye_rgba_sensor"]: 163 | # fisheye_rgba_sensor_spec = create_fisheye_spec(uuid="fisheye_rgba_sensor") 164 | # sensor_specs.append(fisheye_rgba_sensor_spec) 165 | # if settings["fisheye_depth_sensor"]: 166 | # fisheye_depth_sensor_spec = create_fisheye_spec( 167 | # uuid="fisheye_depth_sensor", 168 | # sensor_type=habitat_sim.SensorType.DEPTH, 169 | # channels=1, 170 | # ) 171 | # sensor_specs.append(fisheye_depth_sensor_spec) 172 | # if settings["fisheye_semantic_sensor"]: 173 | # fisheye_semantic_sensor_spec = create_fisheye_spec( 174 | # uuid="fisheye_semantic_sensor", 175 | # sensor_type=habitat_sim.SensorType.SEMANTIC, 176 | # channels=1, 177 | # ) 178 | # sensor_specs.append(fisheye_semantic_sensor_spec) 179 | 180 | def create_equirect_spec(**kw_args): 181 | equirect_sensor_spec = habitat_sim.EquirectangularSensorSpec() 182 | equirect_sensor_spec.uuid = "equirect_rgba_sensor" 183 | equirect_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR 184 | equirect_sensor_spec.resolution = [settings["height"], settings["width"]] 185 | equirect_sensor_spec.position = [0, settings["sensor_height"], 0] 186 | for k in kw_args: 187 | setattr(equirect_sensor_spec, k, kw_args[k]) 188 | return equirect_sensor_spec 189 | 190 | # if settings["equirect_rgba_sensor"]: 191 | # equirect_rgba_sensor_spec = create_equirect_spec(uuid="equirect_rgba_sensor") 192 | # sensor_specs.append(equirect_rgba_sensor_spec) 193 | # 194 | # if settings["equirect_depth_sensor"]: 195 | # equirect_depth_sensor_spec = create_equirect_spec( 196 | # uuid="equirect_depth_sensor", 197 | # sensor_type=habitat_sim.SensorType.DEPTH, 198 | # channels=1, 199 | # ) 200 | # sensor_specs.append(equirect_depth_sensor_spec) 201 | # 202 | # if settings["equirect_semantic_sensor"]: 203 | # equirect_semantic_sensor_spec = create_equirect_spec( 204 | # uuid="equirect_semantic_sensor", 205 | # sensor_type=habitat_sim.SensorType.SEMANTIC, 206 | # channels=1, 207 | # ) 208 | # sensor_specs.append(equirect_semantic_sensor_spec) 209 | 210 | # create agent specifications 211 | agent_cfg = habitat_sim.agent.AgentConfiguration() 212 | agent_cfg.sensor_specifications = sensor_specs 213 | agent_cfg.action_space = { 214 | "move_forward": habitat_sim.agent.ActionSpec( 215 | "move_forward", habitat_sim.agent.ActuationSpec(amount=0.25) 216 | ), 217 | "turn_left": habitat_sim.agent.ActionSpec( 218 | "turn_left", habitat_sim.agent.ActuationSpec(amount=10.0) 219 | ), 220 | "turn_right": habitat_sim.agent.ActionSpec( 221 | "turn_right", habitat_sim.agent.ActuationSpec(amount=10.0) 222 | ), 223 | } 224 | 225 | # override action space to no-op to test physics 226 | if sim_cfg.enable_physics: 227 | agent_cfg.action_space = { 228 | "move_forward": habitat_sim.agent.ActionSpec( 229 | "move_forward", habitat_sim.agent.ActuationSpec(amount=0.0) 230 | ) 231 | } 232 | 233 | return habitat_sim.Configuration(sim_cfg, [agent_cfg]) 234 | -------------------------------------------------------------------------------- /data_generation/habitat_renderer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os, sys, argparse 3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | import cv2 6 | import logging 7 | import habitat_sim as hs 8 | import numpy as np 9 | import quaternion 10 | import yaml 11 | import json 12 | from typing import Any, Dict, List, Tuple, Union 13 | from imgviz import label_colormap 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | import transformation 17 | import imgviz 18 | from datetime import datetime 19 | import time 20 | from settings import make_cfg 21 | 22 | # Custom type definitions 23 | Config = Dict[str, Any] 24 | Observation = hs.sensor.Observation 25 | Sim = hs.Simulator 26 | 27 | def init_habitat(config) : 28 | """Initialize the Habitat simulator with sensors and scene file""" 29 | _cfg = make_cfg(config) 30 | sim = Sim(_cfg) 31 | sim_cfg = hs.SimulatorConfiguration() 32 | sim_cfg.gpu_device_id = 0 33 | # Note: all sensors must have the same resolution 34 | camera_resolution = [config["height"], config["width"]] 35 | sensors = { 36 | "color_sensor": { 37 | "sensor_type": hs.SensorType.COLOR, 38 | "resolution": camera_resolution, 39 | "position": [0.0, config["sensor_height"], 0.0], 40 | }, 41 | "depth_sensor": { 42 | "sensor_type": hs.SensorType.DEPTH, 43 | "resolution": camera_resolution, 44 | "position": [0.0, config["sensor_height"], 0.0], 45 | }, 46 | "semantic_sensor": { 47 | "sensor_type": hs.SensorType.SEMANTIC, 48 | "resolution": camera_resolution, 49 | "position": [0.0, config["sensor_height"], 0.0], 50 | }, 51 | } 52 | 53 | sensor_specs = [] 54 | for sensor_uuid, sensor_params in sensors.items(): 55 | if config[sensor_uuid]: 56 | sensor_spec = hs.SensorSpec() 57 | sensor_spec.uuid = sensor_uuid 58 | sensor_spec.sensor_type = sensor_params["sensor_type"] 59 | sensor_spec.resolution = sensor_params["resolution"] 60 | sensor_spec.position = sensor_params["position"] 61 | 62 | sensor_specs.append(sensor_spec) 63 | 64 | # Here you can specify the amount of displacement in a forward action and the turn angle 65 | agent_cfg = hs.agent.AgentConfiguration() 66 | agent_cfg.sensor_specifications = sensor_specs 67 | agent_cfg.action_space = { 68 | "move_forward": hs.agent.ActionSpec( 69 | "move_forward", hs.agent.ActuationSpec(amount=0.25) 70 | ), 71 | "turn_left": hs.agent.ActionSpec( 72 | "turn_left", hs.agent.ActuationSpec(amount=30.0) 73 | ), 74 | "turn_right": hs.agent.ActionSpec( 75 | "turn_right", hs.agent.ActuationSpec(amount=30.0) 76 | ), 77 | } 78 | 79 | hs_cfg = hs.Configuration(sim_cfg, [agent_cfg]) 80 | # sim = Sim(hs_cfg) 81 | 82 | if config["enable_semantics"]: # extract instance to class mapping function 83 | assert os.path.exists(config["instance2class_mapping"]) 84 | with open(config["instance2class_mapping"], "r") as f: 85 | annotations = json.load(f) 86 | instance_id_to_semantic_label_id = np.array(annotations["id_to_label"]) 87 | num_classes = len(annotations["classes"]) 88 | label_colour_map = label_colormap() 89 | config["instance2semantic"] = instance_id_to_semantic_label_id 90 | config["classes"] = annotations["classes"] 91 | config["objects"] = annotations["objects"] 92 | 93 | config["num_classes"] = num_classes 94 | config["label_colour_map"] = label_colormap() 95 | config["instance_colour_map"] = label_colormap(500) 96 | 97 | 98 | # add camera intrinsic 99 | # hfov = float(agent_cfg.sensor_specifications[0].parameters['hfov']) * np.pi / 180. 100 | # https://aihabitat.org/docs/habitat-api/view-transform-warp.html 101 | # config['K'] = K 102 | # config['K'] = np.array([[fx, 0.0, 0.0], [0.0, fx, 0.0], [0.0, 0.0, 1.0]], 103 | # dtype=np.float64) 104 | 105 | # hfov = float(agent_cfg.sensor_specifications[0].parameters['hfov']) 106 | # fx = 1.0 / np.tan(hfov / 2.0) 107 | # config['K'] = np.array([[fx, 0.0, 0.0], [0.0, fx, 0.0], [0.0, 0.0, 1.0]], 108 | # dtype=np.float64) 109 | 110 | # Get the intrinsic camera parameters 111 | 112 | 113 | logging.info('Habitat simulator initialized') 114 | 115 | return sim, hs_cfg, config 116 | 117 | def save_renders(save_path, observation, enable_semantic, suffix=""): 118 | save_path_rgb = os.path.join(save_path, "rgb") 119 | save_path_depth = os.path.join(save_path, "depth") 120 | save_path_sem_class = os.path.join(save_path, "semantic_class") 121 | save_path_sem_instance = os.path.join(save_path, "semantic_instance") 122 | 123 | if not os.path.exists(save_path_rgb): 124 | os.makedirs(save_path_rgb) 125 | if not os.path.exists(save_path_depth): 126 | os.makedirs(save_path_depth) 127 | if not os.path.exists(save_path_sem_class): 128 | os.makedirs(save_path_sem_class) 129 | if not os.path.exists(save_path_sem_instance): 130 | os.makedirs(save_path_sem_instance) 131 | 132 | cv2.imwrite(os.path.join(save_path_rgb, "rgb{}.png".format(suffix)), observation["color_sensor"][:,:,::-1]) # change from RGB to BGR for opencv write 133 | cv2.imwrite(os.path.join(save_path_depth, "depth{}.png".format(suffix)), observation["depth_sensor_mm"]) 134 | 135 | if enable_semantic: 136 | cv2.imwrite(os.path.join(save_path_sem_class, "semantic_class{}.png".format(suffix)), observation["semantic_class"]) 137 | cv2.imwrite(os.path.join(save_path_sem_class, "vis_sem_class{}.png".format(suffix)), observation["vis_sem_class"][:,:,::-1]) 138 | 139 | cv2.imwrite(os.path.join(save_path_sem_instance, "semantic_instance{}.png".format(suffix)), observation["semantic_instance"]) 140 | cv2.imwrite(os.path.join(save_path_sem_instance, "vis_sem_instance{}.png".format(suffix)), observation["vis_sem_instance"][:,:,::-1]) 141 | 142 | 143 | def render(sim, config): 144 | """Return the sensor observations and ground truth pose""" 145 | observation = sim.get_sensor_observations() 146 | 147 | # process rgb imagem change from RGBA to RGB 148 | observation['color_sensor'] = observation['color_sensor'][..., 0:3] 149 | rgb_img = observation['color_sensor'] 150 | 151 | # process depth 152 | depth_mm = (observation['depth_sensor'].copy()*1000).astype(np.uint16) # change meters to mm 153 | observation['depth_sensor_mm'] = depth_mm 154 | 155 | # process semantics 156 | if config['enable_semantics']: 157 | 158 | # Assuming the scene has no more than 65534 objects 159 | observation['semantic_instance'] = np.clip(observation['semantic_sensor'].astype(np.uint16), 0, 65535) 160 | # observation['semantic_instance'][observation['semantic_instance']==12]=0 # mask out certain instance 161 | # Convert instance IDs to class IDs 162 | 163 | 164 | # observation['semantic_classes'] = np.zeros(observation['semantic'].shape, dtype=np.uint8) 165 | # TODO make this conversion more efficient 166 | semantic_class = config["instance2semantic"][observation['semantic_instance']] 167 | semantic_class[semantic_class < 0] = 0 168 | 169 | vis_sem_class = config["label_colour_map"][semantic_class] 170 | vis_sem_instance = config["instance_colour_map"][observation['semantic_instance']] # may cause error when having more than 255 instances in the scene 171 | 172 | observation['semantic_class'] = semantic_class.astype(np.uint8) 173 | observation["vis_sem_class"] = vis_sem_class.astype(np.uint8) 174 | observation["vis_sem_instance"] = vis_sem_instance.astype(np.uint8) 175 | 176 | # del observation["semantic_sensor"] 177 | 178 | # Get the camera ground truth pose (T_HC) in the habitat frame from the 179 | # position and orientation 180 | t_HC = sim.get_agent(0).get_state().position 181 | q_HC = sim.get_agent(0).get_state().rotation 182 | T_HC = transformation.combine_pose(t_HC, q_HC) 183 | 184 | observation['T_HC'] = T_HC 185 | observation['T_WC'] = transformation.Thc_to_Twc(T_HC) 186 | 187 | return observation 188 | 189 | def set_agent_position(sim, pose): 190 | # Move the agent 191 | R = pose[:3, :3] 192 | orientation_quat = quaternion.from_rotation_matrix(R) 193 | t = pose[:3, 3] 194 | position = t 195 | 196 | orientation = [orientation_quat.x, orientation_quat.y, orientation_quat.z, orientation_quat.w] 197 | agent = sim.get_agent(0) 198 | agent_state = hs.agent.AgentState(position, orientation) 199 | # agent.set_state(agent_state, reset_sensors=False) 200 | agent.set_state(agent_state) 201 | 202 | def main(): 203 | parser = argparse.ArgumentParser(description='Render Colour, Depth, Semantic, Instance labeling from Habitat-Simultation.') 204 | parser.add_argument('--config_file', type=str, 205 | default="./data_generation/replica_render_config_vMAP.yaml", 206 | help='the path to custom config file.') 207 | args = parser.parse_args() 208 | 209 | """Initialize the config dict and Habitat simulator""" 210 | # Read YAML file 211 | with open(args.config_file, 'r') as f: 212 | config = yaml.safe_load(f) 213 | 214 | config["save_path"] = os.path.join(config["save_path"]) 215 | if not os.path.exists(config["save_path"]): 216 | os.makedirs(config["save_path"]) 217 | 218 | T_wc = np.loadtxt(config["pose_file"]).reshape(-1, 4, 4) 219 | Ts_cam2world = T_wc 220 | 221 | print("-----Initialise and Set Habitat-Sim-----") 222 | sim, hs_cfg, config = init_habitat(config) 223 | # Set agent state 224 | sim.initialize_agent(config["default_agent"]) 225 | 226 | """Set agent state""" 227 | print("-----Render Images from Habitat-Sim-----") 228 | with open(os.path.join(config["save_path"], 'render_config.yaml'), 'w') as outfile: 229 | yaml.dump(config, outfile, default_flow_style=False) 230 | start_time = time.time() 231 | total_render_num = Ts_cam2world.shape[0] 232 | for i in range(total_render_num): 233 | if i % 100 == 0 : 234 | print("Rendering Process: {}/{}".format(i, total_render_num)) 235 | set_agent_position(sim, transformation.Twc_to_Thc(Ts_cam2world[i])) 236 | 237 | # replica mode 238 | observation = render(sim, config) 239 | save_renders(config["save_path"], observation, config["enable_semantics"], suffix="_{}".format(i)) 240 | 241 | end_time = time.time() 242 | print("-----Finish Habitat Rendering, Showing Trajectories.-----") 243 | print("Average rendering time per image is {} seconds.".format((end_time-start_time)/Ts_cam2world.shape[0])) 244 | 245 | if __name__ == "__main__": 246 | main() 247 | 248 | 249 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | vMAP SOFTWARE 2 | 3 | LICENCE AGREEMENT 4 | 5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College 6 | London”)) ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY 7 | ON THE CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE FOLLOWING 8 | AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE DOWNLOADING THE SOFTWARE. 9 | BY EXERCISING THE OPTION TO DOWNLOAD THE SOFTWARE YOU AGREE TO BE BOUND BY THE 10 | TERMS OF THE AGREEMENT. 11 | 12 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) 13 | 14 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully 15 | paid-up, royalty free, non-transferable, non-sub- licensable licence (the 16 | “Licence”) to use the elastic fusion source code, including any modification, 17 | part or derivative (the “Software”). 18 | 19 | Ownership and Licence. Your rights to use and download the Software onto your 20 | computer, and all other copies that You are authorised to make, are specified 21 | in this Agreement. However, we (or our licensors) retain all rights, including 22 | but not limited to all copyright and other intellectual property rights 23 | anywhere in the world, in the Software not expressly granted to You in this 24 | Agreement. 25 | 26 | 2. Permitted use of the Licence: 27 | 28 | (a) You may download and install the Software onto one computer or server for 29 | use in accordance with Clause 2(b) of this Agreement provided that You ensure 30 | that the Software is not accessible by other users unless they have themselves 31 | accepted the terms of this licence agreement. 32 | 33 | (b) You may use the Software solely for non-commercial, internal or academic 34 | research purposes and only in accordance with the terms of this Agreement. You 35 | may not use the Software for commercial purposes, including but not limited to 36 | (1) integration of all or part of the source code or the Software into a 37 | product for sale or licence by or on behalf of You to third parties or (2) use 38 | of the Software or any derivative of it for research to develop software 39 | products for sale or licence to a third party or (3) use of the Software or any 40 | derivative of it for research to develop non-software products for sale or 41 | licence to a third party, or (4) use of the Software to provide any service to 42 | an external organisation for which payment is received. 43 | 44 | Should You wish to use the Software for commercial purposes, You shall 45 | email researchcontracts.engineering@imperial.ac.uk . 46 | 47 | (c) Right to Copy. You may copy the Software for back-up and archival purposes, 48 | provided that each copy is kept in your possession and provided You reproduce 49 | our copyright notice (set out in Schedule 1) on each copy. 50 | 51 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software 52 | and You may not transmit, transfer or sub-license this licence to use the 53 | Software or any of your rights or obligations under this Agreement to another 54 | party. 55 | 56 | (e) Identity of Licensee. The licence granted herein is personal to You. You 57 | shall not permit any third party to access, modify or otherwise use the 58 | Software nor shall You access modify or otherwise use the Software on behalf of 59 | any third party. If You wish to obtain a licence for mutiple users or a site 60 | licence for the Software please contact us 61 | at researchcontracts.engineering@imperial.ac.uk . 62 | 63 | (f) Publications and presentations. You may make public, results or data 64 | obtained from, dependent on or arising from research carried out using the 65 | Software, provided that any such presentation or publication identifies the 66 | Software as the source of the results or the data, including the Copyright 67 | Notice given in each element of the Software, and stating that the Software has 68 | been made available for use by You under licence from Imperial College London 69 | and You provide a copy of any such publication to Imperial College London. 70 | 71 | 3. Prohibited Uses. You may not, without written permission from us 72 | at researchcontracts.engineering@imperial.ac.uk : 73 | 74 | (a) Use, copy, modify, merge, or transfer copies of the Software or any 75 | documentation provided by us which relates to the Software except as provided 76 | in this Agreement; 77 | 78 | (b) Use any back-up or archival copies of the Software (or allow anyone else to 79 | use such copies) for any purpose other than to replace the original copy in the 80 | event it is destroyed or becomes defective; or 81 | 82 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner 83 | decode the Software for any reason. 84 | 85 | 4. Warranty Disclaimer 86 | 87 | (a) Disclaimer. The Software has been developed for research purposes only. You 88 | acknowledge that we are providing the Software to You under this licence 89 | agreement free of charge and on condition that the disclaimer set out below 90 | shall apply. We do not represent or warrant that the Software as to: (i) the 91 | quality, accuracy or reliability of the Software; (ii) the suitability of the 92 | Software for any particular use or for use under any specific conditions; and 93 | (iii) whether use of the Software will infringe third-party rights. 94 | 95 | You acknowledge that You have reviewed and evaluated the Software to determine 96 | that it meets your needs and that You assume all responsibility and liability 97 | for determining the suitability of the Software as fit for your particular 98 | purposes and requirements. Subject to Clause 4(b), we exclude and expressly 99 | disclaim all express and implied representations, warranties, conditions and 100 | terms not stated herein (including the implied conditions or warranties of 101 | satisfactory quality, merchantable quality, merchantability and fitness for 102 | purpose). 103 | 104 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or 105 | impose obligations upon us which cannot, in whole or in part, be excluded, 106 | restricted or modified or otherwise do not allow the exclusion of implied 107 | warranties, conditions or terms, in which case the above warranty disclaimer 108 | and exclusion will only apply to You to the extent permitted in the relevant 109 | jurisdiction and does not in any event exclude any implied warranties, 110 | conditions or terms which may not under applicable law be excluded. 111 | 112 | (c) Imperial College London disclaims all responsibility for the use which is 113 | made of the Software and any liability for the outcomes arising from using the 114 | Software. 115 | 116 | 5. Limitation of Liability 117 | 118 | (a) You acknowledge that we are providing the Software to You under this 119 | licence agreement free of charge and on condition that the limitation of 120 | liability set out below shall apply. Accordingly, subject to Clause 5(b), we 121 | exclude all liability whether in contract, tort, negligence or otherwise, in 122 | respect of the Software and/or any related documentation provided to You by us 123 | including, but not limited to, liability for loss or corruption of data, loss 124 | of contracts, loss of income, loss of profits, loss of cover and any 125 | consequential or indirect loss or damage of any kind arising out of or in 126 | connection with this licence agreement, however caused. This exclusion shall 127 | apply even if we have been advised of the possibility of such loss or damage. 128 | 129 | (b) You agree to indemnify Imperial College London and hold it harmless from 130 | and against any and all claims, damages and liabilities asserted by third 131 | parties (including claims for negligence) which arise directly or indirectly 132 | from the use of the Software or any derivative of it or the sale of any 133 | products based on the Software. You undertake to make no liability claim 134 | against any employee, student, agent or appointee of Imperial College London, 135 | in connection with this Licence or the Software. 136 | 137 | (c) Nothing in this Agreement shall have the effect of excluding or limiting 138 | our statutory liability. 139 | 140 | (d) Some jurisdictions do not allow these limitations or exclusions either 141 | wholly or in part, and, to that extent, they may not apply to you. Nothing in 142 | this licence agreement will affect your statutory rights or other relevant 143 | statutory provisions which cannot be excluded, restricted or modified, and its 144 | terms and conditions must be read and construed subject to any such statutory 145 | rights and/or provisions. 146 | 147 | 6. Confidentiality. You agree not to disclose any confidential information 148 | provided to You by us pursuant to this Agreement to any third party without our 149 | prior written consent. The obligations in this Clause 6 shall survive the 150 | termination of this Agreement for any reason. 151 | 152 | 7. Termination. 153 | 154 | (a) We may terminate this licence agreement and your right to use the Software 155 | at any time with immediate effect upon written notice to You. 156 | 157 | (b) This licence agreement and your right to use the Software automatically 158 | terminate if You: 159 | 160 | (i) fail to comply with any provisions of this Agreement; or 161 | 162 | (ii) destroy the copies of the Software in your possession, or voluntarily 163 | return the Software to us. 164 | 165 | (c) Upon termination You will destroy all copies of the Software. 166 | 167 | (d) Otherwise, the restrictions on your rights to use the Software will expire 168 | 10 (ten) years after first use of the Software under this licence agreement. 169 | 170 | 8. Miscellaneous Provisions. 171 | 172 | (a) This Agreement will be governed by and construed in accordance with the 173 | substantive laws of England and Wales whose courts shall have exclusive 174 | jurisdiction over all disputes which may arise between us. 175 | 176 | (b) This is the entire agreement between us relating to the Software, and 177 | supersedes any prior purchase order, communications, advertising or 178 | representations concerning the Software. 179 | 180 | (c) No change or modification of this Agreement will be valid unless it is in 181 | writing, and is signed by us. 182 | 183 | (d) The unenforceability or invalidity of any part of this Agreement will not 184 | affect the enforceability or validity of the remaining parts. 185 | 186 | BSD Elements of the Software 187 | 188 | For BSD elements of the Software, the following terms shall apply: 189 | Copyright as indicated in the header of the individual element of the Software. 190 | All rights reserved. 191 | 192 | Redistribution and use in source and binary forms, with or without 193 | modification, are permitted provided that the following conditions are met: 194 | 195 | 1. Redistributions of source code must retain the above copyright notice, this 196 | list of conditions and the following disclaimer. 197 | 198 | 2. Redistributions in binary form must reproduce the above copyright notice, 199 | this list of conditions and the following disclaimer in the documentation 200 | and/or other materials provided with the distribution. 201 | 202 | 3. Neither the name of the copyright holder nor the names of its contributors 203 | may be used to endorse or promote products derived from this software without 204 | specific prior written permission. 205 | 206 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 207 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 208 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 209 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 210 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 211 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 212 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 213 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 214 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 215 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 216 | 217 | SCHEDULE 1 218 | 219 | The Software 220 | 221 | vMAP is an object-level mapping system with each object represented by a small MLP, that can be efficiently vectorised trained. 222 | It is based on the techniques described in the following publication: 223 | 224 | • Xin Kong, Shikun Liu, Marwan Taher, Andrew J. Davison. vMAP: Vectorised Object Mapping for Neural Field SLAM. ArXiv Preprint, 2023 225 | _________________________ 226 | 227 | Acknowledgments 228 | 229 | If you use the software, you should reference the following paper in any 230 | publication: 231 | 232 | • Xin Kong, Shikun Liu, Marwan Taher, Andrew J. Davison. vMAP: Vectorised Object Mapping for Neural Field SLAM. ArXiv Preprint, 2023 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import imgviz 2 | from torch.utils.data import Dataset, DataLoader 3 | import torch 4 | import numpy as np 5 | import cv2 6 | import os 7 | from utils import enlarge_bbox, get_bbox2d, get_bbox2d_batch, box_filter 8 | import glob 9 | from torchvision import transforms 10 | import image_transforms 11 | import open3d 12 | import time 13 | 14 | def next_live_data(track_to_map_IDT, inited): 15 | while True: 16 | if track_to_map_IDT.empty(): 17 | if inited: 18 | return None # no new frame, use kf buffer 19 | else: # blocking until get the first frame 20 | continue 21 | else: 22 | Buffer_data = track_to_map_IDT.get(block=False) 23 | break 24 | 25 | 26 | if Buffer_data is not None: 27 | image, depth, T, obj, bbox_dict, kf_id = Buffer_data 28 | del Buffer_data 29 | T_obj = torch.eye(4) 30 | sample = {"image": image, "depth": depth, "T": T, "T_obj": T_obj, 31 | "obj": obj, "bbox_dict": bbox_dict, "frame_id": kf_id} 32 | 33 | return sample 34 | else: 35 | print("getting nothing?") 36 | exit(-1) 37 | # return None 38 | 39 | def init_loader(cfg, multi_worker=True): 40 | if cfg.dataset_format == "Replica": 41 | dataset = Replica(cfg) 42 | elif cfg.dataset_format == "ScanNet": 43 | dataset = ScanNet(cfg) 44 | else: 45 | print("Dataset format {} not found".format(cfg.dataset_format)) 46 | exit(-1) 47 | 48 | # init dataloader 49 | if multi_worker: 50 | # multi worker loader 51 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, sampler=None, 52 | batch_sampler=None, num_workers=4, collate_fn=None, 53 | pin_memory=True, drop_last=False, timeout=0, 54 | worker_init_fn=None, generator=None, prefetch_factor=2, 55 | persistent_workers=True) 56 | else: 57 | # single worker loader 58 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, sampler=None, 59 | batch_sampler=None, num_workers=0) 60 | 61 | return dataloader 62 | 63 | class Replica(Dataset): 64 | def __init__(self, cfg): 65 | self.imap_mode = cfg.imap_mode 66 | self.root_dir = cfg.dataset_dir 67 | traj_file = os.path.join(self.root_dir, "traj_w_c.txt") 68 | self.Twc = np.loadtxt(traj_file, delimiter=" ").reshape([-1, 4, 4]) 69 | self.depth_transform = transforms.Compose( 70 | [image_transforms.DepthScale(cfg.depth_scale), 71 | image_transforms.DepthFilter(cfg.max_depth)]) 72 | 73 | # background semantic classes: undefined--1, undefined-0 beam-5 blinds-12 curtain-30 ceiling-31 floor-40 pillar-60 vent-92 wall-93 wall-plug-95 window-97 rug-98 74 | self.background_cls_list = [5,12,30,31,40,60,92,93,95,97,98,79] 75 | # Not sure: door-37 handrail-43 lamp-47 pipe-62 rack-66 shower-stall-73 stair-77 switch-79 wall-cabinet-94 picture-59 76 | self.bbox_scale = 0.2 # 1 #1.5 0.9== s=1/9, s=0.2 77 | 78 | def __len__(self): 79 | return len(os.listdir(os.path.join(self.root_dir, "depth"))) 80 | 81 | def __getitem__(self, idx): 82 | bbox_dict = {} 83 | rgb_file = os.path.join(self.root_dir, "rgb", "rgb_" + str(idx) + ".png") 84 | depth_file = os.path.join(self.root_dir, "depth", "depth_" + str(idx) + ".png") 85 | inst_file = os.path.join(self.root_dir, "semantic_instance", "semantic_instance_" + str(idx) + ".png") 86 | obj_file = os.path.join(self.root_dir, "semantic_class", "semantic_class_" + str(idx) + ".png") 87 | depth = cv2.imread(depth_file, -1).astype(np.float32).transpose(1,0) 88 | image = cv2.imread(rgb_file).astype(np.uint8) 89 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose(1,0,2) 90 | obj = cv2.imread(obj_file, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0) # uint16 -> int32 91 | inst = cv2.imread(inst_file, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0) # uint16 -> int32 92 | 93 | bbox_scale = self.bbox_scale 94 | 95 | if self.imap_mode: 96 | obj = np.zeros_like(obj) 97 | else: 98 | obj_ = np.zeros_like(obj) 99 | inst_list = [] 100 | batch_masks = [] 101 | for inst_id in np.unique(inst): 102 | inst_mask = inst == inst_id 103 | # if np.sum(inst_mask) <= 2000: # too small 20 400 104 | # continue 105 | sem_cls = np.unique(obj[inst_mask]) # sem label, only interested obj 106 | assert sem_cls.shape[0] != 0 107 | if sem_cls in self.background_cls_list: 108 | continue 109 | obj_mask = inst == inst_id 110 | batch_masks.append(obj_mask) 111 | inst_list.append(inst_id) 112 | if len(batch_masks) > 0: 113 | batch_masks = torch.from_numpy(np.stack(batch_masks)) 114 | cmins, cmaxs, rmins, rmaxs = get_bbox2d_batch(batch_masks) 115 | 116 | for i in range(batch_masks.shape[0]): 117 | w = rmaxs[i] - rmins[i] 118 | h = cmaxs[i] - cmins[i] 119 | if w <= 10 or h <= 10: # too small todo 120 | continue 121 | bbox_enlarged = enlarge_bbox([rmins[i], cmins[i], rmaxs[i], cmaxs[i]], scale=bbox_scale, 122 | w=obj.shape[1], h=obj.shape[0]) 123 | # inst_list.append(inst_id) 124 | inst_id = inst_list[i] 125 | obj_[batch_masks[i]] = 1 126 | # bbox_dict.update({int(inst_id): torch.from_numpy(np.array(bbox_enlarged).reshape(-1))}) # batch format 127 | bbox_dict.update({inst_id: torch.from_numpy(np.array( 128 | [bbox_enlarged[1], bbox_enlarged[3], bbox_enlarged[0], bbox_enlarged[2]]))}) # bbox order 129 | 130 | inst[obj_ == 0] = 0 # for background 131 | obj = inst 132 | 133 | bbox_dict.update({0: torch.from_numpy(np.array([int(0), int(obj.shape[0]), 0, int(obj.shape[1])]))}) # bbox order 134 | 135 | T = self.Twc[idx] # could change to ORB-SLAM pose or else 136 | T_obj = np.eye(4) # obj pose, if dynamic 137 | sample = {"image": image, "depth": depth, "T": T, "T_obj": T_obj, 138 | "obj": obj, "bbox_dict": bbox_dict, "frame_id": idx} 139 | 140 | if image is None or depth is None: 141 | print(rgb_file) 142 | print(depth_file) 143 | raise ValueError 144 | 145 | if self.depth_transform: 146 | sample["depth"] = self.depth_transform(sample["depth"]) 147 | 148 | return sample 149 | 150 | class ScanNet(Dataset): 151 | def __init__(self, cfg): 152 | self.imap_mode = cfg.imap_mode 153 | self.root_dir = cfg.dataset_dir 154 | self.color_paths = sorted(glob.glob(os.path.join( 155 | self.root_dir, 'color', '*.jpg')), key=lambda x: int(os.path.basename(x)[:-4])) 156 | self.depth_paths = sorted(glob.glob(os.path.join( 157 | self.root_dir, 'depth', '*.png')), key=lambda x: int(os.path.basename(x)[:-4])) 158 | self.inst_paths = sorted(glob.glob(os.path.join( 159 | self.root_dir, 'instance-filt', '*.png')), key=lambda x: int(os.path.basename(x)[:-4])) # instance-filt 160 | self.sem_paths = sorted(glob.glob(os.path.join( 161 | self.root_dir, 'label-filt', '*.png')), key=lambda x: int(os.path.basename(x)[:-4])) # label-filt 162 | self.load_poses(os.path.join(self.root_dir, 'pose')) 163 | self.n_img = len(self.color_paths) 164 | self.depth_transform = transforms.Compose( 165 | [image_transforms.DepthScale(cfg.depth_scale), 166 | image_transforms.DepthFilter(cfg.max_depth)]) 167 | # self.rgb_transform = rgb_transform 168 | self.W = cfg.W 169 | self.H = cfg.H 170 | self.fx = cfg.fx 171 | self.fy = cfg.fy 172 | self.cx = cfg.cx 173 | self.cy = cfg.cy 174 | self.edge = cfg.mw 175 | self.intrinsic_open3d = open3d.camera.PinholeCameraIntrinsic( 176 | width=self.W, 177 | height=self.H, 178 | fx=self.fx, 179 | fy=self.fy, 180 | cx=self.cx, 181 | cy=self.cy, 182 | ) 183 | 184 | self.min_pixels = 1500 185 | # from scannetv2-labels.combined.tsv 186 | #1-wall, 3-floor, 16-window, 41-ceiling, 232-light switch 0-unknown? 21-pillar 161-doorframe, shower walls-128, curtain-21, windowsill-141 187 | self.background_cls_list = [-1, 0, 1, 3, 16, 41, 232, 21, 161, 128, 21] 188 | self.bbox_scale = 0.2 189 | self.inst_dict = {} 190 | 191 | def load_poses(self, path): 192 | self.poses = [] 193 | pose_paths = sorted(glob.glob(os.path.join(path, '*.txt')), 194 | key=lambda x: int(os.path.basename(x)[:-4])) 195 | for pose_path in pose_paths: 196 | with open(pose_path, "r") as f: 197 | lines = f.readlines() 198 | ls = [] 199 | for line in lines: 200 | l = list(map(float, line.split(' '))) 201 | ls.append(l) 202 | c2w = np.array(ls).reshape(4, 4) 203 | self.poses.append(c2w) 204 | 205 | def __len__(self): 206 | return self.n_img 207 | 208 | def __getitem__(self, index): 209 | bbox_scale = self.bbox_scale 210 | color_path = self.color_paths[index] 211 | depth_path = self.depth_paths[index] 212 | inst_path = self.inst_paths[index] 213 | sem_path = self.sem_paths[index] 214 | color_data = cv2.imread(color_path).astype(np.uint8) 215 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 216 | depth_data = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) 217 | depth_data = np.nan_to_num(depth_data, nan=0.) 218 | T = None 219 | if self.poses is not None: 220 | T = self.poses[index] 221 | if np.any(np.isinf(T)): 222 | if index + 1 == self.__len__(): 223 | print("pose inf!") 224 | return None 225 | return self.__getitem__(index + 1) 226 | 227 | H, W = depth_data.shape 228 | color_data = cv2.resize(color_data, (W, H), interpolation=cv2.INTER_LINEAR) 229 | if self.edge: 230 | edge = self.edge # crop image edge, there are invalid value on the edge of the color image 231 | color_data = color_data[edge:-edge, edge:-edge] 232 | depth_data = depth_data[edge:-edge, edge:-edge] 233 | if self.depth_transform: 234 | depth_data = self.depth_transform(depth_data) 235 | bbox_dict = {} 236 | if self.imap_mode: 237 | inst_data = np.zeros_like(depth_data).astype(np.int32) 238 | else: 239 | inst_data = cv2.imread(inst_path, cv2.IMREAD_UNCHANGED) 240 | inst_data = cv2.resize(inst_data, (W, H), interpolation=cv2.INTER_NEAREST).astype(np.int32) 241 | sem_data = cv2.imread(sem_path, cv2.IMREAD_UNCHANGED)#.astype(np.int32) 242 | sem_data = cv2.resize(sem_data, (W, H), interpolation=cv2.INTER_NEAREST) 243 | if self.edge: 244 | edge = self.edge 245 | inst_data = inst_data[edge:-edge, edge:-edge] 246 | sem_data = sem_data[edge:-edge, edge:-edge] 247 | inst_data += 1 # shift from 0->1 , 0 is for background 248 | 249 | # box filter 250 | track_start = time.time() 251 | masks = [] 252 | classes = [] 253 | # convert to list of arrays 254 | obj_ids = np.unique(inst_data) 255 | for obj_id in obj_ids: 256 | mask = inst_data == obj_id 257 | sem_cls = np.unique(sem_data[mask]) 258 | if sem_cls in self.background_cls_list: 259 | inst_data[mask] = 0 # set to background 260 | continue 261 | masks.append(mask) 262 | classes.append(obj_id) 263 | T_CW = np.linalg.inv(T) 264 | inst_data = box_filter(masks, classes, depth_data, self.inst_dict, self.intrinsic_open3d, T_CW, min_pixels=self.min_pixels) 265 | 266 | merged_obj_ids = np.unique(inst_data) 267 | for obj_id in merged_obj_ids: 268 | mask = inst_data == obj_id 269 | bbox2d = get_bbox2d(mask, bbox_scale=bbox_scale) 270 | if bbox2d is None: 271 | inst_data[mask] = 0 # set to bg 272 | else: 273 | min_x, min_y, max_x, max_y = bbox2d 274 | bbox_dict.update({int(obj_id): torch.from_numpy(np.array([min_x, max_x, min_y, max_y]).reshape(-1))}) # batch format 275 | bbox_time = time.time() 276 | # print("bbox time ", bbox_time - filter_time) 277 | cv2.imshow("inst", imgviz.label2rgb(inst_data)) 278 | cv2.waitKey(1) 279 | print("frame {} track time {}".format(index, bbox_time-track_start)) 280 | 281 | bbox_dict.update({0: torch.from_numpy(np.array([int(0), int(inst_data.shape[1]), 0, int(inst_data.shape[0])]))}) # bbox order 282 | # wrap data to frame dict 283 | T_obj = np.identity(4) 284 | sample = {"image": color_data.transpose(1,0,2), "depth": depth_data.transpose(1,0), "T": T, "T_obj": T_obj} 285 | if color_data is None or depth_data is None: 286 | print(color_path) 287 | print(depth_path) 288 | raise ValueError 289 | 290 | sample.update({"obj": inst_data.transpose(1,0)}) 291 | sample.update({"bbox_dict": bbox_dict}) 292 | return sample 293 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imgviz 3 | import numpy as np 4 | import torch 5 | from functorch import combine_state_for_ensemble 6 | import open3d 7 | import queue 8 | import copy 9 | import torch.utils.dlpack 10 | 11 | class BoundingBox(): 12 | def __init__(self): 13 | super(BoundingBox, self).__init__() 14 | self.extent = None 15 | self.R = None 16 | self.center = None 17 | self.points3d = None # (8,3) 18 | 19 | def bbox_open3d2bbox(bbox_o3d): 20 | bbox = BoundingBox() 21 | bbox.extent = bbox_o3d.extent 22 | bbox.R = bbox_o3d.R 23 | bbox.center = bbox_o3d.center 24 | return bbox 25 | 26 | def bbox_bbox2open3d(bbox): 27 | bbox_o3d = open3d.geometry.OrientedBoundingBox(bbox.center, bbox.R, bbox.extent) 28 | return bbox_o3d 29 | 30 | def update_vmap(models, optimiser): 31 | fmodel, params, buffers = combine_state_for_ensemble(models) 32 | [p.requires_grad_() for p in params] 33 | optimiser.add_param_group({"params": params}) # imap b l 34 | return (fmodel, params, buffers) 35 | 36 | def enlarge_bbox(bbox, scale, w, h): 37 | assert scale >= 0 38 | # print(bbox) 39 | min_x, min_y, max_x, max_y = bbox 40 | margin_x = int(0.5 * scale * (max_x - min_x)) 41 | margin_y = int(0.5 * scale * (max_y - min_y)) 42 | if margin_y == 0 or margin_x == 0: 43 | return None 44 | # assert margin_x != 0 45 | # assert margin_y != 0 46 | min_x -= margin_x 47 | max_x += margin_x 48 | min_y -= margin_y 49 | max_y += margin_y 50 | 51 | min_x = np.clip(min_x, 0, w-1) 52 | min_y = np.clip(min_y, 0, h-1) 53 | max_x = np.clip(max_x, 0, w-1) 54 | max_y = np.clip(max_y, 0, h-1) 55 | 56 | bbox_enlarged = [int(min_x), int(min_y), int(max_x), int(max_y)] 57 | return bbox_enlarged 58 | 59 | def get_bbox2d(obj_mask, bbox_scale=1.0): 60 | contours, hierarchy = cv2.findContours(obj_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[ 61 | -2:] 62 | # # Find the index of the largest contour 63 | # areas = [cv2.contourArea(c) for c in contours] 64 | # max_index = np.argmax(areas) 65 | # cnt = contours[max_index] 66 | # Concatenate all contours 67 | if len(contours) == 0: 68 | return None 69 | cnt = np.concatenate(contours) 70 | x, y, w, h = cv2.boundingRect(cnt) # todo if multiple contours, choose the outmost one? 71 | # x, y, w, h = cv2.boundingRect(contours) 72 | bbox_enlarged = enlarge_bbox([x, y, x + w, y + h], scale=bbox_scale, w=obj_mask.shape[1], h=obj_mask.shape[0]) 73 | return bbox_enlarged 74 | 75 | def get_bbox2d_batch(img): 76 | b,h,w = img.shape[:3] 77 | rows = torch.any(img, axis=2) 78 | cols = torch.any(img, axis=1) 79 | rmins = torch.argmax(rows.float(), dim=1) 80 | rmaxs = h - torch.argmax(rows.float().flip(dims=[1]), dim=1) 81 | cmins = torch.argmax(cols.float(), dim=1) 82 | cmaxs = w - torch.argmax(cols.float().flip(dims=[1]), dim=1) 83 | 84 | return rmins, rmaxs, cmins, cmaxs 85 | 86 | def get_latest_queue(q): 87 | message = None 88 | while(True): 89 | try: 90 | message_latest = q.get(block=False) 91 | if message is not None: 92 | del message 93 | message = message_latest 94 | 95 | except queue.Empty: 96 | break 97 | 98 | return message 99 | 100 | # for association/tracking 101 | class InstData: 102 | def __init__(self): 103 | super(InstData, self).__init__() 104 | self.bbox3D = None 105 | self.inst_id = None # instance 106 | self.class_id = None # semantic 107 | self.pc_sample = None 108 | self.merge_cnt = 0 # merge times counting 109 | self.cmp_cnt = 0 110 | 111 | 112 | def box_filter(masks, classes, depth, inst_dict, intrinsic_open3d, T_CW, min_pixels=500, voxel_size=0.01): 113 | bbox3d_scale = 1.0 # 1.05 114 | inst_data = np.zeros_like(depth, dtype=np.int) 115 | for i in range(len(masks)): 116 | diff_mask = None 117 | inst_mask = masks[i] 118 | inst_id = classes[i] 119 | if inst_id == 0: 120 | continue 121 | inst_depth = np.copy(depth) 122 | inst_depth[~inst_mask] = 0. # inst_mask 123 | # proj_time = time.time() 124 | inst_pc = unproject_pointcloud(inst_depth, intrinsic_open3d, T_CW) 125 | # print("proj time ", time.time()-proj_time) 126 | if len(inst_pc.points) <= 10: # too small 127 | inst_data[inst_mask] = 0 # set to background 128 | continue 129 | if inst_id in inst_dict.keys(): 130 | candidate_inst = inst_dict[inst_id] 131 | # iou_time = time.time() 132 | IoU, indices = check_inside_ratio(inst_pc, candidate_inst.bbox3D) 133 | # print("iou time ", time.time()-iou_time) 134 | # if indices empty 135 | candidate_inst.cmp_cnt += 1 136 | if len(indices) >= 1: 137 | candidate_inst.pc += inst_pc.select_by_index(indices) # only merge pcs inside scale*bbox 138 | # todo check indices follow valid depth 139 | valid_depth_mask = np.zeros_like(inst_depth, dtype=np.bool) 140 | valid_pc_mask = valid_depth_mask[inst_depth!=0] 141 | valid_pc_mask[indices] = True 142 | valid_depth_mask[inst_depth != 0] = valid_pc_mask 143 | valid_mask = valid_depth_mask 144 | diff_mask = np.zeros_like(inst_mask) 145 | # uv_opencv, _ = cv2.projectPoints(np.array(inst_pc.select_by_index(indices).points), T_CW[:3, :3], 146 | # T_CW[:3, 3], intrinsic_open3d.intrinsic_matrix[:3, :3], None) 147 | # uv = np.round(uv_opencv).squeeze().astype(int) 148 | # u = uv[:, 0].reshape(-1, 1) 149 | # v = uv[:, 1].reshape(-1, 1) 150 | # vu = np.concatenate([v, u], axis=-1) 151 | # valid_mask = np.zeros_like(inst_mask) 152 | # valid_mask[tuple(vu.T)] = True 153 | # # cv2.imshow("valid", (inst_depth!=0).astype(np.uint8)*255) 154 | # # cv2.waitKey(1) 155 | diff_mask[(inst_depth != 0) & (~valid_mask)] = True 156 | # cv2.imshow("diff_mask", diff_mask.astype(np.uint8) * 255) 157 | # cv2.waitKey(1) 158 | else: # merge all for scannet 159 | # print("too few pcs obj ", inst_id) 160 | inst_data[inst_mask] = -1 161 | continue 162 | # downsample_time = time.time() 163 | # adapt_voxel_size = np.maximum(np.max(candidate_inst.bbox3D.extent)/100, 0.1) 164 | candidate_inst.pc = candidate_inst.pc.voxel_down_sample(voxel_size) # adapt_voxel_size 165 | # candidate_inst.pc = candidate_inst.pc.farthest_point_down_sample(500) 166 | # candidate_inst.pc = candidate_inst.pc.random_down_sample(np.minimum(len(candidate_inst.pc.points)/500.,1)) 167 | # print("downsample time ", time.time() - downsample_time) # 0.03s even 168 | # bbox_time = time.time() 169 | try: 170 | candidate_inst.bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(candidate_inst.pc.points) 171 | except RuntimeError: 172 | # print("too few pcs obj ", inst_id) 173 | inst_data[inst_mask] = -1 174 | continue 175 | # enlarge 176 | candidate_inst.bbox3D.scale(bbox3d_scale, candidate_inst.bbox3D.get_center()) 177 | else: # new inst 178 | # init new inst and new sem 179 | new_inst = InstData() 180 | new_inst.inst_id = inst_id 181 | smaller_mask = cv2.erode(inst_mask.astype(np.uint8), np.ones((5, 5)), iterations=3).astype(bool) 182 | if np.sum(smaller_mask) < min_pixels: 183 | # print("too few pcs obj ", inst_id) 184 | inst_data[inst_mask] = 0 185 | continue 186 | inst_depth_small = depth.copy() 187 | inst_depth_small[~smaller_mask] = 0 188 | inst_pc_small = unproject_pointcloud(inst_depth_small, intrinsic_open3d, T_CW) 189 | new_inst.pc = inst_pc_small 190 | new_inst.pc = new_inst.pc.voxel_down_sample(voxel_size) 191 | try: 192 | inst_bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(new_inst.pc.points) 193 | except RuntimeError: 194 | # print("too few pcs obj ", inst_id) 195 | inst_data[inst_mask] = 0 196 | continue 197 | # scale up 198 | inst_bbox3D.scale(bbox3d_scale, inst_bbox3D.get_center()) 199 | new_inst.bbox3D = inst_bbox3D 200 | # update inst_dict 201 | inst_dict.update({inst_id: new_inst}) # init new sem 202 | 203 | # update inst_data 204 | inst_data[inst_mask] = inst_id 205 | if diff_mask is not None: 206 | inst_data[diff_mask] = -1 # unsure area 207 | 208 | return inst_data 209 | 210 | def load_matrix_from_txt(path, shape=(4, 4)): 211 | with open(path) as f: 212 | txt = f.readlines() 213 | txt = ''.join(txt).replace('\n', ' ') 214 | matrix = [float(v) for v in txt.split()] 215 | return np.array(matrix).reshape(shape) 216 | 217 | def check_mask_order(obj_masks, depth_np, obj_ids): 218 | print(len(obj_masks)) 219 | print(len(obj_ids)) 220 | 221 | assert len(obj_masks) == len(obj_ids) 222 | depth = torch.from_numpy(depth_np) 223 | obj_masked_modified = copy.deepcopy(obj_masks[:]) 224 | for i in range(len(obj_masks) - 1): 225 | 226 | mask1 = obj_masks[i].float() 227 | mask1_ = obj_masked_modified[i].float() 228 | for j in range(i + 1, len(obj_masks)): 229 | mask2 = obj_masks[j].float() 230 | mask2_ = obj_masked_modified[j].float() 231 | # case 1: if they don't intersect we don't touch them 232 | if ((mask1 + mask2) == 2).sum() == 0: 233 | continue 234 | # case 2: the entire object 1 is inside of object 2, we say object 1 is in front of object 2: 235 | elif (((mask1 + mask2) == 2).float() - mask1).sum() == 0: 236 | mask2_ -= mask1_ 237 | # case 3: the entire object 2 is inside of object 1, we say object 2 is in front of object 1: 238 | elif (((mask1 + mask2) == 2).float() - mask2).sum() == 0: 239 | mask1_ -= mask2_ 240 | # case 4: use depth to check object order: 241 | else: 242 | # object 1 is closer 243 | if (depth * mask1).sum() / mask1.sum() > (depth * mask2).sum() / mask2.sum(): 244 | mask2_ -= ((mask1 + mask2) == 2).float() 245 | # object 2 is closer 246 | if (depth * mask1).sum() / mask1.sum() < (depth * mask2).sum() / mask2.sum(): 247 | mask1_ -= ((mask1 + mask2) == 2).float() 248 | 249 | final_mask = torch.zeros_like(depth, dtype=torch.int) 250 | # instance_labels = {} 251 | for i in range(len(obj_masked_modified)): 252 | final_mask = final_mask.masked_fill(obj_masked_modified[i] > 0, obj_ids[i]) 253 | # instance_labels[i] = obj_ids[i].item() 254 | return final_mask.cpu().numpy() 255 | 256 | 257 | def unproject_pointcloud(depth, intrinsic_open3d, T_CW): 258 | # depth, mask, intrinsic, extrinsic -> point clouds 259 | pc_sample = open3d.geometry.PointCloud.create_from_depth_image(depth=open3d.geometry.Image(depth), 260 | intrinsic=intrinsic_open3d, 261 | extrinsic=T_CW, 262 | depth_scale=1.0, 263 | project_valid_depth_only=True) 264 | return pc_sample 265 | 266 | def check_inside_ratio(pc, bbox3D): 267 | # pc, bbox3d -> inside ratio 268 | indices = bbox3D.get_point_indices_within_bounding_box(pc.points) 269 | assert len(pc.points) > 0 270 | ratio = len(indices) / len(pc.points) 271 | # print("ratio ", ratio) 272 | return ratio, indices 273 | 274 | def track_instance(masks, classes, depth, inst_list, sem_dict, intrinsic_open3d, T_CW, IoU_thresh=0.5, voxel_size=0.1, 275 | min_pixels=2000, erode=True, clip_features=None, class_names=None): 276 | device = masks.device 277 | inst_data_dict = {} 278 | inst_data_dict.update({0: torch.zeros(depth.shape, dtype=torch.int, device=device)}) 279 | inst_ids = [] 280 | bbox3d_scale = 1.0 # todo 1.0 281 | min_extent = 0.05 282 | depth = torch.from_numpy(depth).to(device) 283 | for i in range(len(masks)): 284 | inst_data = torch.zeros(depth.shape, dtype=torch.int, device=device) 285 | smaller_mask = cv2.erode(masks[i].detach().cpu().numpy().astype(np.uint8), np.ones((5, 5)), iterations=3).astype(bool) 286 | inst_depth_small = depth.detach().cpu().numpy() 287 | inst_depth_small[~smaller_mask] = 0 288 | inst_pc_small = unproject_pointcloud(inst_depth_small, intrinsic_open3d, T_CW) 289 | diff_mask = None 290 | if np.sum(smaller_mask) <= min_pixels: # too small 20 400 # todo use sem to set background 291 | inst_data[masks[i]] = 0 # set to background 292 | continue 293 | inst_pc_voxel = inst_pc_small.voxel_down_sample(voxel_size) 294 | if len(inst_pc_voxel.points) <= 10: # too small 20 400 # todo use sem to set background 295 | inst_data[masks[i]] = 0 # set to background 296 | continue 297 | is_merged = False 298 | inst_id = None 299 | inst_mask = masks[i] #smaller_mask #masks[i] # todo only 300 | inst_class = classes[i] 301 | inst_depth = depth.detach().cpu().numpy() 302 | inst_depth[~masks[i].detach().cpu().numpy()] = 0. # inst_mask 303 | inst_pc = unproject_pointcloud(inst_depth, intrinsic_open3d, T_CW) 304 | sem_inst_list = [] 305 | if clip_features is not None: # check similar sems based on clip feature distance 306 | sem_thr = 200 #300. for table #320. # 260. 307 | for sem_exist in sem_dict.keys(): 308 | if torch.abs(clip_features[class_names[inst_class]] - clip_features[class_names[sem_exist]]).sum() < sem_thr: 309 | sem_inst_list.extend(sem_dict[sem_exist]) 310 | else: # no clip features, only do strictly sem check 311 | if inst_class in sem_dict.keys(): 312 | sem_inst_list.extend(sem_dict[inst_class]) 313 | 314 | for candidate_inst in sem_inst_list: 315 | # if True: # only consider 3D bbox, merge them if they are spatial together 316 | IoU, indices = check_inside_ratio(inst_pc, candidate_inst.bbox3D) 317 | candidate_inst.cmp_cnt += 1 318 | if IoU > IoU_thresh: 319 | # merge inst to candidate 320 | is_merged = True 321 | candidate_inst.merge_cnt += 1 322 | candidate_inst.pc += inst_pc.select_by_index(indices) 323 | # inst_uv = inst_pc.select_by_index(indices).project_to_depth_image(masks[i].shape[1], masks[i].shape[0], intrinsic_open3d, T_CW, depth_scale=1.0, depth_max=10.0) 324 | # # inst_uv = torch.utils.dlpack.from_dlpack(uv_opencv.as_tensor().to_dlpack()) 325 | # valid_mask = inst_uv.squeeze() > 0. # shape --> H, W 326 | # diff_mask = (inst_depth > 0.) & (~valid_mask) 327 | diff_mask = torch.zeros_like(inst_mask) 328 | uv_opencv, _ = cv2.projectPoints(np.array(inst_pc.select_by_index(indices).points), T_CW[:3, :3], 329 | T_CW[:3, 3], intrinsic_open3d.intrinsic_matrix[:3,:3], None) 330 | uv = np.round(uv_opencv).squeeze().astype(int) 331 | u = uv[:, 0].reshape(-1, 1) 332 | v = uv[:, 1].reshape(-1, 1) 333 | vu = np.concatenate([v, u], axis=-1) 334 | valid_mask = np.zeros(inst_mask.shape, dtype=np.bool) 335 | valid_mask[tuple(vu.T)] = True 336 | diff_mask[(inst_depth!=0) & (~valid_mask)] = True 337 | # downsample pcs 338 | candidate_inst.pc = candidate_inst.pc.voxel_down_sample(voxel_size) 339 | # candidate_inst.pc.random_down_sample(np.minimum(500//len(candidate_inst.pc.points),1)) 340 | candidate_inst.bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(candidate_inst.pc.points) 341 | # enlarge 342 | candidate_inst.bbox3D.scale(bbox3d_scale, candidate_inst.bbox3D.get_center()) 343 | candidate_inst.bbox3D.extent = np.maximum(candidate_inst.bbox3D.extent, min_extent) # at least bigger than min_extent 344 | inst_id = candidate_inst.inst_id 345 | break 346 | # if candidate_inst.cmp_cnt >= 20 and candidate_inst.merge_cnt <= 5: 347 | # sem_inst_list.remove(candidate_inst) 348 | 349 | if not is_merged: 350 | # init new inst and new sem 351 | new_inst = InstData() 352 | new_inst.inst_id = len(inst_list) + 1 353 | new_inst.class_id = inst_class 354 | 355 | new_inst.pc = inst_pc_small 356 | new_inst.pc = new_inst.pc.voxel_down_sample(voxel_size) 357 | inst_bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(new_inst.pc.points) 358 | # scale up 359 | inst_bbox3D.scale(bbox3d_scale, inst_bbox3D.get_center()) 360 | inst_bbox3D.extent = np.maximum(inst_bbox3D.extent, min_extent) 361 | new_inst.bbox3D = inst_bbox3D 362 | inst_list.append(new_inst) 363 | inst_id = new_inst.inst_id 364 | # update sem_dict 365 | if inst_class in sem_dict.keys(): 366 | sem_dict[inst_class].append(new_inst) # append new inst to exist sem 367 | else: 368 | sem_dict.update({inst_class: [new_inst]}) # init new sem 369 | # update inst_data 370 | inst_data[inst_mask] = inst_id 371 | if diff_mask is not None: 372 | inst_data[diff_mask] = -1 # unsure area 373 | if inst_id not in inst_ids: 374 | inst_data_dict.update({inst_id: inst_data}) 375 | else: 376 | continue 377 | # idx = inst_ids.index(inst_id) 378 | # inst_data_list[idx] = inst_data_list[idx] & torch.from_numpy(inst_data) # merge them? todo 379 | # return inst_data 380 | mask_bg = torch.stack(list(inst_data_dict.values())).sum(0) != 0 381 | inst_data_dict.update({0: mask_bg.int()}) 382 | return inst_data_dict 383 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import loss 3 | from vmap import * 4 | import utils 5 | import open3d 6 | import dataset 7 | import vis 8 | from functorch import vmap 9 | import argparse 10 | from cfg import Config 11 | import shutil 12 | 13 | if __name__ == "__main__": 14 | ############################################# 15 | # init config 16 | torch.backends.cudnn.enabled = True 17 | torch.backends.cudnn.benchmark = True 18 | 19 | # setting params 20 | parser = argparse.ArgumentParser(description='Model training for single GPU') 21 | parser.add_argument('--logdir', default='./logs/debug', 22 | type=str) 23 | parser.add_argument('--config', 24 | default='./configs/Replica/config_replica_room0_vMAP.json', 25 | type=str) 26 | parser.add_argument('--save_ckpt', 27 | default=False, 28 | type=bool) 29 | args = parser.parse_args() 30 | 31 | log_dir = args.logdir 32 | config_file = args.config 33 | save_ckpt = args.save_ckpt 34 | os.makedirs(log_dir, exist_ok=True) # saving logs 35 | shutil.copy(config_file, log_dir) 36 | cfg = Config(config_file) # config params 37 | n_sample_per_step = cfg.n_per_optim 38 | n_sample_per_step_bg = cfg.n_per_optim_bg 39 | 40 | # param for vis 41 | vis3d = open3d.visualization.Visualizer() 42 | vis3d.create_window(window_name="3D mesh vis", 43 | width=cfg.W, 44 | height=cfg.H, 45 | left=600, top=50) 46 | view_ctl = vis3d.get_view_control() 47 | view_ctl.set_constant_z_far(10.) 48 | 49 | # set camera 50 | cam_info = cameraInfo(cfg) 51 | intrinsic_open3d = open3d.camera.PinholeCameraIntrinsic( 52 | width=cfg.W, 53 | height=cfg.H, 54 | fx=cfg.fx, 55 | fy=cfg.fy, 56 | cx=cfg.cx, 57 | cy=cfg.cy) 58 | 59 | # init obj_dict 60 | obj_dict = {} # only objs 61 | vis_dict = {} # including bg 62 | 63 | # init for training 64 | AMP = False 65 | if AMP: 66 | scaler = torch.cuda.amp.GradScaler() # amp https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/ 67 | optimiser = torch.optim.AdamW([torch.autograd.Variable(torch.tensor(0))], lr=cfg.learning_rate, weight_decay=cfg.weight_decay) 68 | 69 | ############################################# 70 | # init data stream 71 | if not cfg.live_mode: 72 | # load dataset 73 | dataloader = dataset.init_loader(cfg) 74 | dataloader_iterator = iter(dataloader) 75 | dataset_len = len(dataloader) 76 | else: 77 | dataset_len = 1000000 78 | # # init ros node 79 | # torch.multiprocessing.set_start_method('spawn') # spawn 80 | # import ros_nodes 81 | # track_to_map_Buffer = torch.multiprocessing.Queue(maxsize=5) 82 | # # track_to_vis_T_WC = torch.multiprocessing.Queue(maxsize=1) 83 | # kfs_que = torch.multiprocessing.Queue(maxsize=5) # to store one more buffer 84 | # track_p = torch.multiprocessing.Process(target=ros_nodes.Tracking, 85 | # args=( 86 | # (cfg), (track_to_map_Buffer), (None), 87 | # (kfs_que), (True),)) 88 | # track_p.start() 89 | 90 | 91 | # init vmap 92 | fc_models, pe_models = [], [] 93 | scene_bg = None 94 | 95 | for frame_id in tqdm(range(dataset_len)): 96 | print("*********************************************") 97 | # get new frame data 98 | with performance_measure(f"getting next data"): 99 | if not cfg.live_mode: 100 | # get data from dataloader 101 | sample = next(dataloader_iterator) 102 | else: 103 | pass 104 | 105 | if sample is not None: # new frame 106 | last_frame_time = time.time() 107 | with performance_measure(f"Appending data"): 108 | rgb = sample["image"].to(cfg.data_device) 109 | depth = sample["depth"].to(cfg.data_device) 110 | twc = sample["T"].to(cfg.data_device) 111 | bbox_dict = sample["bbox_dict"] 112 | if "frame_id" in sample.keys(): 113 | live_frame_id = sample["frame_id"] 114 | else: 115 | live_frame_id = frame_id 116 | if not cfg.live_mode: 117 | inst = sample["obj"].to(cfg.data_device) 118 | obj_ids = torch.unique(inst) 119 | else: 120 | inst_data_dict = sample["obj"] 121 | obj_ids = inst_data_dict.keys() 122 | # append new frame info to objs in current view 123 | for obj_id in obj_ids: 124 | if obj_id == -1: # unsured area 125 | continue 126 | obj_id = int(obj_id) 127 | # convert inst mask to state 128 | if not cfg.live_mode: 129 | state = torch.zeros_like(inst, dtype=torch.uint8, device=cfg.data_device) 130 | state[inst == obj_id] = 1 131 | state[inst == -1] = 2 132 | else: 133 | inst_mask = inst_data_dict[obj_id].permute(1,0) 134 | label_list = torch.unique(inst_mask).tolist() 135 | state = torch.zeros_like(inst_mask, dtype=torch.uint8, device=cfg.data_device) 136 | state[inst_mask == obj_id] = 1 137 | state[inst_mask == -1] = 2 138 | bbox = bbox_dict[obj_id] 139 | if obj_id in vis_dict.keys(): 140 | scene_obj = vis_dict[obj_id] 141 | scene_obj.append_keyframe(rgb, depth, state, bbox, twc, live_frame_id) 142 | else: # init scene_obj 143 | if len(obj_dict.keys()) >= cfg.max_n_models: 144 | print("models full!!!! current num ", len(obj_dict.keys())) 145 | continue 146 | print("init new obj ", obj_id) 147 | if cfg.do_bg and obj_id == 0: # todo param 148 | scene_bg = sceneObject(cfg, obj_id, rgb, depth, state, bbox, twc, live_frame_id) 149 | # scene_bg.init_obj_center(intrinsic_open3d, depth, state, twc) 150 | optimiser.add_param_group({"params": scene_bg.trainer.fc_occ_map.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay}) 151 | optimiser.add_param_group({"params": scene_bg.trainer.pe.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay}) 152 | vis_dict.update({obj_id: scene_bg}) 153 | else: 154 | scene_obj = sceneObject(cfg, obj_id, rgb, depth, state, bbox, twc, live_frame_id) 155 | # scene_obj.init_obj_center(intrinsic_open3d, depth, state, twc) 156 | obj_dict.update({obj_id: scene_obj}) 157 | vis_dict.update({obj_id: scene_obj}) 158 | # params = [scene_obj.trainer.fc_occ_map.parameters(), scene_obj.trainer.pe.parameters()] 159 | optimiser.add_param_group({"params": scene_obj.trainer.fc_occ_map.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay}) 160 | optimiser.add_param_group({"params": scene_obj.trainer.pe.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay}) 161 | if cfg.training_strategy == "vmap": 162 | update_vmap_model = True 163 | fc_models.append(obj_dict[obj_id].trainer.fc_occ_map) 164 | pe_models.append(obj_dict[obj_id].trainer.pe) 165 | 166 | # ################################### 167 | # # measure trainable params in total 168 | # total_params = 0 169 | # obj_k = obj_dict[obj_id] 170 | # for p in obj_k.trainer.fc_occ_map.parameters(): 171 | # if p.requires_grad: 172 | # total_params += p.numel() 173 | # for p in obj_k.trainer.pe.parameters(): 174 | # if p.requires_grad: 175 | # total_params += p.numel() 176 | # print("total param ", total_params) 177 | 178 | # dynamically add vmap 179 | with performance_measure(f"add vmap"): 180 | if cfg.training_strategy == "vmap" and update_vmap_model == True: 181 | fc_model, fc_param, fc_buffer = utils.update_vmap(fc_models, optimiser) 182 | pe_model, pe_param, pe_buffer = utils.update_vmap(pe_models, optimiser) 183 | update_vmap_model = False 184 | 185 | 186 | ################################################################## 187 | # training data preperation, get training data for all objs 188 | Batch_N_gt_depth = [] 189 | Batch_N_gt_rgb = [] 190 | Batch_N_depth_mask = [] 191 | Batch_N_obj_mask = [] 192 | Batch_N_input_pcs = [] 193 | Batch_N_sampled_z = [] 194 | 195 | with performance_measure(f"Sampling over {len(obj_dict.keys())} objects,"): 196 | if cfg.do_bg and scene_bg is not None: 197 | gt_rgb, gt_depth, valid_depth_mask, obj_mask, input_pcs, sampled_z \ 198 | = scene_bg.get_training_samples(cfg.n_iter_per_frame * cfg.win_size_bg, cfg.n_samples_per_frame_bg, 199 | cam_info.rays_dir_cache) 200 | bg_gt_depth = gt_depth.reshape([gt_depth.shape[0] * gt_depth.shape[1]]) 201 | bg_gt_rgb = gt_rgb.reshape([gt_rgb.shape[0] * gt_rgb.shape[1], gt_rgb.shape[2]]) 202 | bg_valid_depth_mask = valid_depth_mask 203 | bg_obj_mask = obj_mask 204 | bg_input_pcs = input_pcs.reshape( 205 | [input_pcs.shape[0] * input_pcs.shape[1], input_pcs.shape[2], input_pcs.shape[3]]) 206 | bg_sampled_z = sampled_z.reshape([sampled_z.shape[0] * sampled_z.shape[1], sampled_z.shape[2]]) 207 | 208 | for obj_id, obj_k in obj_dict.items(): 209 | gt_rgb, gt_depth, valid_depth_mask, obj_mask, input_pcs, sampled_z \ 210 | = obj_k.get_training_samples(cfg.n_iter_per_frame * cfg.win_size, cfg.n_samples_per_frame, 211 | cam_info.rays_dir_cache) 212 | # merge first two dims, sample_per_frame*num_per_frame 213 | Batch_N_gt_depth.append(gt_depth.reshape([gt_depth.shape[0] * gt_depth.shape[1]])) 214 | Batch_N_gt_rgb.append(gt_rgb.reshape([gt_rgb.shape[0] * gt_rgb.shape[1], gt_rgb.shape[2]])) 215 | Batch_N_depth_mask.append(valid_depth_mask) 216 | Batch_N_obj_mask.append(obj_mask) 217 | Batch_N_input_pcs.append(input_pcs.reshape([input_pcs.shape[0] * input_pcs.shape[1], input_pcs.shape[2], input_pcs.shape[3]])) 218 | Batch_N_sampled_z.append(sampled_z.reshape([sampled_z.shape[0] * sampled_z.shape[1], sampled_z.shape[2]])) 219 | 220 | # # vis sampled points in open3D 221 | # # sampled pcs 222 | # pc = open3d.geometry.PointCloud() 223 | # pc.points = open3d.utility.Vector3dVector(input_pcs.cpu().numpy().reshape(-1,3)) 224 | # open3d.visualization.draw_geometries([pc]) 225 | # rgb_np = rgb.cpu().numpy().astype(np.uint8).transpose(1,0,2) 226 | # # print("rgb ", rgb_np.shape) 227 | # # print(rgb_np) 228 | # # cv2.imshow("rgb", rgb_np) 229 | # # cv2.waitKey(1) 230 | # depth_np = depth.cpu().numpy().astype(np.float32).transpose(1,0) 231 | # twc_np = twc.cpu().numpy() 232 | # rgbd = open3d.geometry.RGBDImage.create_from_color_and_depth( 233 | # open3d.geometry.Image(rgb_np), 234 | # open3d.geometry.Image(depth_np), 235 | # depth_trunc=max_depth, 236 | # depth_scale=1, 237 | # convert_rgb_to_intensity=False, 238 | # ) 239 | # T_CW = np.linalg.inv(twc_np) 240 | # # input image pc 241 | # input_pc = open3d.geometry.PointCloud.create_from_rgbd_image( 242 | # image=rgbd, 243 | # intrinsic=intrinsic_open3d, 244 | # extrinsic=T_CW) 245 | # input_pc.points = open3d.utility.Vector3dVector(np.array(input_pc.points) - obj_k.obj_center.cpu().numpy()) 246 | # open3d.visualization.draw_geometries([pc, input_pc]) 247 | 248 | 249 | #################################################### 250 | # training 251 | assert len(Batch_N_input_pcs) > 0 252 | # move data to GPU (n_obj, n_iter_per_frame, win_size*num_per_frame, 3) 253 | with performance_measure(f"stacking and moving to gpu: "): 254 | 255 | Batch_N_input_pcs = torch.stack(Batch_N_input_pcs).to(cfg.training_device) 256 | Batch_N_gt_depth = torch.stack(Batch_N_gt_depth).to(cfg.training_device) 257 | Batch_N_gt_rgb = torch.stack(Batch_N_gt_rgb).to(cfg.training_device) / 255. # todo 258 | Batch_N_depth_mask = torch.stack(Batch_N_depth_mask).to(cfg.training_device) 259 | Batch_N_obj_mask = torch.stack(Batch_N_obj_mask).to(cfg.training_device) 260 | Batch_N_sampled_z = torch.stack(Batch_N_sampled_z).to(cfg.training_device) 261 | if cfg.do_bg: 262 | bg_input_pcs = bg_input_pcs.to(cfg.training_device) 263 | bg_gt_depth = bg_gt_depth.to(cfg.training_device) 264 | bg_gt_rgb = bg_gt_rgb.to(cfg.training_device) / 255. 265 | bg_valid_depth_mask = bg_valid_depth_mask.to(cfg.training_device) 266 | bg_obj_mask = bg_obj_mask.to(cfg.training_device) 267 | bg_sampled_z = bg_sampled_z.to(cfg.training_device) 268 | 269 | with performance_measure(f"Training over {len(obj_dict.keys())} objects,"): 270 | for iter_step in range(cfg.n_iter_per_frame): 271 | data_idx = slice(iter_step*n_sample_per_step, (iter_step+1)*n_sample_per_step) 272 | batch_input_pcs = Batch_N_input_pcs[:, data_idx, ...] 273 | batch_gt_depth = Batch_N_gt_depth[:, data_idx, ...] 274 | batch_gt_rgb = Batch_N_gt_rgb[:, data_idx, ...] 275 | batch_depth_mask = Batch_N_depth_mask[:, data_idx, ...] 276 | batch_obj_mask = Batch_N_obj_mask[:, data_idx, ...] 277 | batch_sampled_z = Batch_N_sampled_z[:, data_idx, ...] 278 | if cfg.training_strategy == "forloop": 279 | # for loop training 280 | batch_alpha = [] 281 | batch_color = [] 282 | for k, obj_id in enumerate(obj_dict.keys()): 283 | obj_k = obj_dict[obj_id] 284 | embedding_k = obj_k.trainer.pe(batch_input_pcs[k]) 285 | alpha_k, color_k = obj_k.trainer.fc_occ_map(embedding_k) 286 | batch_alpha.append(alpha_k) 287 | batch_color.append(color_k) 288 | 289 | batch_alpha = torch.stack(batch_alpha) 290 | batch_color = torch.stack(batch_color) 291 | elif cfg.training_strategy == "vmap": 292 | # batched training 293 | batch_embedding = vmap(pe_model)(pe_param, pe_buffer, batch_input_pcs) 294 | batch_alpha, batch_color = vmap(fc_model)(fc_param, fc_buffer, batch_embedding) 295 | # print("batch alpha ", batch_alpha.shape) 296 | else: 297 | print("training strategy {} is not implemented ".format(cfg.training_strategy)) 298 | exit(-1) 299 | 300 | 301 | # step loss 302 | # with performance_measure(f"Batch LOSS"): 303 | batch_loss, _ = loss.step_batch_loss(batch_alpha, batch_color, 304 | batch_gt_depth.detach(), batch_gt_rgb.detach(), 305 | batch_obj_mask.detach(), batch_depth_mask.detach(), 306 | batch_sampled_z.detach()) 307 | 308 | if cfg.do_bg: 309 | bg_data_idx = slice(iter_step * n_sample_per_step_bg, (iter_step + 1) * n_sample_per_step_bg) 310 | bg_embedding = scene_bg.trainer.pe(bg_input_pcs[bg_data_idx, ...]) 311 | bg_alpha, bg_color = scene_bg.trainer.fc_occ_map(bg_embedding) 312 | bg_loss, _ = loss.step_batch_loss(bg_alpha[None, ...], bg_color[None, ...], 313 | bg_gt_depth[None, bg_data_idx, ...].detach(), bg_gt_rgb[None, bg_data_idx].detach(), 314 | bg_obj_mask[None, bg_data_idx, ...].detach(), bg_valid_depth_mask[None, bg_data_idx, ...].detach(), 315 | bg_sampled_z[None, bg_data_idx, ...].detach()) 316 | batch_loss += bg_loss 317 | 318 | # with performance_measure(f"Backward"): 319 | if AMP: 320 | scaler.scale(batch_loss).backward() 321 | scaler.step(optimiser) 322 | scaler.update() 323 | else: 324 | batch_loss.backward() 325 | optimiser.step() 326 | optimiser.zero_grad(set_to_none=True) 327 | # print("loss ", batch_loss.item()) 328 | 329 | # update each origin model params 330 | # todo find a better way # https://github.com/pytorch/functorch/issues/280 331 | with performance_measure(f"updating vmap param"): 332 | if cfg.training_strategy == "vmap": 333 | with torch.no_grad(): 334 | for model_id, (obj_id, obj_k) in enumerate(obj_dict.items()): 335 | for i, param in enumerate(obj_k.trainer.fc_occ_map.parameters()): 336 | param.copy_(fc_param[i][model_id]) 337 | for i, param in enumerate(obj_k.trainer.pe.parameters()): 338 | param.copy_(pe_param[i][model_id]) 339 | 340 | 341 | #################################################################### 342 | # live vis mesh 343 | if (((frame_id % cfg.n_vis_iter) == 0 or frame_id == dataset_len-1) or 344 | (cfg.live_mode and time.time()-last_frame_time>cfg.keep_live_time)) and frame_id >= 10: 345 | vis3d.clear_geometries() 346 | for obj_id, obj_k in vis_dict.items(): 347 | bound = obj_k.get_bound(intrinsic_open3d) 348 | if bound is None: 349 | print("get bound failed obj ", obj_id) 350 | continue 351 | adaptive_grid_dim = int(np.minimum(np.max(bound.extent)//cfg.live_voxel_size+1, cfg.grid_dim)) 352 | mesh = obj_k.trainer.meshing(bound, obj_k.obj_center, grid_dim=adaptive_grid_dim) 353 | if mesh is None: 354 | print("meshing failed obj ", obj_id) 355 | continue 356 | 357 | # save to dir 358 | obj_mesh_output = os.path.join(log_dir, "scene_mesh") 359 | os.makedirs(obj_mesh_output, exist_ok=True) 360 | mesh.export(os.path.join(obj_mesh_output, "frame_{}_obj{}.obj".format(frame_id, str(obj_id)))) 361 | 362 | # live vis 363 | open3d_mesh = vis.trimesh_to_open3d(mesh) 364 | vis3d.add_geometry(open3d_mesh) 365 | vis3d.add_geometry(bound) 366 | # update vis3d 367 | vis3d.poll_events() 368 | vis3d.update_renderer() 369 | 370 | if False: # follow cam 371 | cam = view_ctl.convert_to_pinhole_camera_parameters() 372 | T_CW_np = np.linalg.inv(twc.cpu().numpy()) 373 | cam.extrinsic = T_CW_np 374 | view_ctl.convert_from_pinhole_camera_parameters(cam) 375 | vis3d.poll_events() 376 | vis3d.update_renderer() 377 | 378 | with performance_measure("saving ckpt"): 379 | if save_ckpt and ((((frame_id % cfg.n_vis_iter) == 0 or frame_id == dataset_len - 1) or 380 | (cfg.live_mode and time.time() - last_frame_time > cfg.keep_live_time)) and frame_id >= 10): 381 | for obj_id, obj_k in vis_dict.items(): 382 | ckpt_dir = os.path.join(log_dir, "ckpt", str(obj_id)) 383 | os.makedirs(ckpt_dir, exist_ok=True) 384 | bound = obj_k.get_bound(intrinsic_open3d) # update bound 385 | obj_k.save_checkpoints(ckpt_dir, frame_id) 386 | # save current cam pose 387 | cam_dir = os.path.join(log_dir, "cam_pose") 388 | os.makedirs(cam_dir, exist_ok=True) 389 | torch.save({"twc": twc,}, os.path.join(cam_dir, "twc_frame_{}".format(frame_id) + ".pth")) 390 | 391 | 392 | -------------------------------------------------------------------------------- /vmap.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from time import perf_counter_ns 5 | from tqdm import tqdm 6 | import trainer 7 | import open3d 8 | import trimesh 9 | import scipy 10 | from bidict import bidict 11 | import copy 12 | import os 13 | 14 | import utils 15 | 16 | 17 | class performance_measure: 18 | 19 | def __init__(self, name) -> None: 20 | self.name = name 21 | 22 | def __enter__(self): 23 | self.start_time = perf_counter_ns() 24 | 25 | def __exit__(self, type, value, tb): 26 | self.end_time = perf_counter_ns() 27 | self.exec_time = self.end_time - self.start_time 28 | 29 | print(f"{self.name} excution time: {(self.exec_time)/1000000:.2f} ms") 30 | 31 | def origin_dirs_W(T_WC, dirs_C): 32 | 33 | assert T_WC.shape[0] == dirs_C.shape[0] 34 | assert T_WC.shape[1:] == (4, 4) 35 | assert dirs_C.shape[2] == 3 36 | 37 | dirs_W = (T_WC[:, None, :3, :3] @ dirs_C[..., None]).squeeze() 38 | 39 | origins = T_WC[:, :3, -1] 40 | 41 | return origins, dirs_W 42 | 43 | 44 | # @torch.jit.script 45 | def stratified_bins(min_depth, max_depth, n_bins, n_rays, type=torch.float32, device = "cuda:0"): 46 | # type: (Tensor, Tensor, int, int) -> Tensor 47 | 48 | bin_limits_scale = torch.linspace(0, 1, n_bins+1, dtype=type, device=device) 49 | 50 | if not torch.is_tensor(min_depth): 51 | min_depth = torch.ones(n_rays, dtype=type, device=device) * min_depth 52 | 53 | if not torch.is_tensor(max_depth): 54 | max_depth = torch.ones(n_rays, dtype=type, device=device) * max_depth 55 | 56 | depth_range = max_depth - min_depth 57 | 58 | lower_limits_scale = depth_range[..., None] * bin_limits_scale + min_depth[..., None] 59 | lower_limits_scale = lower_limits_scale[:, :-1] 60 | 61 | assert lower_limits_scale.shape == (n_rays, n_bins) 62 | 63 | bin_length_scale = depth_range / n_bins 64 | increments_scale = torch.rand( 65 | n_rays, n_bins, device=device, 66 | dtype=torch.float32) * bin_length_scale[..., None] 67 | 68 | z_vals_scale = lower_limits_scale + increments_scale 69 | 70 | assert z_vals_scale.shape == (n_rays, n_bins) 71 | 72 | return z_vals_scale 73 | 74 | # @torch.jit.script 75 | def normal_bins_sampling(depth, n_bins, n_rays, delta, device = "cuda:0"): 76 | # type: (Tensor, int, int, float) -> Tensor 77 | 78 | # device = "cpu" 79 | # bins = torch.normal(0.0, delta / 3., size=[n_rays, n_bins], devi 80 | # self.keyframes_batch = torch.empty(self.n_keyframes,ce=device).sort().values 81 | bins = torch.empty(n_rays, n_bins, dtype=torch.float32, device=device).normal_(mean=0.,std=delta / 3.).sort().values 82 | bins = torch.clip(bins, -delta, delta) 83 | z_vals = depth[:, None] + bins 84 | 85 | assert z_vals.shape == (n_rays, n_bins) 86 | 87 | return z_vals 88 | 89 | 90 | class sceneObject: 91 | """ 92 | object instance mapping, 93 | updating keyframes, get training samples, optimizing MLP map 94 | """ 95 | 96 | def __init__(self, cfg, obj_id, rgb:torch.tensor, depth:torch.tensor, mask:torch.tensor, bbox_2d:torch.tensor, t_wc:torch.tensor, live_frame_id) -> None: 97 | self.do_bg = cfg.do_bg 98 | self.obj_id = obj_id 99 | self.data_device = cfg.data_device 100 | self.training_device = cfg.training_device 101 | 102 | assert rgb.shape[:2] == depth.shape 103 | assert rgb.shape[:2] == mask.shape 104 | assert bbox_2d.shape == (4,) 105 | assert t_wc.shape == (4, 4,) 106 | 107 | if self.do_bg and self.obj_id == 0: # do seperate bg 108 | self.obj_scale = cfg.bg_scale 109 | self.hidden_feature_size = cfg.hidden_feature_size_bg 110 | self.n_bins_cam2surface = cfg.n_bins_cam2surface_bg 111 | self.keyframe_step = cfg.keyframe_step_bg 112 | else: 113 | self.obj_scale = cfg.obj_scale 114 | self.hidden_feature_size = cfg.hidden_feature_size 115 | self.n_bins_cam2surface = cfg.n_bins_cam2surface 116 | self.keyframe_step = cfg.keyframe_step 117 | 118 | self.frames_width = rgb.shape[0] 119 | self.frames_height = rgb.shape[1] 120 | 121 | self.min_bound = cfg.min_depth 122 | self.max_bound = cfg.max_depth 123 | self.n_bins = cfg.n_bins 124 | self.n_unidir_funcs = cfg.n_unidir_funcs 125 | 126 | self.surface_eps = cfg.surface_eps 127 | self.stop_eps = cfg.stop_eps 128 | 129 | self.n_keyframes = 1 # Number of keyframes 130 | self.kf_pointer = None 131 | self.keyframe_buffer_size = cfg.keyframe_buffer_size 132 | self.kf_id_dict = bidict({live_frame_id:0}) 133 | self.kf_buffer_full = False 134 | self.frame_cnt = 0 # number of frames taken in 135 | self.lastest_kf_queue = [] 136 | 137 | self.bbox = torch.empty( # obj bounding bounding box in the frame 138 | self.keyframe_buffer_size, 139 | 4, 140 | device=self.data_device) # [u low, u high, v low, v high] 141 | self.bbox[0] = bbox_2d 142 | 143 | # RGB + pixel state batch 144 | self.rgb_idx = slice(0, 3) 145 | self.state_idx = slice(3, 4) 146 | self.rgbs_batch = torch.empty(self.keyframe_buffer_size, 147 | self.frames_width, 148 | self.frames_height, 149 | 4, 150 | dtype=torch.uint8, 151 | device=self.data_device) 152 | 153 | # Pixel states: 154 | self.other_obj = 0 # pixel doesn't belong to obj 155 | self.this_obj = 1 # pixel belong to obj 156 | self.unknown_obj = 2 # pixel state is unknown 157 | 158 | # Initialize first frame rgb and pixel state 159 | self.rgbs_batch[0, :, :, self.rgb_idx] = rgb 160 | self.rgbs_batch[0, :, :, self.state_idx] = mask[..., None] 161 | 162 | self.depth_batch = torch.empty(self.keyframe_buffer_size, 163 | self.frames_width, 164 | self.frames_height, 165 | dtype=torch.float32, 166 | device=self.data_device) 167 | 168 | # Initialize first frame's depth 169 | self.depth_batch[0] = depth 170 | self.t_wc_batch = torch.empty( 171 | self.keyframe_buffer_size, 4, 4, 172 | dtype=torch.float32, 173 | device=self.data_device) # world to camera transform 174 | 175 | # Initialize first frame's world2cam transform 176 | self.t_wc_batch[0] = t_wc 177 | 178 | # neural field map 179 | trainer_cfg = copy.deepcopy(cfg) 180 | trainer_cfg.obj_id = self.obj_id 181 | trainer_cfg.hidden_feature_size = self.hidden_feature_size 182 | trainer_cfg.obj_scale = self.obj_scale 183 | self.trainer = trainer.Trainer(trainer_cfg) 184 | 185 | # 3D boundary 186 | self.bbox3d = None 187 | self.pc = [] 188 | 189 | # init obj local frame 190 | # self.obj_center = self.init_obj_center(intrinsic, depth, mask, t_wc) 191 | self.obj_center = torch.tensor(0.0) # shouldn't make any difference because of frequency embedding 192 | 193 | 194 | def init_obj_center(self, intrinsic_open3d, depth, mask, t_wc): 195 | obj_depth = depth.cpu().clone() 196 | obj_depth[mask!=self.this_obj] = 0 197 | T_CW = np.linalg.inv(t_wc.cpu().numpy()) 198 | pc_obj_init = open3d.geometry.PointCloud.create_from_depth_image( 199 | depth=open3d.geometry.Image(np.asarray(obj_depth.permute(1,0).numpy(), order="C")), 200 | intrinsic=intrinsic_open3d, 201 | extrinsic=T_CW, 202 | depth_trunc=self.max_bound, 203 | depth_scale=1.0) 204 | obj_center = torch.from_numpy(np.mean(pc_obj_init.points, axis=0)).float() 205 | return obj_center 206 | 207 | # @profile 208 | def append_keyframe(self, rgb:torch.tensor, depth:torch.tensor, mask:torch.tensor, bbox_2d:torch.tensor, t_wc:torch.tensor, frame_id:np.uint8=1): 209 | assert rgb.shape[:2] == depth.shape 210 | assert rgb.shape[:2] == mask.shape 211 | assert bbox_2d.shape == (4,) 212 | assert t_wc.shape == (4, 4,) 213 | assert self.n_keyframes <= self.keyframe_buffer_size - 1 214 | assert rgb.dtype == torch.uint8 215 | assert mask.dtype == torch.uint8 216 | assert depth.dtype == torch.float32 217 | 218 | # every kf_step choose one kf 219 | is_kf = (self.frame_cnt % self.keyframe_step == 0) or self.n_keyframes == 1 220 | # print("---------------------") 221 | # print("self.kf_id_dict ", self.kf_id_dict) 222 | # print("live frame id ", frame_id) 223 | # print("n_frames ", self.n_keyframes) 224 | if self.n_keyframes == self.keyframe_buffer_size - 1: # kf buffer full, need to prune 225 | self.kf_buffer_full = True 226 | if self.kf_pointer is None: 227 | self.kf_pointer = self.n_keyframes 228 | 229 | self.rgbs_batch[self.kf_pointer, :, :, self.rgb_idx] = rgb 230 | self.rgbs_batch[self.kf_pointer, :, :, self.state_idx] = mask[..., None] 231 | self.depth_batch[self.kf_pointer, ...] = depth 232 | self.t_wc_batch[self.kf_pointer, ...] = t_wc 233 | self.bbox[self.kf_pointer, ...] = bbox_2d 234 | self.kf_id_dict.inv[self.kf_pointer] = frame_id 235 | 236 | if is_kf: 237 | self.lastest_kf_queue.append(self.kf_pointer) 238 | pruned_frame_id, pruned_kf_id = self.prune_keyframe() 239 | self.kf_pointer = pruned_kf_id 240 | print("pruned kf id ", self.kf_pointer) 241 | 242 | else: 243 | if not is_kf: # not kf, replace 244 | self.rgbs_batch[self.n_keyframes-1, :, :, self.rgb_idx] = rgb 245 | self.rgbs_batch[self.n_keyframes-1, :, :, self.state_idx] = mask[..., None] 246 | self.depth_batch[self.n_keyframes-1, ...] = depth 247 | self.t_wc_batch[self.n_keyframes-1, ...] = t_wc 248 | self.bbox[self.n_keyframes-1, ...] = bbox_2d 249 | self.kf_id_dict.inv[self.n_keyframes-1] = frame_id 250 | else: # is kf, add new kf 251 | self.kf_id_dict[frame_id] = self.n_keyframes 252 | self.rgbs_batch[self.n_keyframes, :, :, self.rgb_idx] = rgb 253 | self.rgbs_batch[self.n_keyframes, :, :, self.state_idx] = mask[..., None] 254 | self.depth_batch[self.n_keyframes, ...] = depth 255 | self.t_wc_batch[self.n_keyframes, ...] = t_wc 256 | self.bbox[self.n_keyframes, ...] = bbox_2d 257 | self.lastest_kf_queue.append(self.n_keyframes) 258 | self.n_keyframes += 1 259 | 260 | # print("self.kf_id_dic ", self.kf_id_dict) 261 | self.frame_cnt += 1 262 | if len(self.lastest_kf_queue) > 2: # keep latest two frames 263 | self.lastest_kf_queue = self.lastest_kf_queue[-2:] 264 | 265 | def prune_keyframe(self): 266 | # simple strategy to prune, randomly choose 267 | key, value = random.choice(list(self.kf_id_dict.items())[:-2]) # do not prune latest two frames 268 | return key, value 269 | 270 | def get_bound(self, intrinsic_open3d): 271 | # get 3D boundary from posed depth img todo update sparse pc when append frame 272 | pcs = open3d.geometry.PointCloud() 273 | for kf_id in range(self.n_keyframes): 274 | mask = self.rgbs_batch[kf_id, :, :, self.state_idx].squeeze() == self.this_obj 275 | depth = self.depth_batch[kf_id].cpu().clone() 276 | twc = self.t_wc_batch[kf_id].cpu().numpy() 277 | depth[~mask] = 0 278 | depth = depth.permute(1,0).numpy().astype(np.float32) 279 | T_CW = np.linalg.inv(twc) 280 | pc = open3d.geometry.PointCloud.create_from_depth_image(depth=open3d.geometry.Image(np.asarray(depth, order="C")), intrinsic=intrinsic_open3d, extrinsic=T_CW) 281 | # self.pc += pc 282 | pcs += pc 283 | 284 | # # get minimal oriented 3d bbox 285 | # try: 286 | # bbox3d = open3d.geometry.OrientedBoundingBox.create_from_points(pcs.points) 287 | # except RuntimeError: 288 | # print("too few pcs obj ") 289 | # return None 290 | # trimesh has a better minimal bbox implementation than open3d 291 | try: 292 | transform, extents = trimesh.bounds.oriented_bounds(np.array(pcs.points)) # pc 293 | transform = np.linalg.inv(transform) 294 | except scipy.spatial._qhull.QhullError: 295 | print("too few pcs obj ") 296 | return None 297 | 298 | for i in range(extents.shape[0]): 299 | extents[i] = np.maximum(extents[i], 0.10) # at least rendering 10cm 300 | bbox = utils.BoundingBox() 301 | bbox.center = transform[:3, 3] 302 | bbox.R = transform[:3, :3] 303 | bbox.extent = extents 304 | bbox3d = open3d.geometry.OrientedBoundingBox(bbox.center, bbox.R, bbox.extent) 305 | 306 | min_extent = 0.05 307 | bbox3d.extent = np.maximum(min_extent, bbox3d.extent) 308 | bbox3d.color = (255,0,0) 309 | self.bbox3d = utils.bbox_open3d2bbox(bbox_o3d=bbox3d) 310 | # self.pc = [] 311 | print("obj ", self.obj_id) 312 | print("bound ", bbox3d) 313 | print("kf id dict ", self.kf_id_dict) 314 | # open3d.visualization.draw_geometries([bbox3d, pcs]) 315 | return bbox3d 316 | 317 | 318 | 319 | def get_training_samples(self, n_frames, n_samples, cached_rays_dir): 320 | # Sample pixels 321 | if self.n_keyframes > 2: # make sure latest 2 frames are sampled todo if kf pruned, this is not the latest frame 322 | keyframe_ids = torch.randint(low=0, 323 | high=self.n_keyframes, 324 | size=(n_frames - 2,), 325 | dtype=torch.long, 326 | device=self.data_device) 327 | # if self.kf_buffer_full: 328 | # latest_frame_ids = list(self.kf_id_dict.values())[-2:] 329 | latest_frame_ids = self.lastest_kf_queue[-2:] 330 | keyframe_ids = torch.cat([keyframe_ids, 331 | torch.tensor(latest_frame_ids, device=keyframe_ids.device)]) 332 | # print("latest_frame_ids", latest_frame_ids) 333 | # else: # sample last 2 frames 334 | # keyframe_ids = torch.cat([keyframe_ids, 335 | # torch.tensor([self.n_keyframes-2, self.n_keyframes-1], device=keyframe_ids.device)]) 336 | else: 337 | keyframe_ids = torch.randint(low=0, 338 | high=self.n_keyframes, 339 | size=(n_frames,), 340 | dtype=torch.long, 341 | device=self.data_device) 342 | keyframe_ids = torch.unsqueeze(keyframe_ids, dim=-1) 343 | idx_w = torch.rand(n_frames, n_samples, device=self.data_device) 344 | idx_h = torch.rand(n_frames, n_samples, device=self.data_device) 345 | 346 | # resizing idx_w and idx_h to be in the bbox range 347 | idx_w = idx_w * (self.bbox[keyframe_ids, 1] - self.bbox[keyframe_ids, 0]) + self.bbox[keyframe_ids, 0] 348 | idx_h = idx_h * (self.bbox[keyframe_ids, 3] - self.bbox[keyframe_ids, 2]) + self.bbox[keyframe_ids, 2] 349 | 350 | idx_w = idx_w.long() 351 | idx_h = idx_h.long() 352 | 353 | sampled_rgbs = self.rgbs_batch[keyframe_ids, idx_w, idx_h] 354 | sampled_depth = self.depth_batch[keyframe_ids, idx_w, idx_h] 355 | 356 | # Get ray directions for sampled pixels 357 | sampled_ray_dirs = cached_rays_dir[idx_w, idx_h] 358 | 359 | # Get sampled keyframe poses 360 | sampled_twc = self.t_wc_batch[keyframe_ids[:, 0], :, :] 361 | 362 | origins, dirs_w = origin_dirs_W(sampled_twc, sampled_ray_dirs) 363 | 364 | return self.sample_3d_points(sampled_rgbs, sampled_depth, origins, dirs_w) 365 | 366 | def sample_3d_points(self, sampled_rgbs, sampled_depth, origins, dirs_w): 367 | """ 368 | 3D sampling strategy 369 | 370 | * For pixels with invalid depth: 371 | - N+M from minimum bound to max (stratified) 372 | 373 | * For pixels with valid depth: 374 | # Pixel belongs to this object 375 | - N from cam to surface (stratified) 376 | - M around surface (stratified/normal) 377 | # Pixel belongs that don't belong to this object 378 | - N from cam to surface (stratified) 379 | - M around surface (stratified) 380 | # Pixel with unknown state 381 | - Do nothing! 382 | """ 383 | 384 | n_bins_cam2surface = self.n_bins_cam2surface 385 | n_bins = self.n_bins 386 | eps = self.surface_eps 387 | other_objs_max_eps = self.stop_eps #0.05 # todo 0.02 388 | # print("max depth ", torch.max(sampled_depth)) 389 | sampled_z = torch.zeros( 390 | sampled_rgbs.shape[0] * sampled_rgbs.shape[1], 391 | n_bins_cam2surface + n_bins, 392 | dtype=self.depth_batch.dtype, 393 | device=self.data_device) # shape (N*n_rays, n_bins_cam2surface + n_bins) 394 | 395 | invalid_depth_mask = (sampled_depth <= self.min_bound).view(-1) 396 | # max_bound = self.max_bound 397 | max_bound = torch.max(sampled_depth) 398 | # sampling for points with invalid depth 399 | invalid_depth_count = invalid_depth_mask.count_nonzero() 400 | if invalid_depth_count: 401 | sampled_z[invalid_depth_mask, :] = stratified_bins( 402 | self.min_bound, max_bound, 403 | n_bins_cam2surface + n_bins, invalid_depth_count, 404 | device=self.data_device) 405 | 406 | # sampling for valid depth rays 407 | valid_depth_mask = ~invalid_depth_mask 408 | valid_depth_count = valid_depth_mask.count_nonzero() 409 | 410 | 411 | if valid_depth_count: 412 | # Sample between min bound and depth for all pixels with valid depth 413 | sampled_z[valid_depth_mask, :n_bins_cam2surface] = stratified_bins( 414 | self.min_bound, sampled_depth.view(-1)[valid_depth_mask]-eps, 415 | n_bins_cam2surface, valid_depth_count, device=self.data_device) 416 | 417 | # sampling around depth for this object 418 | obj_mask = (sampled_rgbs[..., -1] == self.this_obj).view(-1) & valid_depth_mask # todo obj_mask 419 | assert sampled_z.shape[0] == obj_mask.shape[0] 420 | obj_count = obj_mask.count_nonzero() 421 | 422 | if obj_count: 423 | sampling_method = "normal" # stratified or normal 424 | if sampling_method == "stratified": 425 | sampled_z[obj_mask, n_bins_cam2surface:] = stratified_bins( 426 | sampled_depth.view(-1)[obj_mask] - eps, sampled_depth.view(-1)[obj_mask] + eps, 427 | n_bins, obj_count, device=self.data_device) 428 | 429 | elif sampling_method == "normal": 430 | sampled_z[obj_mask, n_bins_cam2surface:] = normal_bins_sampling( 431 | sampled_depth.view(-1)[obj_mask], 432 | n_bins, 433 | obj_count, 434 | delta=eps, 435 | device=self.data_device) 436 | 437 | else: 438 | raise ( 439 | f"sampling method not implemented {sampling_method}, \ 440 | stratified and normal sampling only currenty implemented." 441 | ) 442 | 443 | # sampling around depth of other objects 444 | other_obj_mask = (sampled_rgbs[..., -1] != self.this_obj).view(-1) & valid_depth_mask 445 | other_objs_count = other_obj_mask.count_nonzero() 446 | if other_objs_count: 447 | sampled_z[other_obj_mask, n_bins_cam2surface:] = stratified_bins( 448 | sampled_depth.view(-1)[other_obj_mask] - eps, 449 | sampled_depth.view(-1)[other_obj_mask] + other_objs_max_eps, 450 | n_bins, other_objs_count, device=self.data_device) 451 | 452 | sampled_z = sampled_z.view(sampled_rgbs.shape[0], 453 | sampled_rgbs.shape[1], 454 | -1) # view as (n_rays, n_samples, 10) 455 | input_pcs = origins[..., None, None, :] + (dirs_w[:, :, None, :] * 456 | sampled_z[..., None]) 457 | input_pcs -= self.obj_center 458 | obj_labels = sampled_rgbs[..., -1].view(-1) 459 | return sampled_rgbs[..., :3], sampled_depth, valid_depth_mask, obj_labels, input_pcs, sampled_z 460 | 461 | def save_checkpoints(self, path, epoch): 462 | obj_id = self.obj_id 463 | chechpoint_load_file = (path + "/obj_" + str(obj_id) + "_frame_" + str(epoch) + ".pth") 464 | 465 | torch.save( 466 | { 467 | "epoch": epoch, 468 | "FC_state_dict": self.trainer.fc_occ_map.state_dict(), 469 | "PE_state_dict": self.trainer.pe.state_dict(), 470 | "obj_id": self.obj_id, 471 | "bbox": self.bbox3d, 472 | "obj_scale": self.trainer.obj_scale 473 | }, 474 | chechpoint_load_file, 475 | ) 476 | # optimiser? 477 | 478 | def load_checkpoints(self, ckpt_file): 479 | checkpoint_load_file = (ckpt_file) 480 | if not os.path.exists(checkpoint_load_file): 481 | print("ckpt not exist ", checkpoint_load_file) 482 | return 483 | checkpoint = torch.load(checkpoint_load_file) 484 | self.trainer.fc_occ_map.load_state_dict(checkpoint["FC_state_dict"]) 485 | self.trainer.pe.load_state_dict(checkpoint["PE_state_dict"]) 486 | self.obj_id = checkpoint["obj_id"] 487 | self.bbox3d = checkpoint["bbox"] 488 | self.trainer.obj_scale = checkpoint["obj_scale"] 489 | 490 | self.trainer.fc_occ_map.to(self.training_device) 491 | self.trainer.pe.to(self.training_device) 492 | 493 | 494 | class cameraInfo: 495 | 496 | def __init__(self, cfg) -> None: 497 | self.device = cfg.data_device 498 | self.width = cfg.W # Frame width 499 | self.height = cfg.H # Frame height 500 | 501 | self.fx = cfg.fx 502 | self.fy = cfg.fy 503 | self.cx = cfg.cx 504 | self.cy = cfg.cy 505 | 506 | self.rays_dir_cache = self.get_rays_dirs() 507 | 508 | def get_rays_dirs(self, depth_type="z"): 509 | idx_w = torch.arange(end=self.width, device=self.device) 510 | idx_h = torch.arange(end=self.height, device=self.device) 511 | 512 | dirs = torch.ones((self.width, self.height, 3), device=self.device) 513 | 514 | dirs[:, :, 0] = ((idx_w - self.cx) / self.fx)[:, None] 515 | dirs[:, :, 1] = ((idx_h - self.cy) / self.fy) 516 | 517 | if depth_type == "euclidean": 518 | raise Exception( 519 | "Get camera rays directions with euclidean depth not yet implemented" 520 | ) 521 | norm = torch.norm(dirs, dim=-1) 522 | dirs = dirs * (1. / norm)[:, :, :, None] 523 | 524 | return dirs 525 | 526 | --------------------------------------------------------------------------------