├── media
└── teaser.png
├── data_generation
├── replica_render_config_vMAP.yaml
├── README.md
├── extract_inst_obj.py
├── transformation.py
├── settings.py
└── habitat_renderer.py
├── vis.py
├── metric
├── metrics.py
├── eval_3D_scene.py
└── eval_3D_obj.py
├── image_transforms.py
├── configs
├── ScanNet
│ ├── config_scannet0000_iMAP.json
│ ├── config_scannet0024_iMAP.json
│ ├── config_scannet0000_vMAP.json
│ └── config_scannet0024_vMAP.json
└── Replica
│ ├── config_replica_room0_iMAP.json
│ └── config_replica_room0_vMAP.json
├── .gitignore
├── model.py
├── loss.py
├── environment.yml
├── embedding.py
├── trainer.py
├── cfg.py
├── render_rays.py
├── README.md
├── LICENSE
├── dataset.py
├── utils.py
├── train.py
└── vmap.py
/media/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kxhit/vMAP/HEAD/media/teaser.png
--------------------------------------------------------------------------------
/data_generation/replica_render_config_vMAP.yaml:
--------------------------------------------------------------------------------
1 | # Agent settings
2 | default_agent: 0
3 | gpu_id: 0
4 | width: 1200 #1280
5 | height: 680 #960
6 | sensor_height: 0
7 |
8 | color_sensor: true # RGB sensor
9 | semantic_sensor: true # Semantic sensor
10 | depth_sensor: true # Depth sensor
11 | enable_semantics: true
12 |
13 | # room_0
14 | scene_file: "/home/xin/data/vmap/room_0/habitat/mesh_semantic.ply"
15 | instance2class_mapping: "/home/xin/data/vmap/room_0/habitat/info_semantic.json"
16 | save_path: "/home/xin/data/vmap/room_0/vmap/00/"
17 | pose_file: "/home/xin/data/vmap/room_0/vmap/00/traj_w_c.txt"
18 | ## HDR texture
19 | ## issue https://github.com/facebookresearch/Replica-Dataset/issues/41#issuecomment-566251467
20 | #scene_file: "/home/xin/data/vmap/room_0/mesh.ply"
21 | #instance2class_mapping: "/home/xin/data/vmap/room_0/habitat/info_semantic.json"
22 | #save_path: "/home/xin/data/vmap/room_0/vmap/00/"
23 | #pose_file: "/home/xin/data/vmap/room_0/vmap/00/traj_w_c.txt"
--------------------------------------------------------------------------------
/vis.py:
--------------------------------------------------------------------------------
1 | import skimage.measure
2 | import trimesh
3 | import open3d as o3d
4 | import numpy as np
5 |
6 | def marching_cubes(occupancy, level=0.5):
7 | try:
8 | vertices, faces, vertex_normals, _ = skimage.measure.marching_cubes( #marching_cubes_lewiner( #marching_cubes(
9 | occupancy, level=level, gradient_direction='ascent')
10 | except (RuntimeError, ValueError):
11 | return None
12 |
13 | dim = occupancy.shape[0]
14 | vertices = vertices / (dim - 1)
15 | mesh = trimesh.Trimesh(vertices=vertices,
16 | vertex_normals=vertex_normals,
17 | faces=faces)
18 |
19 | return mesh
20 |
21 | def trimesh_to_open3d(src):
22 | dst = o3d.geometry.TriangleMesh()
23 | dst.vertices = o3d.utility.Vector3dVector(src.vertices)
24 | dst.triangles = o3d.utility.Vector3iVector(src.faces)
25 | vertex_colors = src.visual.vertex_colors[:, :3].astype(np.float) / 255.0
26 | dst.vertex_colors = o3d.utility.Vector3dVector(vertex_colors)
27 | dst.compute_vertex_normals()
28 |
29 | return dst
--------------------------------------------------------------------------------
/metric/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial import cKDTree as KDTree
3 |
4 | def completion_ratio(gt_points, rec_points, dist_th=0.01):
5 | gen_points_kd_tree = KDTree(rec_points)
6 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points)
7 | completion = np.mean((one_distances < dist_th).astype(np.float))
8 | return completion
9 |
10 |
11 | def accuracy(gt_points, rec_points):
12 | gt_points_kd_tree = KDTree(gt_points)
13 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points)
14 | gen_to_gt_chamfer = np.mean(two_distances)
15 | return gen_to_gt_chamfer
16 |
17 |
18 | def completion(gt_points, rec_points):
19 | gt_points_kd_tree = KDTree(rec_points)
20 | one_distances, two_vertex_ids = gt_points_kd_tree.query(gt_points)
21 | gt_to_gen_chamfer = np.mean(one_distances)
22 | return gt_to_gen_chamfer
23 |
24 |
25 | def chamfer(gt_points, rec_points):
26 | # one direction
27 | gen_points_kd_tree = KDTree(rec_points)
28 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points)
29 | gt_to_gen_chamfer = np.mean(one_distances)
30 |
31 | # other direction
32 | gt_points_kd_tree = KDTree(gt_points)
33 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points)
34 | gen_to_gt_chamfer = np.mean(two_distances)
35 |
36 | return (gt_to_gen_chamfer + gen_to_gt_chamfer) / 2.
37 |
38 |
--------------------------------------------------------------------------------
/image_transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | class BGRtoRGB(object):
6 | """bgr format to rgb"""
7 |
8 | def __call__(self, image):
9 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
10 | return image
11 |
12 |
13 | class DepthScale(object):
14 | """scale depth to meters"""
15 |
16 | def __init__(self, scale):
17 | self.scale = scale
18 |
19 | def __call__(self, depth):
20 | depth = depth.astype(np.float32)
21 | return depth * self.scale
22 |
23 |
24 | class DepthFilter(object):
25 | """scale depth to meters"""
26 |
27 | def __init__(self, max_depth):
28 | self.max_depth = max_depth
29 |
30 | def __call__(self, depth):
31 | far_mask = depth > self.max_depth
32 | depth[far_mask] = 0.
33 | return depth
34 |
35 |
36 | class Undistort(object):
37 | """scale depth to meters"""
38 |
39 | def __init__(self,
40 | w, h,
41 | fx, fy, cx, cy,
42 | k1, k2, k3, k4, k5, k6,
43 | p1, p2,
44 | interpolation):
45 | self.interpolation = interpolation
46 | K = np.array([[fx, 0., cx],
47 | [0., fy, cy],
48 | [0., 0., 1.]])
49 |
50 | self.map1x, self.map1y = cv2.initUndistortRectifyMap(
51 | K,
52 | np.array([k1, k2, p1, p2, k3, k4, k5, k6]),
53 | np.eye(3),
54 | K,
55 | (w, h),
56 | cv2.CV_32FC1)
57 |
58 | def __call__(self, im):
59 | im = cv2.remap(im, self.map1x, self.map1y, self.interpolation)
60 | return im
61 |
--------------------------------------------------------------------------------
/data_generation/README.md:
--------------------------------------------------------------------------------
1 | ## Replica Data Generation
2 |
3 | ### Download Replica Dataset
4 | Download 3D models and info files from [Replica](https://github.com/facebookresearch/Replica-Dataset)
5 |
6 | ### 3D Object Mesh Extraction
7 | Change the input path in `./data_generation/extract_inst_obj.py` and run
8 | ```angular2html
9 | python ./data_generation/extract_inst_obj.py
10 | ```
11 |
12 | ### Camera Trajectory Generation
13 | Please refer to [Semantic-NeRF](https://github.com/Harry-Zhi/semantic_nerf/issues/25#issuecomment-1340595427) for more details. The random trajectory generation only works for single room scene. For multiple rooms scene, collision checking is needed. Welcome contributions.
14 |
15 | ### Rendering 2D Images
16 | Given camera trajectory t_wc (change pose_file in configs), we use [Habitat-Sim](https://github.com/facebookresearch/habitat-sim) to render RGB, Depth, Semantic and Instance images.
17 |
18 | #### Install Habitat-Sim 0.2.1
19 | We recommend to use conda to install habitat-sim 0.2.1.
20 | ```angular2html
21 | conda create -n habitat python=3.8.12 cmake=3.14.0
22 | conda activate habitat
23 | conda install habitat-sim=0.2.1 withbullet -c conda-forge -c aihabitat
24 | conda install numba=0.54.1
25 | ```
26 |
27 | #### Run rendering with configs
28 | ```angular2html
29 | python ./data_generation/habitat_renderer.py --config ./data_generation/replica_render_config_vMAP.yaml
30 | ```
31 | Note that to get HDR img, use mesh.ply not semantic_mesh.ply. Change path in configs. And copy rgb folder to replace previous high exposure rgb.
32 | ```angular2html
33 | python ./data_generation/habitat_renderer.py --config ./data_generation/replica_render_config_vMAP.yaml
34 | ```
--------------------------------------------------------------------------------
/data_generation/extract_inst_obj.py:
--------------------------------------------------------------------------------
1 | # reference https://github.com/facebookresearch/Replica-Dataset/issues/17#issuecomment-538757418
2 |
3 | from plyfile import *
4 | import numpy as np
5 | import trimesh
6 |
7 |
8 | # path_in = 'path/to/mesh_semantic.ply'
9 | path_in = '/home/xin/data/vmap/room_0_debug/habitat/mesh_semantic.ply'
10 |
11 | print("Reading input...")
12 | mesh = trimesh.load(path_in)
13 | # mesh.show()
14 | file_in = PlyData.read(path_in)
15 | vertices_in = file_in.elements[0]
16 | faces_in = file_in.elements[1]
17 |
18 | print("Filtering data...")
19 | objects = {}
20 | sub_mesh_indices = {}
21 | for i, f in enumerate(faces_in):
22 | object_id = f[1]
23 | if not object_id in objects:
24 | objects[object_id] = []
25 | sub_mesh_indices[object_id] = []
26 | objects[object_id].append((f[0],))
27 | sub_mesh_indices[object_id].append(i)
28 | sub_mesh_indices[object_id].append(i+faces_in.data.shape[0])
29 |
30 |
31 | print("Writing data...")
32 | for object_id, faces in objects.items():
33 | path_out = path_in + f"_{object_id}.ply"
34 | # print("sub_mesh_indices[object_id] ", sub_mesh_indices[object_id])
35 | obj_mesh = mesh.submesh([sub_mesh_indices[object_id]], append=True)
36 | in_n = len(sub_mesh_indices[object_id])
37 | out_n = obj_mesh.faces.shape[0]
38 | # print("obj id ", object_id)
39 | # print("in_n ", in_n)
40 | # print("out_n ", out_n)
41 | # print("faces ", len(faces))
42 | # assert in_n == out_n
43 | obj_mesh.export(path_out)
44 | # faces_out = PlyElement.describe(np.array(faces, dtype=[('vertex_indices', 'O')]), 'face')
45 | # print("faces out ", len(PlyData([vertices_in, faces_out]).elements[1].data))
46 | # PlyData([vertices_in, faces_out]).write(path_out+"_cmp.ply")
47 |
48 |
--------------------------------------------------------------------------------
/data_generation/transformation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import quaternion
3 | import trimesh
4 |
5 | def habitat_world_transformations():
6 | import habitat_sim
7 | # Transforms between the habitat frame H (y-up) and the world frame W (z-up).
8 | T_wh = np.identity(4)
9 |
10 | # https://stackoverflow.com/questions/1171849/finding-quaternion-representing-the-rotation-from-one-vector-to-another
11 | T_wh[0:3, 0:3] = quaternion.as_rotation_matrix(habitat_sim.utils.common.quat_from_two_vectors(
12 | habitat_sim.geo.GRAVITY, np.array([0.0, 0.0, -1.0])))
13 |
14 | T_hw = np.linalg.inv(T_wh)
15 |
16 | return T_wh, T_hw
17 |
18 | def opencv_to_opengl_camera(transform=None):
19 | if transform is None:
20 | transform = np.eye(4)
21 | return transform @ trimesh.transformations.rotation_matrix(
22 | np.deg2rad(180), [1, 0, 0]
23 | )
24 |
25 | def opengl_to_opencv_camera(transform=None):
26 | if transform is None:
27 | transform = np.eye(4)
28 | return transform @ trimesh.transformations.rotation_matrix(
29 | np.deg2rad(-180), [1, 0, 0]
30 | )
31 |
32 | def Twc_to_Thc(T_wc): # opencv-camera to world transformation ---> habitat-caemra to habitat world transformation
33 | T_wh, T_hw = habitat_world_transformations()
34 | T_hc = T_hw @ T_wc @ opengl_to_opencv_camera()
35 | return T_hc
36 |
37 |
38 | def Thc_to_Twc(T_hc): # habitat-caemra to habitat world transformation ---> opencv-camera to world transformation
39 | T_wh, T_hw = habitat_world_transformations()
40 | T_wc = T_wh @ T_hc @ opencv_to_opengl_camera()
41 | return T_wc
42 |
43 |
44 | def combine_pose(t: np.array, q: quaternion.quaternion) -> np.array:
45 | T = np.identity(4)
46 | T[0:3, 3] = t
47 | T[0:3, 0:3] = quaternion.as_rotation_matrix(q)
48 | return T
--------------------------------------------------------------------------------
/configs/ScanNet/config_scannet0000_iMAP.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "live": 0,
4 | "path": "/home/xin/data/ScanNet/NICESLAM/scene0000_00",
5 | "format": "ScanNet",
6 | "keep_alive": 20
7 | },
8 | "optimizer": {
9 | "args":{
10 | "lr": 0.001,
11 | "weight_decay": 0.013,
12 | "pose_lr": 0.001
13 | }
14 | },
15 | "trainer": {
16 | "imap_mode": 1,
17 | "do_bg": 0,
18 | "n_models": 1,
19 | "train_device": "cuda:0",
20 | "data_device": "cuda:0",
21 | "training_strategy": "vmap",
22 | "epochs": 1000000,
23 | "scale": 1000.0
24 | },
25 | "render": {
26 | "depth_range": [0.0, 6.0],
27 | "n_bins": 9,
28 | "n_bins_cam2surface": 5,
29 | "n_bins_cam2surface_bg": 5,
30 | "iters_per_frame": 20,
31 | "n_per_optim": 2400,
32 | "n_per_optim_bg": 1200
33 | },
34 | "model": {
35 | "n_unidir_funcs": 5,
36 | "obj_scale": 3.0,
37 | "bg_scale": 15.0,
38 | "color_scaling": 5.0,
39 | "opacity_scaling": 10.0,
40 | "gt_scene": 1,
41 | "surface_eps": 0.1,
42 | "other_eps": 0.05,
43 | "keyframe_buffer_size": 20,
44 | "keyframe_step": 50,
45 | "keyframe_step_bg": 50,
46 | "window_size": 5,
47 | "window_size_bg": 10,
48 | "hidden_layers_block": 1,
49 | "hidden_feature_size": 256,
50 | "hidden_feature_size_bg": 128
51 | },
52 | "camera": {
53 | "w": 640,
54 | "h": 480,
55 | "mw": 10,
56 | "mh": 10
57 | },
58 | "vis": {
59 | "vis_device": "cuda:0",
60 | "n_vis_iter": 500,
61 | "n_bins_fine_vis": 10,
62 | "im_vis_reduce": 10,
63 | "grid_dim": 256,
64 | "live_vis": 1,
65 | "live_voxel_size": 0.005
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/configs/ScanNet/config_scannet0024_iMAP.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "live": 0,
4 | "path": "/home/xin/data/ScanNet/obj-imap/scene0024_00",
5 | "format": "ScanNet",
6 | "keep_alive": 20
7 | },
8 | "optimizer": {
9 | "args":{
10 | "lr": 0.001,
11 | "weight_decay": 0.013,
12 | "pose_lr": 0.001
13 | }
14 | },
15 | "trainer": {
16 | "imap_mode": 1,
17 | "do_bg": 0,
18 | "n_models": 1,
19 | "train_device": "cuda:0",
20 | "data_device": "cuda:0",
21 | "training_strategy": "vmap",
22 | "epochs": 1000000,
23 | "scale": 1000.0
24 | },
25 | "render": {
26 | "depth_range": [0.0, 6.0],
27 | "n_bins": 9,
28 | "n_bins_cam2surface": 5,
29 | "n_bins_cam2surface_bg": 5,
30 | "iters_per_frame": 20,
31 | "n_per_optim": 2400,
32 | "n_per_optim_bg": 1200
33 | },
34 | "model": {
35 | "n_unidir_funcs": 5,
36 | "obj_scale": 2.0,
37 | "bg_scale": 5.0,
38 | "color_scaling": 5.0,
39 | "opacity_scaling": 10.0,
40 | "gt_scene": 1,
41 | "surface_eps": 0.1,
42 | "other_eps": 0.05,
43 | "keyframe_buffer_size": 20,
44 | "keyframe_step": 50,
45 | "keyframe_step_bg": 50,
46 | "window_size": 5,
47 | "window_size_bg": 10,
48 | "hidden_layers_block": 1,
49 | "hidden_feature_size": 256,
50 | "hidden_feature_size_bg": 128
51 | },
52 | "camera": {
53 | "w": 640,
54 | "h": 480,
55 | "mw": 10,
56 | "mh": 10
57 | },
58 | "vis": {
59 | "vis_device": "cuda:0",
60 | "n_vis_iter": 500,
61 | "n_bins_fine_vis": 10,
62 | "im_vis_reduce": 10,
63 | "grid_dim": 256,
64 | "live_vis": 1,
65 | "live_voxel_size": 0.005
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/configs/ScanNet/config_scannet0000_vMAP.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "live": 0,
4 | "path": "/home/xin/data/ScanNet/NICESLAM/scene0000_00",
5 | "format": "ScanNet",
6 | "keep_alive": 20
7 | },
8 | "optimizer": {
9 | "args":{
10 | "lr": 0.001,
11 | "weight_decay": 0.013,
12 | "pose_lr": 0.001
13 | }
14 | },
15 | "trainer": {
16 | "imap_mode": 0,
17 | "do_bg": 1,
18 | "n_models": 100,
19 | "train_device": "cuda:0",
20 | "data_device": "cuda:0",
21 | "training_strategy": "vmap",
22 | "epochs": 1000000,
23 | "scale": 1000.0
24 | },
25 | "render": {
26 | "depth_range": [0.0, 6.0],
27 | "n_bins": 9,
28 | "n_bins_cam2surface": 1,
29 | "n_bins_cam2surface_bg": 5,
30 | "iters_per_frame": 20,
31 | "n_per_optim": 120,
32 | "n_per_optim_bg": 1200
33 | },
34 | "model": {
35 | "n_unidir_funcs": 5,
36 | "obj_scale": 3.0,
37 | "bg_scale": 10.0,
38 | "color_scaling": 5.0,
39 | "opacity_scaling": 10.0,
40 | "gt_scene": 1,
41 | "surface_eps": 0.1,
42 | "other_eps": 0.05,
43 | "keyframe_buffer_size": 20,
44 | "keyframe_step": 25,
45 | "keyframe_step_bg": 50,
46 | "window_size": 5,
47 | "window_size_bg": 10,
48 | "hidden_layers_block": 1,
49 | "hidden_feature_size": 32,
50 | "hidden_feature_size_bg": 128
51 | },
52 | "camera": {
53 | "w": 640,
54 | "h": 480,
55 | "mw": 10,
56 | "mh": 10
57 | },
58 | "vis": {
59 | "vis_device": "cuda:0",
60 | "n_vis_iter": 10000000,
61 | "n_bins_fine_vis": 10,
62 | "im_vis_reduce": 10,
63 | "grid_dim": 256,
64 | "live_vis": 1,
65 | "live_voxel_size": 0.005
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/configs/ScanNet/config_scannet0024_vMAP.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "live": 0,
4 | "path": "/home/xin/data/ScanNet/obj-imap/scene0024_00",
5 | "format": "ScanNet",
6 | "keep_alive": 20
7 | },
8 | "optimizer": {
9 | "args":{
10 | "lr": 0.001,
11 | "weight_decay": 0.013,
12 | "pose_lr": 0.001
13 | }
14 | },
15 | "trainer": {
16 | "imap_mode": 0,
17 | "do_bg": 1,
18 | "n_models": 100,
19 | "train_device": "cuda:0",
20 | "data_device": "cuda:0",
21 | "training_strategy": "vmap",
22 | "epochs": 1000000,
23 | "scale": 1000.0
24 | },
25 | "render": {
26 | "depth_range": [0.0, 6.0],
27 | "n_bins": 9,
28 | "n_bins_cam2surface": 1,
29 | "n_bins_cam2surface_bg": 5,
30 | "iters_per_frame": 20,
31 | "n_per_optim": 120,
32 | "n_per_optim_bg": 1200
33 | },
34 | "model": {
35 | "n_unidir_funcs": 5,
36 | "obj_scale": 3.0,
37 | "bg_scale": 10.0,
38 | "color_scaling": 5.0,
39 | "opacity_scaling": 10.0,
40 | "gt_scene": 1,
41 | "surface_eps": 0.1,
42 | "other_eps": 0.05,
43 | "keyframe_buffer_size": 20,
44 | "keyframe_step": 25,
45 | "keyframe_step_bg": 50,
46 | "window_size": 5,
47 | "window_size_bg": 10,
48 | "hidden_layers_block": 1,
49 | "hidden_feature_size": 32,
50 | "hidden_feature_size_bg": 128
51 | },
52 | "camera": {
53 | "w": 640,
54 | "h": 480,
55 | "mw": 10,
56 | "mh": 10
57 | },
58 | "vis": {
59 | "vis_device": "cuda:0",
60 | "n_vis_iter": 10000000,
61 | "n_bins_fine_vis": 10,
62 | "im_vis_reduce": 10,
63 | "grid_dim": 256,
64 | "live_vis": 1,
65 | "live_voxel_size": 0.005
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/configs/Replica/config_replica_room0_iMAP.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "live": 0,
4 | "path": "/home/xin/data/Replica/vmap/room_0/imap/00",
5 | "format": "Replica",
6 | "keep_alive": 20
7 | },
8 | "optimizer": {
9 | "args":{
10 | "lr": 0.001,
11 | "weight_decay": 0.013,
12 | "pose_lr": 0.001
13 | }
14 | },
15 | "trainer": {
16 | "imap_mode": 1,
17 | "do_bg": 0,
18 | "n_models": 1,
19 | "train_device": "cuda:0",
20 | "data_device": "cuda:0",
21 | "training_strategy": "vmap",
22 | "epochs": 1000000,
23 | "scale": 1000.0
24 | },
25 | "render": {
26 | "depth_range": [0.0, 8.0],
27 | "n_bins": 9,
28 | "n_bins_cam2surface": 5,
29 | "n_bins_cam2surface_bg": 5,
30 | "iters_per_frame": 20,
31 | "n_per_optim": 4800,
32 | "n_per_optim_bg": 1200
33 | },
34 | "model": {
35 | "n_unidir_funcs": 5,
36 | "obj_scale": 5.0,
37 | "bg_scale": 5.0,
38 | "color_scaling": 5.0,
39 | "opacity_scaling": 10.0,
40 | "gt_scene": 1,
41 | "surface_eps": 0.1,
42 | "other_eps": 0.05,
43 | "keyframe_buffer_size": 20,
44 | "keyframe_step": 50,
45 | "keyframe_step_bg": 50,
46 | "window_size": 5,
47 | "window_size_bg": 10,
48 | "hidden_layers_block": 1,
49 | "hidden_feature_size": 256,
50 | "hidden_feature_size_bg": 128
51 | },
52 | "camera": {
53 | "w": 1200,
54 | "h": 680,
55 | "fx": 600.0,
56 | "fy": 600.0,
57 | "cx": 599.5,
58 | "cy": 339.5,
59 | "mw": 0,
60 | "mh": 0
61 | },
62 | "vis": {
63 | "vis_device": "cuda:0",
64 | "n_vis_iter": 500,
65 | "n_bins_fine_vis": 10,
66 | "im_vis_reduce": 10,
67 | "grid_dim": 256,
68 | "live_vis": 1,
69 | "live_voxel_size": 0.005
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/configs/Replica/config_replica_room0_vMAP.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "live": 0,
4 | "path": "/home/xin/data/Replica/vmap/room_0/imap/00",
5 | "format": "Replica",
6 | "keep_alive": 20
7 | },
8 | "optimizer": {
9 | "args":{
10 | "lr": 0.001,
11 | "weight_decay": 0.013,
12 | "pose_lr": 0.001
13 | }
14 | },
15 | "trainer": {
16 | "imap_mode": 0,
17 | "do_bg": 1,
18 | "n_models": 100,
19 | "train_device": "cuda:0",
20 | "data_device": "cuda:0",
21 | "training_strategy": "vmap",
22 | "epochs": 1000000,
23 | "scale": 1000.0
24 | },
25 | "render": {
26 | "depth_range": [0.0, 8.0],
27 | "n_bins": 9,
28 | "n_bins_cam2surface": 1,
29 | "n_bins_cam2surface_bg": 5,
30 | "iters_per_frame": 20,
31 | "n_per_optim": 120,
32 | "n_per_optim_bg": 1200
33 | },
34 | "model": {
35 | "n_unidir_funcs": 5,
36 | "obj_scale": 2.0,
37 | "bg_scale": 5.0,
38 | "color_scaling": 5.0,
39 | "opacity_scaling": 10.0,
40 | "gt_scene": 1,
41 | "surface_eps": 0.1,
42 | "other_eps": 0.05,
43 | "keyframe_buffer_size": 20,
44 | "keyframe_step": 25,
45 | "keyframe_step_bg": 50,
46 | "window_size": 5,
47 | "window_size_bg": 10,
48 | "hidden_layers_block": 1,
49 | "hidden_feature_size": 32,
50 | "hidden_feature_size_bg": 128
51 | },
52 | "camera": {
53 | "w": 1200,
54 | "h": 680,
55 | "fx": 600.0,
56 | "fy": 600.0,
57 | "cx": 599.5,
58 | "cy": 339.5,
59 | "mw": 0,
60 | "mh": 0
61 | },
62 | "vis": {
63 | "vis_device": "cuda:0",
64 | "n_vis_iter": 500,
65 | "n_bins_fine_vis": 10,
66 | "im_vis_reduce": 10,
67 | "grid_dim": 256,
68 | "live_vis": 1,
69 | "live_voxel_size": 0.005
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.anaconda3/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
127 | HierarchicalPriors.sublime-project
128 | HierarchicalPriors.sublime-workspace
129 | scripts/res1/
130 | scripts/results/
131 | scripts/traj.txt
132 | vids_ims
133 | res_datasets/
134 | results/
135 | ScenePriors/train/examples/experiments/
136 |
137 | *.pkl
138 | *.idea
139 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def init_weights(m, init_fn=torch.nn.init.xavier_normal_):
5 | if type(m) == torch.nn.Linear:
6 | init_fn(m.weight)
7 |
8 |
9 | def fc_block(in_f, out_f):
10 | return torch.nn.Sequential(
11 | torch.nn.Linear(in_f, out_f),
12 | torch.nn.ReLU(out_f)
13 | )
14 |
15 |
16 | class OccupancyMap(torch.nn.Module):
17 | def __init__(
18 | self,
19 | emb_size1,
20 | emb_size2,
21 | hidden_size=256,
22 | do_color=True,
23 | hidden_layers_block=1
24 | ):
25 | super(OccupancyMap, self).__init__()
26 | self.do_color = do_color
27 | self.embedding_size1 = emb_size1
28 | self.in_layer = fc_block(self.embedding_size1, hidden_size)
29 |
30 | hidden1 = [fc_block(hidden_size, hidden_size)
31 | for _ in range(hidden_layers_block)]
32 | self.mid1 = torch.nn.Sequential(*hidden1)
33 | # self.embedding_size2 = 21*(5+1)+3 - self.embedding_size # 129-66=63 32
34 | self.embedding_size2 = emb_size2
35 | self.cat_layer = fc_block(
36 | hidden_size + self.embedding_size1, hidden_size)
37 |
38 | # self.cat_layer = fc_block(
39 | # hidden_size , hidden_size)
40 |
41 | hidden2 = [fc_block(hidden_size, hidden_size)
42 | for _ in range(hidden_layers_block)]
43 | self.mid2 = torch.nn.Sequential(*hidden2)
44 |
45 | self.out_alpha = torch.nn.Linear(hidden_size, 1)
46 |
47 | if self.do_color:
48 | self.color_linear = fc_block(self.embedding_size2 + hidden_size, hidden_size)
49 | self.out_color = torch.nn.Linear(hidden_size, 3)
50 |
51 | # self.relu = torch.nn.functional.relu
52 | self.sigmoid = torch.sigmoid
53 |
54 | def forward(self, x,
55 | noise_std=None,
56 | do_alpha=True,
57 | do_color=True,
58 | do_cat=True):
59 | fc1 = self.in_layer(x[...,:self.embedding_size1])
60 | fc2 = self.mid1(fc1)
61 | # fc3 = self.cat_layer(fc2)
62 | if do_cat:
63 | fc2_x = torch.cat((fc2, x[...,:self.embedding_size1]), dim=-1)
64 | fc3 = self.cat_layer(fc2_x)
65 | else:
66 | fc3 = fc2
67 | fc4 = self.mid2(fc3)
68 |
69 | alpha = None
70 | if do_alpha:
71 | raw = self.out_alpha(fc4) # todo ignore noise
72 | if noise_std is not None:
73 | noise = torch.randn(raw.shape, device=x.device) * noise_std
74 | raw = raw + noise
75 |
76 | # alpha = self.relu(raw) * scale # nerf
77 | alpha = raw * 10. #self.scale # unisurf
78 |
79 | color = None
80 | if self.do_color and do_color:
81 | fc4_cat = self.color_linear(torch.cat((fc4, x[..., self.embedding_size1:]), dim=-1))
82 | raw_color = self.out_color(fc4_cat)
83 | color = self.sigmoid(raw_color)
84 |
85 | return alpha, color
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import render_rays
3 | import torch.nn.functional as F
4 |
5 | def step_batch_loss(alpha, color, gt_depth, gt_color, sem_labels, mask_depth, z_vals,
6 | color_scaling=5.0, opacity_scaling=10.0):
7 | """
8 | apply depth where depth are valid -> mask_depth
9 | apply depth, color loss on this_obj & unkown_obj == (~other_obj) -> mask_obj
10 | apply occupancy/opacity loss on this_obj & other_obj == (~unknown_obj) -> mask_sem
11 |
12 | output:
13 | loss for training
14 | loss_all for per sample, could be used for active sampling, replay buffer
15 | """
16 | mask_obj = sem_labels != 0
17 | mask_obj = mask_obj.detach()
18 | mask_sem = sem_labels != 2
19 | mask_sem = mask_sem.detach()
20 |
21 | alpha = alpha.squeeze(dim=-1)
22 | color = color.squeeze(dim=-1)
23 |
24 | occupancy = render_rays.occupancy_activation(alpha)
25 | termination = render_rays.occupancy_to_termination(occupancy, is_batch=True) # shape [num_batch, num_ray, points_per_ray]
26 |
27 | render_depth = render_rays.render(termination, z_vals)
28 | diff_sq = (z_vals - render_depth[..., None]) ** 2
29 | var = render_rays.render(termination, diff_sq).detach() # must detach here!
30 | render_color = render_rays.render(termination[..., None], color, dim=-2)
31 | render_opacity = torch.sum(termination, dim=-1) # similar to obj-nerf opacity loss
32 |
33 | # 2D depth loss: only on valid depth & mask
34 | # [mask_depth & mask_obj]
35 | # loss_all = torch.zeros_like(render_depth)
36 | loss_depth_raw = render_rays.render_loss(render_depth, gt_depth, loss="L1", normalise=False)
37 | loss_depth = torch.mul(loss_depth_raw, mask_depth & mask_obj) # keep dim but set invalid element be zero
38 | # loss_all += loss_depth
39 | loss_depth = render_rays.reduce_batch_loss(loss_depth, var=var, avg=True, mask=mask_depth & mask_obj) # apply var as imap
40 |
41 | # 2D color loss: only on obj mask
42 | # [mask_obj]
43 | loss_col_raw = render_rays.render_loss(render_color, gt_color, loss="L1", normalise=False)
44 | loss_col = torch.mul(loss_col_raw.sum(-1), mask_obj)
45 | # loss_all += loss_col / 3. * color_scaling
46 | loss_col = render_rays.reduce_batch_loss(loss_col, var=None, avg=True, mask=mask_obj)
47 |
48 | # 2D occupancy/opacity loss: apply except unknown area
49 | # [mask_sem]
50 | # loss_opacity_raw = F.mse_loss(torch.clamp(render_opacity, 0, 1), mask_obj.float().detach()) # encourage other_obj to be empty, while this_obj to be solid
51 | # print("opacity max ", torch.max(render_opacity.max()))
52 | # print("opacity min ", torch.max(render_opacity.min()))
53 | loss_opacity_raw = render_rays.render_loss(render_opacity, mask_obj.float(), loss="L1", normalise=False)
54 | loss_opacity = torch.mul(loss_opacity_raw, mask_sem) # but ignore -1 unkown area e.g., mask edges
55 | # loss_all += loss_opacity * opacity_scaling
56 | loss_opacity = render_rays.reduce_batch_loss(loss_opacity, var=None, avg=True, mask=mask_sem) # todo var
57 |
58 | # loss for bp
59 | l_batch = loss_depth + loss_col * color_scaling + loss_opacity * opacity_scaling
60 | loss = l_batch.sum()
61 |
62 | return loss, None # return loss, loss_all.detach()
63 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: vmap
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blas=1.0=mkl
9 | - brotlipy=0.7.0=py38h27cfd23_1003
10 | - bzip2=1.0.8=h7b6447c_0
11 | - ca-certificates=2022.10.11=h06a4308_0
12 | - certifi=2022.9.24=py38h06a4308_0
13 | - cffi=1.15.1=py38h5eee18b_2
14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
15 | - cryptography=38.0.1=py38h9ce1e76_0
16 | - cudatoolkit=11.3.1=h2bc3f7f_2
17 | - ffmpeg=4.3=hf484d3e_0
18 | - freetype=2.12.1=h4a9f257_0
19 | - giflib=5.2.1=h7b6447c_0
20 | - gmp=6.2.1=h295c915_3
21 | - gnutls=3.6.15=he1e5248_0
22 | - idna=3.4=py38h06a4308_0
23 | - intel-openmp=2021.4.0=h06a4308_3561
24 | - jpeg=9e=h7f8727e_0
25 | - lame=3.100=h7b6447c_0
26 | - lcms2=2.12=h3be6417_0
27 | - ld_impl_linux-64=2.38=h1181459_1
28 | - lerc=3.0=h295c915_0
29 | - libdeflate=1.8=h7f8727e_5
30 | - libffi=3.4.2=h6a678d5_6
31 | - libgcc-ng=11.2.0=h1234567_1
32 | - libgomp=11.2.0=h1234567_1
33 | - libiconv=1.16=h7f8727e_2
34 | - libidn2=2.3.2=h7f8727e_0
35 | - libpng=1.6.37=hbc83047_0
36 | - libstdcxx-ng=11.2.0=h1234567_1
37 | - libtasn1=4.16.0=h27cfd23_0
38 | - libtiff=4.4.0=hecacb30_2
39 | - libunistring=0.9.10=h27cfd23_0
40 | - libwebp=1.2.4=h11a3e52_0
41 | - libwebp-base=1.2.4=h5eee18b_0
42 | - lz4-c=1.9.3=h295c915_1
43 | - mkl=2021.4.0=h06a4308_640
44 | - mkl-service=2.4.0=py38h7f8727e_0
45 | - mkl_fft=1.3.1=py38hd3c417c_0
46 | - mkl_random=1.2.2=py38h51133e4_0
47 | - ncurses=6.3=h5eee18b_3
48 | - nettle=3.7.3=hbbd107a_1
49 | - numpy=1.23.4=py38h14f4228_0
50 | - numpy-base=1.23.4=py38h31eccc5_0
51 | - openh264=2.1.1=h4ff587b_0
52 | - openssl=1.1.1s=h7f8727e_0
53 | - pillow=9.2.0=py38hace64e9_1
54 | - pip=22.2.2=py38h06a4308_0
55 | - pycparser=2.21=pyhd3eb1b0_0
56 | - pyopenssl=22.0.0=pyhd3eb1b0_0
57 | - pysocks=1.7.1=py38h06a4308_0
58 | - python=3.8.15=h7a1cb2a_2
59 | - pytorch=1.12.1=py3.8_cuda11.3_cudnn8.3.2_0
60 | - pytorch-mutex=1.0=cuda
61 | - readline=8.2=h5eee18b_0
62 | - requests=2.28.1=py38h06a4308_0
63 | - setuptools=65.5.0=py38h06a4308_0
64 | - six=1.16.0=pyhd3eb1b0_1
65 | - sqlite=3.40.0=h5082296_0
66 | - tk=8.6.12=h1ccaba5_0
67 | - torchaudio=0.12.1=py38_cu113
68 | - torchvision=0.13.1=py38_cu113
69 | - typing_extensions=4.3.0=py38h06a4308_0
70 | - urllib3=1.26.12=py38h06a4308_0
71 | - wheel=0.37.1=pyhd3eb1b0_0
72 | - xz=5.2.6=h5eee18b_0
73 | - zlib=1.2.13=h5eee18b_0
74 | - zstd=1.5.2=ha4553b6_0
75 | - pip:
76 | - antlr4-python3-runtime==4.9.3
77 | - bidict==0.22.0
78 | - click==8.1.3
79 | - dash==2.7.0
80 | - dash-core-components==2.0.0
81 | - dash-html-components==2.0.0
82 | - dash-table==5.0.0
83 | - entrypoints==0.4
84 | - flask==2.2.2
85 | - functorch==0.2.0
86 | - fvcore==0.1.5.post20221122
87 | - h5py==3.7.0
88 | - imageio==2.22.4
89 | - importlib-metadata==5.1.0
90 | - iopath==0.1.10
91 | - itsdangerous==2.1.2
92 | - lpips==0.1.4
93 | - nbformat==5.5.0
94 | - omegaconf==2.2.3
95 | - open3d==0.16.0
96 | - pexpect==4.8.0
97 | - plotly==5.11.0
98 | - pycocotools==2.0.6
99 | - pyquaternion==0.9.9
100 | - scikit-image==0.19.3
101 | - tenacity==8.1.0
102 | - termcolor==2.1.1
103 | - timm==0.6.12
104 | - werkzeug==2.2.2
105 | - yacs==0.1.8
106 | - zipp==3.11.0
107 |
108 |
--------------------------------------------------------------------------------
/embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def positional_encoding(
5 | tensor,
6 | B_layer=None,
7 | num_encoding_functions=6,
8 | scale=10.
9 | ):
10 | if B_layer is not None:
11 | embedding_gauss = B_layer(tensor / scale)
12 | embedding_gauss = torch.sin(embedding_gauss)
13 | embedding = embedding_gauss
14 | else:
15 | frequency_bands = 2.0 ** torch.linspace(
16 | 0.0,
17 | num_encoding_functions - 1,
18 | num_encoding_functions,
19 | dtype=tensor.dtype,
20 | device=tensor.device,
21 | )
22 |
23 | n_repeat = num_encoding_functions * 2 + 1
24 | embedding = tensor[..., None, :].repeat(1, 1, n_repeat, 1) / scale
25 | even_idx = np.arange(1, num_encoding_functions + 1) * 2
26 | odd_idx = even_idx - 1
27 |
28 | frequency_bands = frequency_bands[None, None, :, None]
29 |
30 | embedding[:, :, even_idx, :] = torch.cos(
31 | frequency_bands * embedding[:, :, even_idx, :])
32 | embedding[:, :, odd_idx, :] = torch.sin(
33 | frequency_bands * embedding[:, :, odd_idx, :])
34 |
35 | n_dim = tensor.shape[-1]
36 | embedding = embedding.view(
37 | embedding.shape[0], embedding.shape[1], n_repeat * n_dim)
38 | # print("embedding ", embedding.shape)
39 | embedding = embedding.squeeze(0)
40 |
41 | return embedding
42 |
43 | class UniDirsEmbed(torch.nn.Module):
44 | def __init__(self, min_deg=0, max_deg=2, scale=2.):
45 | super(UniDirsEmbed, self).__init__()
46 | self.min_deg = min_deg
47 | self.max_deg = max_deg
48 | self.n_freqs = max_deg - min_deg + 1
49 | self.tensor_scale = torch.tensor(scale, requires_grad=False)
50 |
51 | dirs = torch.tensor([
52 | 0.8506508, 0, 0.5257311,
53 | 0.809017, 0.5, 0.309017,
54 | 0.5257311, 0.8506508, 0,
55 | 1, 0, 0,
56 | 0.809017, 0.5, -0.309017,
57 | 0.8506508, 0, -0.5257311,
58 | 0.309017, 0.809017, -0.5,
59 | 0, 0.5257311, -0.8506508,
60 | 0.5, 0.309017, -0.809017,
61 | 0, 1, 0,
62 | -0.5257311, 0.8506508, 0,
63 | -0.309017, 0.809017, -0.5,
64 | 0, 0.5257311, 0.8506508,
65 | -0.309017, 0.809017, 0.5,
66 | 0.309017, 0.809017, 0.5,
67 | 0.5, 0.309017, 0.809017,
68 | 0.5, -0.309017, 0.809017,
69 | 0, 0, 1,
70 | -0.5, 0.309017, 0.809017,
71 | -0.809017, 0.5, 0.309017,
72 | -0.809017, 0.5, -0.309017
73 | ]).reshape(-1, 3)
74 |
75 | self.B_layer = torch.nn.Linear(3, 21, bias=False)
76 | self.B_layer.weight.data = dirs
77 |
78 | frequency_bands = 2.0 ** torch.linspace(self.min_deg, self.max_deg, self.n_freqs)
79 | self.register_buffer("frequency_bands", frequency_bands, persistent=False)
80 | self.register_buffer("scale", self.tensor_scale, persistent=True)
81 |
82 | def forward(self, x):
83 | tensor = x / self.scale # functorch needs buffer, otherwise changed
84 | proj = self.B_layer(tensor)
85 | proj_bands = proj[..., None, :] * self.frequency_bands[None, None, :, None]
86 | xb = proj_bands.view(list(proj.shape[:-1]) + [-1])
87 | # embedding = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
88 | embedding = torch.sin(xb * np.pi)
89 | embedding = torch.cat([tensor] + [embedding], dim=-1)
90 | # print("emb size ", embedding.shape)
91 | return embedding
--------------------------------------------------------------------------------
/metric/eval_3D_scene.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import trimesh
4 | from metrics import accuracy, completion, completion_ratio
5 | import os
6 |
7 | def calc_3d_metric(mesh_rec, mesh_gt, N=200000):
8 | """
9 | 3D reconstruction metric.
10 | """
11 | metrics = [[] for _ in range(4)]
12 | rec_pc = trimesh.sample.sample_surface(mesh_rec, N)
13 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0])
14 |
15 | gt_pc = trimesh.sample.sample_surface(mesh_gt, N)
16 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0])
17 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices)
18 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices)
19 | completion_ratio_rec = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.05)
20 | completion_ratio_rec_1 = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.01)
21 |
22 | # accuracy_rec *= 100 # convert to cm
23 | # completion_rec *= 100 # convert to cm
24 | # completion_ratio_rec *= 100 # convert to %
25 | # print('accuracy: ', accuracy_rec)
26 | # print('completion: ', completion_rec)
27 | # print('completion ratio: ', completion_ratio_rec)
28 | # print("completion_ratio_rec_1cm ", completion_ratio_rec_1)
29 | metrics[0].append(accuracy_rec)
30 | metrics[1].append(completion_rec)
31 | metrics[2].append(completion_ratio_rec_1)
32 | metrics[3].append(completion_ratio_rec)
33 | return metrics
34 |
35 |
36 | if __name__ == "__main__":
37 | exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
38 | data_dir = "/home/xin/data/vmap/"
39 | # log_dir = "../logs/iMAP/"
40 | log_dir = "../logs/vMAP/"
41 |
42 | for exp in tqdm(exp_name):
43 | gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat")
44 | exp_dir = os.path.join(log_dir, exp)
45 | mesh_dir = os.path.join(exp_dir, "scene_mesh")
46 | output_path = os.path.join(exp_dir, "eval_mesh")
47 | os.makedirs(output_path, exist_ok=True)
48 | if "vMAP" in exp_dir:
49 | mesh_list = os.listdir(mesh_dir)
50 | if "frame_1999_scene.obj" in mesh_list:
51 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_scene.obj")
52 | else: # compose obj into scene mesh
53 | scene_meshes = []
54 | for f in mesh_list:
55 | _, f_type = os.path.splitext(f)
56 | if f_type == ".obj" or f_type == ".ply":
57 | obj_mesh = trimesh.load(os.path.join(mesh_dir, f))
58 | scene_meshes.append(obj_mesh)
59 | scene_mesh = trimesh.util.concatenate(scene_meshes)
60 | scene_mesh.export(os.path.join(mesh_dir, "frame_1999_scene.obj"))
61 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_scene.obj")
62 | elif "iMAP" in exp_dir: # obj0 is the scene mesh
63 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj0.obj")
64 | else:
65 | print("Not Implement")
66 | exit(-1)
67 | gt_mesh_files = os.listdir(gt_dir)
68 | gt_mesh_file = os.path.join(gt_dir, "../mesh.ply")
69 | mesh_rec = trimesh.load(rec_meshfile)
70 | # mesh_rec.invert() # niceslam mesh face needs invert
71 | metrics_3D = [[] for _ in range(4)]
72 | mesh_gt = trimesh.load(gt_mesh_file)
73 | metrics = calc_3d_metric(mesh_rec, mesh_gt, N=200000) # for objs use 10k, for scene use 200k points
74 | metrics_3D[0].append(metrics[0]) # acc
75 | metrics_3D[1].append(metrics[1]) # comp
76 | metrics_3D[2].append(metrics[2]) # comp ratio 1cm
77 | metrics_3D[3].append(metrics[3]) # comp ratio 5cm
78 | metrics_3D = np.array(metrics_3D)
79 | np.save(output_path + '/metrics_3D_scene.npy', metrics_3D)
80 | print("metrics 3D scene \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n ", metrics_3D.mean(axis=1))
81 | print("-----------------------------------------")
82 | print("finish exp ", exp)
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import model
3 | import embedding
4 | import render_rays
5 | import numpy as np
6 | import vis
7 | from tqdm import tqdm
8 |
9 | class Trainer:
10 | def __init__(self, cfg):
11 | self.obj_id = cfg.obj_id
12 | self.device = cfg.training_device
13 | self.hidden_feature_size = cfg.hidden_feature_size #32 for obj # 256 for iMAP, 128 for seperate bg
14 | self.obj_scale = cfg.obj_scale # 10 for bg and iMAP
15 | self.n_unidir_funcs = cfg.n_unidir_funcs
16 | self.emb_size1 = 21*(3+1)+3
17 | self.emb_size2 = 21*(5+1)+3 - self.emb_size1
18 |
19 | self.load_network()
20 |
21 | if self.obj_id == 0:
22 | self.bound_extent = 0.995
23 | else:
24 | self.bound_extent = 0.9
25 |
26 | def load_network(self):
27 | self.fc_occ_map = model.OccupancyMap(
28 | self.emb_size1,
29 | self.emb_size2,
30 | hidden_size=self.hidden_feature_size
31 | )
32 | self.fc_occ_map.apply(model.init_weights).to(self.device)
33 | self.pe = embedding.UniDirsEmbed(max_deg=self.n_unidir_funcs, scale=self.obj_scale).to(self.device)
34 |
35 | def meshing(self, bound, obj_center, grid_dim=256):
36 | occ_range = [-1., 1.]
37 | range_dist = occ_range[1] - occ_range[0]
38 | scene_scale_np = bound.extent / (range_dist * self.bound_extent)
39 | scene_scale = torch.from_numpy(scene_scale_np).float().to(self.device)
40 | transform_np = np.eye(4, dtype=np.float32)
41 | transform_np[:3, 3] = bound.center
42 | transform_np[:3, :3] = bound.R
43 | # transform_np = np.linalg.inv(transform_np) #
44 | transform = torch.from_numpy(transform_np).to(self.device)
45 | grid_pc = render_rays.make_3D_grid(occ_range=occ_range, dim=grid_dim, device=self.device,
46 | scale=scene_scale, transform=transform).view(-1, 3)
47 | grid_pc -= obj_center.to(grid_pc.device)
48 | ret = self.eval_points(grid_pc)
49 | if ret is None:
50 | return None
51 |
52 | occ, _ = ret
53 | mesh = vis.marching_cubes(occ.view(grid_dim, grid_dim, grid_dim).cpu().numpy())
54 | if mesh is None:
55 | print("marching cube failed")
56 | return None
57 |
58 | # Transform to [-1, 1] range
59 | mesh.apply_translation([-0.5, -0.5, -0.5])
60 | mesh.apply_scale(2)
61 |
62 | # Transform to scene coordinates
63 | mesh.apply_scale(scene_scale_np)
64 | mesh.apply_transform(transform_np)
65 |
66 | vertices_pts = torch.from_numpy(np.array(mesh.vertices)).float().to(self.device)
67 | ret = self.eval_points(vertices_pts)
68 | if ret is None:
69 | return None
70 | _, color = ret
71 | mesh_color = color * 255
72 | vertex_colors = mesh_color.detach().squeeze(0).cpu().numpy().astype(np.uint8)
73 | mesh.visual.vertex_colors = vertex_colors
74 |
75 | return mesh
76 |
77 | def eval_points(self, points, chunk_size=100000):
78 | # 256^3 = 16777216
79 | alpha, color = [], []
80 | n_chunks = int(np.ceil(points.shape[0] / chunk_size))
81 | with torch.no_grad():
82 | for k in tqdm(range(n_chunks)): # 2s/it 1000000 pts
83 | chunk_idx = slice(k * chunk_size, (k + 1) * chunk_size)
84 | embedding_k = self.pe(points[chunk_idx, ...])
85 | alpha_k, color_k = self.fc_occ_map(embedding_k)
86 | alpha.extend(alpha_k.detach().squeeze())
87 | color.extend(color_k.detach().squeeze())
88 | alpha = torch.stack(alpha)
89 | color = torch.stack(color)
90 |
91 | occ = render_rays.occupancy_activation(alpha).detach()
92 | if occ.max() == 0:
93 | print("no occ")
94 | return None
95 | return (occ, color)
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/cfg.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import os
4 | import utils
5 |
6 | class Config:
7 | def __init__(self, config_file):
8 | # setting params
9 | with open(config_file) as json_file:
10 | config = json.load(json_file)
11 |
12 | # training strategy
13 | self.do_bg = bool(config["trainer"]["do_bg"])
14 | self.training_device = config["trainer"]["train_device"]
15 | self.data_device = config["trainer"]["data_device"]
16 | self.max_n_models = config["trainer"]["n_models"]
17 | self.live_mode = bool(config["dataset"]["live"])
18 | self.keep_live_time = config["dataset"]["keep_alive"]
19 | self.imap_mode = config["trainer"]["imap_mode"]
20 | self.training_strategy = config["trainer"]["training_strategy"] # "forloop" "vmap"
21 | self.obj_id = -1
22 |
23 | # dataset setting
24 | self.dataset_format = config["dataset"]["format"]
25 | self.dataset_dir = config["dataset"]["path"]
26 | self.depth_scale = 1 / config["trainer"]["scale"]
27 | # camera setting
28 | self.max_depth = config["render"]["depth_range"][1]
29 | self.min_depth = config["render"]["depth_range"][0]
30 | self.mh = config["camera"]["mh"]
31 | self.mw = config["camera"]["mw"]
32 | self.height = config["camera"]["h"]
33 | self.width = config["camera"]["w"]
34 | self.H = self.height - 2 * self.mh
35 | self.W = self.width - 2 * self.mw
36 | if "fx" in config["camera"]:
37 | self.fx = config["camera"]["fx"]
38 | self.fy = config["camera"]["fy"]
39 | self.cx = config["camera"]["cx"] - self.mw
40 | self.cy = config["camera"]["cy"] - self.mh
41 | else: # for scannet
42 | intrinsic = utils.load_matrix_from_txt(os.path.join(self.dataset_dir, "intrinsic/intrinsic_depth.txt"))
43 | self.fx = intrinsic[0, 0]
44 | self.fy = intrinsic[1, 1]
45 | self.cx = intrinsic[0, 2] - self.mw
46 | self.cy = intrinsic[1, 2] - self.mh
47 | if "distortion" in config["camera"]:
48 | self.distortion_array = np.array(config["camera"]["distortion"])
49 | elif "k1" in config["camera"]:
50 | k1 = config["camera"]["k1"]
51 | k2 = config["camera"]["k2"]
52 | k3 = config["camera"]["k3"]
53 | k4 = config["camera"]["k4"]
54 | k5 = config["camera"]["k5"]
55 | k6 = config["camera"]["k6"]
56 | p1 = config["camera"]["p1"]
57 | p2 = config["camera"]["p2"]
58 | self.distortion_array = np.array([k1, k2, p1, p2, k3, k4, k5, k6])
59 | else:
60 | self.distortion_array = None
61 |
62 | # training setting
63 | self.win_size = config["model"]["window_size"]
64 | self.n_iter_per_frame = config["render"]["iters_per_frame"]
65 | self.n_per_optim = config["render"]["n_per_optim"]
66 | self.n_samples_per_frame = self.n_per_optim // self.win_size
67 | self.win_size_bg = config["model"]["window_size_bg"]
68 | self.n_per_optim_bg = config["render"]["n_per_optim_bg"]
69 | self.n_samples_per_frame_bg = self.n_per_optim_bg // self.win_size_bg
70 | self.keyframe_buffer_size = config["model"]["keyframe_buffer_size"]
71 | self.keyframe_step = config["model"]["keyframe_step"]
72 | self.keyframe_step_bg = config["model"]["keyframe_step_bg"]
73 | self.obj_scale = config["model"]["obj_scale"]
74 | self.bg_scale = config["model"]["bg_scale"]
75 | self.hidden_feature_size = config["model"]["hidden_feature_size"]
76 | self.hidden_feature_size_bg = config["model"]["hidden_feature_size_bg"]
77 | self.n_bins_cam2surface = config["render"]["n_bins_cam2surface"]
78 | self.n_bins_cam2surface_bg = config["render"]["n_bins_cam2surface_bg"]
79 | self.n_bins = config["render"]["n_bins"]
80 | self.n_unidir_funcs = config["model"]["n_unidir_funcs"]
81 | self.surface_eps = config["model"]["surface_eps"]
82 | self.stop_eps = config["model"]["other_eps"]
83 |
84 | # optimizer setting
85 | self.learning_rate = config["optimizer"]["args"]["lr"]
86 | self.weight_decay = config["optimizer"]["args"]["weight_decay"]
87 |
88 | # vis setting
89 | self.vis_device = config["vis"]["vis_device"]
90 | self.n_vis_iter = config["vis"]["n_vis_iter"]
91 | self.live_voxel_size = config["vis"]["live_voxel_size"]
92 | self.grid_dim = config["vis"]["grid_dim"]
--------------------------------------------------------------------------------
/render_rays.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def occupancy_activation(alpha, distances=None):
5 | # occ = 1.0 - torch.exp(-alpha * distances)
6 | occ = torch.sigmoid(alpha) # unisurf
7 |
8 | return occ
9 |
10 | def alpha_to_occupancy(depths, dirs, alpha, add_last=False):
11 | interval_distances = depths[..., 1:] - depths[..., :-1]
12 | if add_last:
13 | last_distance = torch.empty(
14 | (depths.shape[0], 1),
15 | device=depths.device,
16 | dtype=depths.dtype).fill_(0.1)
17 | interval_distances = torch.cat(
18 | [interval_distances, last_distance], dim=-1)
19 |
20 | dirs_norm = torch.norm(dirs, dim=-1)
21 | interval_distances = interval_distances * dirs_norm[:, None]
22 | occ = occupancy_activation(alpha, interval_distances)
23 |
24 | return occ
25 |
26 | def occupancy_to_termination(occupancy, is_batch=False):
27 | if is_batch:
28 | first = torch.ones(list(occupancy.shape[:2]) + [1], device=occupancy.device)
29 | free_probs = (1. - occupancy + 1e-10)[:, :, :-1]
30 | else:
31 | first = torch.ones([occupancy.shape[0], 1], device=occupancy.device)
32 | free_probs = (1. - occupancy + 1e-10)[:, :-1]
33 | free_probs = torch.cat([first, free_probs], dim=-1)
34 | term_probs = occupancy * torch.cumprod(free_probs, dim=-1)
35 |
36 | # using escape probability
37 | # occupancy = occupancy[:, :-1]
38 | # first = torch.ones([occupancy.shape[0], 1], device=occupancy.device)
39 | # free_probs = (1. - occupancy + 1e-10)
40 | # free_probs = torch.cat([first, free_probs], dim=-1)
41 | # last = torch.ones([occupancy.shape[0], 1], device=occupancy.device)
42 | # occupancy = torch.cat([occupancy, last], dim=-1)
43 | # term_probs = occupancy * torch.cumprod(free_probs, dim=-1)
44 |
45 | return term_probs
46 |
47 | def render(termination, vals, dim=-1):
48 | weighted_vals = termination * vals
49 | render = weighted_vals.sum(dim=dim)
50 |
51 | return render
52 |
53 | def render_loss(render, gt, loss="L1", normalise=False):
54 | residual = render - gt
55 | if loss == "L2":
56 | loss_mat = residual ** 2
57 | elif loss == "L1":
58 | loss_mat = torch.abs(residual)
59 | else:
60 | print("loss type {} not implemented!".format(loss))
61 |
62 | if normalise:
63 | loss_mat = loss_mat / gt
64 |
65 | return loss_mat
66 |
67 | def reduce_batch_loss(loss_mat, var=None, avg=True, mask=None, loss_type="L1"):
68 | mask_num = torch.sum(mask, dim=-1)
69 | if (mask_num == 0).any(): # no valid sample, return 0 loss
70 | loss = torch.zeros_like(loss_mat)
71 | if avg:
72 | loss = torch.mean(loss, dim=-1)
73 | return loss
74 | if var is not None:
75 | eps = 1e-4
76 | if loss_type == "L2":
77 | information = 1.0 / (var + eps)
78 | elif loss_type == "L1":
79 | information = 1.0 / (torch.sqrt(var) + eps)
80 |
81 | loss_weighted = loss_mat * information
82 | else:
83 | loss_weighted = loss_mat
84 |
85 | if avg:
86 | if mask is not None:
87 | loss = (torch.sum(loss_weighted, dim=-1)/(torch.sum(mask, dim=-1)+1e-10))
88 | if (loss > 100000).any():
89 | print("loss explode")
90 | exit(-1)
91 | else:
92 | loss = torch.mean(loss_weighted, dim=-1).sum()
93 | else:
94 | loss = loss_weighted
95 |
96 | return loss
97 |
98 | def make_3D_grid(occ_range=[-1., 1.], dim=256, device="cuda:0", transform=None, scale=None):
99 | t = torch.linspace(occ_range[0], occ_range[1], steps=dim, device=device)
100 | grid = torch.meshgrid(t, t, t)
101 | grid_3d = torch.cat(
102 | (grid[0][..., None],
103 | grid[1][..., None],
104 | grid[2][..., None]), dim=3
105 | )
106 |
107 | if scale is not None:
108 | grid_3d = grid_3d * scale
109 | if transform is not None:
110 | R1 = transform[None, None, None, 0, :3]
111 | R2 = transform[None, None, None, 1, :3]
112 | R3 = transform[None, None, None, 2, :3]
113 |
114 | grid1 = (R1 * grid_3d).sum(-1, keepdim=True)
115 | grid2 = (R2 * grid_3d).sum(-1, keepdim=True)
116 | grid3 = (R3 * grid_3d).sum(-1, keepdim=True)
117 | grid_3d = torch.cat([grid1, grid2, grid3], dim=-1)
118 |
119 | trans = transform[None, None, None, :3, 3]
120 | grid_3d = grid_3d + trans
121 |
122 | return grid_3d
123 |
--------------------------------------------------------------------------------
/metric/eval_3D_obj.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import trimesh
4 | from metrics import accuracy, completion, completion_ratio
5 | import os
6 | import json
7 |
8 | def calc_3d_metric(mesh_rec, mesh_gt, N=200000):
9 | """
10 | 3D reconstruction metric.
11 | """
12 | metrics = [[] for _ in range(4)]
13 | transform, extents = trimesh.bounds.oriented_bounds(mesh_gt)
14 | extents = extents / 0.9 # enlarge 0.9
15 | box = trimesh.creation.box(extents=extents, transform=np.linalg.inv(transform))
16 | mesh_rec = mesh_rec.slice_plane(box.facets_origin, -box.facets_normal)
17 | if mesh_rec.vertices.shape[0] == 0:
18 | print("no mesh found")
19 | return
20 | rec_pc = trimesh.sample.sample_surface(mesh_rec, N)
21 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0])
22 |
23 | gt_pc = trimesh.sample.sample_surface(mesh_gt, N)
24 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0])
25 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices)
26 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices)
27 | completion_ratio_rec = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.05)
28 | completion_ratio_rec_1 = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.01)
29 |
30 | # accuracy_rec *= 100 # convert to cm
31 | # completion_rec *= 100 # convert to cm
32 | # completion_ratio_rec *= 100 # convert to %
33 | # print('accuracy: ', accuracy_rec)
34 | # print('completion: ', completion_rec)
35 | # print('completion ratio: ', completion_ratio_rec)
36 | # print("completion_ratio_rec_1cm ", completion_ratio_rec_1)
37 | metrics[0].append(accuracy_rec)
38 | metrics[1].append(completion_rec)
39 | metrics[2].append(completion_ratio_rec_1)
40 | metrics[3].append(completion_ratio_rec)
41 | return metrics
42 |
43 | def get_gt_bg_mesh(gt_dir, background_cls_list):
44 | with open(os.path.join(gt_dir, "info_semantic.json")) as f:
45 | label_obj_list = json.load(f)["objects"]
46 |
47 | bg_meshes = []
48 | for obj in label_obj_list:
49 | if int(obj["class_id"]) in background_cls_list:
50 | obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(int(obj["id"])) + ".ply")
51 | obj_mesh = trimesh.load(obj_file)
52 | bg_meshes.append(obj_mesh)
53 |
54 | bg_mesh = trimesh.util.concatenate(bg_meshes)
55 | return bg_mesh
56 |
57 | def get_obj_ids(obj_dir):
58 | files = os.listdir(obj_dir)
59 | obj_ids = []
60 | for f in files:
61 | obj_id = f.split("obj")[1][:-1]
62 | if obj_id == '':
63 | continue
64 | obj_ids.append(int(obj_id))
65 | return obj_ids
66 |
67 |
68 | if __name__ == "__main__":
69 | background_cls_list = [5, 12, 30, 31, 40, 60, 92, 93, 95, 97, 98, 79]
70 | exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
71 | data_dir = "/home/xin/data/vmap/"
72 | log_dir = "../logs/iMAP/"
73 | # log_dir = "../logs/vMAP/"
74 |
75 | for exp in tqdm(exp_name):
76 | gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat")
77 | exp_dir = os.path.join(log_dir, exp)
78 | mesh_dir = os.path.join(exp_dir, "scene_mesh")
79 | output_path = os.path.join(exp_dir, "eval_mesh")
80 | os.makedirs(output_path, exist_ok=True)
81 | metrics_3D = [[] for _ in range(4)]
82 |
83 | # get obj ids
84 | # obj_ids = np.loadtxt() # todo use a pre-defined obj list or use vMAP results
85 | obj_ids = get_obj_ids(mesh_dir.replace("iMAP", "vMAP"))
86 | for obj_id in tqdm(obj_ids):
87 | if obj_id == 0: # for bg
88 | N = 200000
89 | mesh_gt = get_gt_bg_mesh(gt_dir, background_cls_list)
90 | else: # for obj
91 | N = 10000
92 | obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(obj_id) + ".ply")
93 | mesh_gt = trimesh.load(obj_file)
94 |
95 | if "vMAP" in exp_dir:
96 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj"+str(obj_id)+".obj")
97 | elif "iMAP" in exp_dir:
98 | rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj0.obj")
99 | else:
100 | print("Not Implement")
101 | exit(-1)
102 |
103 | mesh_rec = trimesh.load(rec_meshfile)
104 | # mesh_rec.invert() # niceslam mesh face needs invert
105 | metrics = calc_3d_metric(mesh_rec, mesh_gt, N=N) # for objs use 10k, for scene use 200k points
106 | if metrics is None:
107 | continue
108 | np.save(output_path + '/metric_obj{}.npy'.format(obj_id), np.array(metrics))
109 | metrics_3D[0].append(metrics[0]) # acc
110 | metrics_3D[1].append(metrics[1]) # comp
111 | metrics_3D[2].append(metrics[2]) # comp ratio 1cm
112 | metrics_3D[3].append(metrics[3]) # comp ratio 5cm
113 | metrics_3D = np.array(metrics_3D)
114 | np.save(output_path + '/metrics_3D_obj.npy', metrics_3D)
115 | print("metrics 3D obj \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n", metrics_3D.mean(axis=1))
116 | print("-----------------------------------------")
117 | print("finish exp ", exp)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [comment]: <> (# vMAP: Vectorised Object Mapping for Neural Field SLAM)
2 |
3 |
4 |
5 |
6 |
7 |
vMAP: Vectorised Object Mapping for Neural Field SLAM
8 |
9 | Xin Kong
10 | ·
11 | Shikun Liu
12 | ·
13 | Marwan Taher
14 | ·
15 | Andrew Davison
16 |
17 |
18 | [comment]: <> ( PAPER
)
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | vMAP builds an object-level map from a real-time RGB-D input stream. Each object is represented by a separate MLP neural field model, all optimised in parallel via vectorised training.
29 |
30 |
31 |
32 | We provide the implementation of the following neural-field SLAM frameworks:
33 | - **vMAP** [Official Implementation]
34 | - **iMAP** [Simplified and Improved Re-Implementation, with depth guided sampling]
35 |
36 |
37 |
38 | ## Install
39 | First, let's start with a virtual environment with the required dependencies.
40 | ```bash
41 | conda env create -f environment.yml
42 | ```
43 |
44 | ## Dataset
45 | Please download the following datasets to reproduce our results.
46 |
47 | * [Replica Demo](https://huggingface.co/datasets/kxic/vMAP/resolve/main/demo_replica_room_0.zip) - Replica Room 0 only for faster experimentation.
48 | * [Replica](https://huggingface.co/datasets/kxic/vMAP/resolve/main/vmap.zip) - All Pre-generated Replica sequences. For Replica data generation, please refer to directory `data_generation`.
49 | * [ScanNet](https://github.com/ScanNet/ScanNet) - Official ScanNet sequences.
50 | Each dataset contains a sequence of RGB-D images, as well as their corresponding camera poses, and object instance labels.
51 | To extract data from ScanNet .sens files, run
52 | ```bash
53 | conda activate py2
54 | python2 reader.py --filename ~/data/ScanNet/scannet/scans/scene0024_00/scene0024_00.sens --output_path ~/data/ScanNet/objnerf/ --export_depth_images --export_color_images --export_poses --export_intrinsics
55 | ```
56 |
57 | ## Config
58 |
59 | Then update the config files in `configs/.json` with your dataset paths, as well as other training hyper-parameters.
60 | ```json
61 | "dataset": {
62 | "path": "path/to/ims/folder/",
63 | }
64 | ```
65 |
66 | ## Running vMAP / iMAP
67 | The following commands will run vMAP / iMAP in a single-thread setting.
68 |
69 | #### vMAP
70 | ```bash
71 | python ./train.py --config ./configs/Replica/config_replica_room0_vMAP.json --logdir ./logs/vMAP/room0 --save_ckpt True
72 | ```
73 | #### iMAP
74 | ```bash
75 | python ./train.py --config ./configs/Replica/config_replica_room0_iMAP.json --logdir ./logs/iMAP/room0 --save_ckpt True
76 | ```
77 |
78 | [comment]: <> (#### Multi thread demo)
79 |
80 | [comment]: <> (```bash)
81 |
82 | [comment]: <> (./parallel_train.py --config "config_file.json" --logdir ./logs)
83 |
84 | [comment]: <> (```)
85 |
86 | ## Evaluation
87 | To evaluate the quality of reconstructed scenes, we provide two different methods,
88 | #### 3D Scene-level Evaluation
89 | The same metrics following the original iMAP, to compare with GT scene meshes by **Accuracy**, **Completion** and **Completion Ratio**.
90 | ```bash
91 | python ./metric/eval_3D_scene.py
92 | ```
93 | #### 3D Object-level Evaluation
94 | We also provide the object-level metrics by computing the same metrics but averaging across all objects in a scene.
95 | ```bash
96 | python ./metric/eval_3D_obj.py
97 | ```
98 |
99 | [comment]: <> (### Novel View Synthesis)
100 |
101 | [comment]: <> (##### 2D Novel View Eval)
102 |
103 | [comment]: <> (We rendered a new trajectory in each scene and randomly choose novel view pose from it, evaluating 2D rendering performance)
104 |
105 | [comment]: <> (```bash)
106 |
107 | [comment]: <> (./metric/eval_2D_view.py)
108 |
109 | [comment]: <> (```)
110 |
111 | ## Results
112 | We provide raw results, including 3D meshes, 2D novel view rendering, and evaluated metrics of vMAP and iMAP* for easier comparison.
113 |
114 | * [Replica](https://huggingface.co/datasets/kxic/vMAP/resolve/main/vMAP_Replica_Results.zip)
115 |
116 | ## Acknowledgement
117 | We would like thank the following open-source repositories that we have build upon for the implementation of this work: [NICE-SLAM](https://github.com/cvg/nice-slam), and [functorch](https://github.com/pytorch/functorch).
118 |
119 | ## Citation
120 | If you found this code/work to be useful in your own research, please considering citing the following:
121 | ```bibtex
122 | @article{kong2023vmap,
123 | title={vMAP: Vectorised Object Mapping for Neural Field SLAM},
124 | author={Kong, Xin and Liu, Shikun and Taher, Marwan and Davison, Andrew J},
125 | journal={arXiv preprint arXiv:2302.01838},
126 | year={2023}
127 | }
128 | ```
129 |
130 | ```bibtex
131 | @inproceedings{sucar2021imap,
132 | title={iMAP: Implicit mapping and positioning in real-time},
133 | author={Sucar, Edgar and Liu, Shikun and Ortiz, Joseph and Davison, Andrew J},
134 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
135 | pages={6229--6238},
136 | year={2021}
137 | }
138 | ```
139 |
140 |
--------------------------------------------------------------------------------
/data_generation/settings.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import habitat_sim
6 | import habitat_sim.agent
7 |
8 | default_sim_settings = {
9 | # settings shared by example.py and benchmark.py
10 | "max_frames": 1000,
11 | "width": 640,
12 | "height": 480,
13 | "default_agent": 0,
14 | "sensor_height": 1.5,
15 | "hfov": 90,
16 | "color_sensor": True, # RGB sensor (default: ON)
17 | "semantic_sensor": False, # semantic sensor (default: OFF)
18 | "depth_sensor": False, # depth sensor (default: OFF)
19 | "ortho_rgba_sensor": False, # Orthographic RGB sensor (default: OFF)
20 | "ortho_depth_sensor": False, # Orthographic depth sensor (default: OFF)
21 | "ortho_semantic_sensor": False, # Orthographic semantic sensor (default: OFF)
22 | "fisheye_rgba_sensor": False,
23 | "fisheye_depth_sensor": False,
24 | "fisheye_semantic_sensor": False,
25 | "equirect_rgba_sensor": False,
26 | "equirect_depth_sensor": False,
27 | "equirect_semantic_sensor": False,
28 | "seed": 1,
29 | "silent": False, # do not print log info (default: OFF)
30 | # settings exclusive to example.py
31 | "save_png": False, # save the pngs to disk (default: OFF)
32 | "print_semantic_scene": False,
33 | "print_semantic_mask_stats": False,
34 | "compute_shortest_path": False,
35 | "compute_action_shortest_path": False,
36 | "scene": "data/scene_datasets/habitat-test-scenes/skokloster-castle.glb",
37 | "test_scene_data_url": "http://dl.fbaipublicfiles.com/habitat/habitat-test-scenes.zip",
38 | "goal_position": [5.047, 0.199, 11.145],
39 | "enable_physics": False,
40 | "enable_gfx_replay_save": False,
41 | "physics_config_file": "./data/default.physics_config.json",
42 | "num_objects": 10,
43 | "test_object_index": 0,
44 | "frustum_culling": True,
45 | }
46 |
47 | # build SimulatorConfiguration
48 | def make_cfg(settings):
49 | sim_cfg = habitat_sim.SimulatorConfiguration()
50 | if "frustum_culling" in settings:
51 | sim_cfg.frustum_culling = settings["frustum_culling"]
52 | else:
53 | sim_cfg.frustum_culling = False
54 | if "enable_physics" in settings:
55 | sim_cfg.enable_physics = settings["enable_physics"]
56 | if "physics_config_file" in settings:
57 | sim_cfg.physics_config_file = settings["physics_config_file"]
58 | # if not settings["silent"]:
59 | # print("sim_cfg.physics_config_file = " + sim_cfg.physics_config_file)
60 | if "scene_light_setup" in settings:
61 | sim_cfg.scene_light_setup = settings["scene_light_setup"]
62 | sim_cfg.gpu_device_id = 0
63 | if not hasattr(sim_cfg, "scene_id"):
64 | raise RuntimeError(
65 | "Error: Please upgrade habitat-sim. SimulatorConfig API version mismatch"
66 | )
67 | sim_cfg.scene_id = settings["scene_file"]
68 |
69 | # define default sensor parameters (see src/esp/Sensor/Sensor.h)
70 | sensor_specs = []
71 |
72 | def create_camera_spec(**kw_args):
73 | camera_sensor_spec = habitat_sim.CameraSensorSpec()
74 | camera_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
75 | camera_sensor_spec.resolution = [settings["height"], settings["width"]]
76 | camera_sensor_spec.position = [0, settings["sensor_height"], 0]
77 | for k in kw_args:
78 | setattr(camera_sensor_spec, k, kw_args[k])
79 | return camera_sensor_spec
80 |
81 | if settings["color_sensor"]:
82 | color_sensor_spec = create_camera_spec(
83 | uuid="color_sensor",
84 | # hfov=settings["hfov"],
85 | sensor_type=habitat_sim.SensorType.COLOR,
86 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE,
87 | )
88 | sensor_specs.append(color_sensor_spec)
89 |
90 | if settings["depth_sensor"]:
91 | depth_sensor_spec = create_camera_spec(
92 | uuid="depth_sensor",
93 | # hfov=settings["hfov"],
94 | sensor_type=habitat_sim.SensorType.DEPTH,
95 | channels=1,
96 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE,
97 | )
98 | sensor_specs.append(depth_sensor_spec)
99 |
100 | if settings["semantic_sensor"]:
101 | semantic_sensor_spec = create_camera_spec(
102 | uuid="semantic_sensor",
103 | # hfov=settings["hfov"],
104 | sensor_type=habitat_sim.SensorType.SEMANTIC,
105 | channels=1,
106 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE,
107 | )
108 | sensor_specs.append(semantic_sensor_spec)
109 |
110 | # if settings["ortho_rgba_sensor"]:
111 | # ortho_rgba_sensor_spec = create_camera_spec(
112 | # uuid="ortho_rgba_sensor",
113 | # sensor_type=habitat_sim.SensorType.COLOR,
114 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC,
115 | # )
116 | # sensor_specs.append(ortho_rgba_sensor_spec)
117 | #
118 | # if settings["ortho_depth_sensor"]:
119 | # ortho_depth_sensor_spec = create_camera_spec(
120 | # uuid="ortho_depth_sensor",
121 | # sensor_type=habitat_sim.SensorType.DEPTH,
122 | # channels=1,
123 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC,
124 | # )
125 | # sensor_specs.append(ortho_depth_sensor_spec)
126 | #
127 | # if settings["ortho_semantic_sensor"]:
128 | # ortho_semantic_sensor_spec = create_camera_spec(
129 | # uuid="ortho_semantic_sensor",
130 | # sensor_type=habitat_sim.SensorType.SEMANTIC,
131 | # channels=1,
132 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC,
133 | # )
134 | # sensor_specs.append(ortho_semantic_sensor_spec)
135 |
136 | # TODO Figure out how to implement copying of specs
137 | def create_fisheye_spec(**kw_args):
138 | fisheye_sensor_spec = habitat_sim.FisheyeSensorDoubleSphereSpec()
139 | fisheye_sensor_spec.uuid = "fisheye_sensor"
140 | fisheye_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
141 | fisheye_sensor_spec.sensor_model_type = (
142 | habitat_sim.FisheyeSensorModelType.DOUBLE_SPHERE
143 | )
144 |
145 | # The default value (alpha, xi) is set to match the lens "GoPro" found in Table 3 of this paper:
146 | # Vladyslav Usenko, Nikolaus Demmel and Daniel Cremers: The Double Sphere
147 | # Camera Model, The International Conference on 3D Vision (3DV), 2018
148 | # You can find the intrinsic parameters for the other lenses in the same table as well.
149 | fisheye_sensor_spec.xi = -0.27
150 | fisheye_sensor_spec.alpha = 0.57
151 | fisheye_sensor_spec.focal_length = [364.84, 364.86]
152 |
153 | fisheye_sensor_spec.resolution = [settings["height"], settings["width"]]
154 | # The default principal_point_offset is the middle of the image
155 | fisheye_sensor_spec.principal_point_offset = None
156 | # default: fisheye_sensor_spec.principal_point_offset = [i/2 for i in fisheye_sensor_spec.resolution]
157 | fisheye_sensor_spec.position = [0, settings["sensor_height"], 0]
158 | for k in kw_args:
159 | setattr(fisheye_sensor_spec, k, kw_args[k])
160 | return fisheye_sensor_spec
161 |
162 | # if settings["fisheye_rgba_sensor"]:
163 | # fisheye_rgba_sensor_spec = create_fisheye_spec(uuid="fisheye_rgba_sensor")
164 | # sensor_specs.append(fisheye_rgba_sensor_spec)
165 | # if settings["fisheye_depth_sensor"]:
166 | # fisheye_depth_sensor_spec = create_fisheye_spec(
167 | # uuid="fisheye_depth_sensor",
168 | # sensor_type=habitat_sim.SensorType.DEPTH,
169 | # channels=1,
170 | # )
171 | # sensor_specs.append(fisheye_depth_sensor_spec)
172 | # if settings["fisheye_semantic_sensor"]:
173 | # fisheye_semantic_sensor_spec = create_fisheye_spec(
174 | # uuid="fisheye_semantic_sensor",
175 | # sensor_type=habitat_sim.SensorType.SEMANTIC,
176 | # channels=1,
177 | # )
178 | # sensor_specs.append(fisheye_semantic_sensor_spec)
179 |
180 | def create_equirect_spec(**kw_args):
181 | equirect_sensor_spec = habitat_sim.EquirectangularSensorSpec()
182 | equirect_sensor_spec.uuid = "equirect_rgba_sensor"
183 | equirect_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
184 | equirect_sensor_spec.resolution = [settings["height"], settings["width"]]
185 | equirect_sensor_spec.position = [0, settings["sensor_height"], 0]
186 | for k in kw_args:
187 | setattr(equirect_sensor_spec, k, kw_args[k])
188 | return equirect_sensor_spec
189 |
190 | # if settings["equirect_rgba_sensor"]:
191 | # equirect_rgba_sensor_spec = create_equirect_spec(uuid="equirect_rgba_sensor")
192 | # sensor_specs.append(equirect_rgba_sensor_spec)
193 | #
194 | # if settings["equirect_depth_sensor"]:
195 | # equirect_depth_sensor_spec = create_equirect_spec(
196 | # uuid="equirect_depth_sensor",
197 | # sensor_type=habitat_sim.SensorType.DEPTH,
198 | # channels=1,
199 | # )
200 | # sensor_specs.append(equirect_depth_sensor_spec)
201 | #
202 | # if settings["equirect_semantic_sensor"]:
203 | # equirect_semantic_sensor_spec = create_equirect_spec(
204 | # uuid="equirect_semantic_sensor",
205 | # sensor_type=habitat_sim.SensorType.SEMANTIC,
206 | # channels=1,
207 | # )
208 | # sensor_specs.append(equirect_semantic_sensor_spec)
209 |
210 | # create agent specifications
211 | agent_cfg = habitat_sim.agent.AgentConfiguration()
212 | agent_cfg.sensor_specifications = sensor_specs
213 | agent_cfg.action_space = {
214 | "move_forward": habitat_sim.agent.ActionSpec(
215 | "move_forward", habitat_sim.agent.ActuationSpec(amount=0.25)
216 | ),
217 | "turn_left": habitat_sim.agent.ActionSpec(
218 | "turn_left", habitat_sim.agent.ActuationSpec(amount=10.0)
219 | ),
220 | "turn_right": habitat_sim.agent.ActionSpec(
221 | "turn_right", habitat_sim.agent.ActuationSpec(amount=10.0)
222 | ),
223 | }
224 |
225 | # override action space to no-op to test physics
226 | if sim_cfg.enable_physics:
227 | agent_cfg.action_space = {
228 | "move_forward": habitat_sim.agent.ActionSpec(
229 | "move_forward", habitat_sim.agent.ActuationSpec(amount=0.0)
230 | )
231 | }
232 |
233 | return habitat_sim.Configuration(sim_cfg, [agent_cfg])
234 |
--------------------------------------------------------------------------------
/data_generation/habitat_renderer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os, sys, argparse
3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5 | import cv2
6 | import logging
7 | import habitat_sim as hs
8 | import numpy as np
9 | import quaternion
10 | import yaml
11 | import json
12 | from typing import Any, Dict, List, Tuple, Union
13 | from imgviz import label_colormap
14 | from PIL import Image
15 | import matplotlib.pyplot as plt
16 | import transformation
17 | import imgviz
18 | from datetime import datetime
19 | import time
20 | from settings import make_cfg
21 |
22 | # Custom type definitions
23 | Config = Dict[str, Any]
24 | Observation = hs.sensor.Observation
25 | Sim = hs.Simulator
26 |
27 | def init_habitat(config) :
28 | """Initialize the Habitat simulator with sensors and scene file"""
29 | _cfg = make_cfg(config)
30 | sim = Sim(_cfg)
31 | sim_cfg = hs.SimulatorConfiguration()
32 | sim_cfg.gpu_device_id = 0
33 | # Note: all sensors must have the same resolution
34 | camera_resolution = [config["height"], config["width"]]
35 | sensors = {
36 | "color_sensor": {
37 | "sensor_type": hs.SensorType.COLOR,
38 | "resolution": camera_resolution,
39 | "position": [0.0, config["sensor_height"], 0.0],
40 | },
41 | "depth_sensor": {
42 | "sensor_type": hs.SensorType.DEPTH,
43 | "resolution": camera_resolution,
44 | "position": [0.0, config["sensor_height"], 0.0],
45 | },
46 | "semantic_sensor": {
47 | "sensor_type": hs.SensorType.SEMANTIC,
48 | "resolution": camera_resolution,
49 | "position": [0.0, config["sensor_height"], 0.0],
50 | },
51 | }
52 |
53 | sensor_specs = []
54 | for sensor_uuid, sensor_params in sensors.items():
55 | if config[sensor_uuid]:
56 | sensor_spec = hs.SensorSpec()
57 | sensor_spec.uuid = sensor_uuid
58 | sensor_spec.sensor_type = sensor_params["sensor_type"]
59 | sensor_spec.resolution = sensor_params["resolution"]
60 | sensor_spec.position = sensor_params["position"]
61 |
62 | sensor_specs.append(sensor_spec)
63 |
64 | # Here you can specify the amount of displacement in a forward action and the turn angle
65 | agent_cfg = hs.agent.AgentConfiguration()
66 | agent_cfg.sensor_specifications = sensor_specs
67 | agent_cfg.action_space = {
68 | "move_forward": hs.agent.ActionSpec(
69 | "move_forward", hs.agent.ActuationSpec(amount=0.25)
70 | ),
71 | "turn_left": hs.agent.ActionSpec(
72 | "turn_left", hs.agent.ActuationSpec(amount=30.0)
73 | ),
74 | "turn_right": hs.agent.ActionSpec(
75 | "turn_right", hs.agent.ActuationSpec(amount=30.0)
76 | ),
77 | }
78 |
79 | hs_cfg = hs.Configuration(sim_cfg, [agent_cfg])
80 | # sim = Sim(hs_cfg)
81 |
82 | if config["enable_semantics"]: # extract instance to class mapping function
83 | assert os.path.exists(config["instance2class_mapping"])
84 | with open(config["instance2class_mapping"], "r") as f:
85 | annotations = json.load(f)
86 | instance_id_to_semantic_label_id = np.array(annotations["id_to_label"])
87 | num_classes = len(annotations["classes"])
88 | label_colour_map = label_colormap()
89 | config["instance2semantic"] = instance_id_to_semantic_label_id
90 | config["classes"] = annotations["classes"]
91 | config["objects"] = annotations["objects"]
92 |
93 | config["num_classes"] = num_classes
94 | config["label_colour_map"] = label_colormap()
95 | config["instance_colour_map"] = label_colormap(500)
96 |
97 |
98 | # add camera intrinsic
99 | # hfov = float(agent_cfg.sensor_specifications[0].parameters['hfov']) * np.pi / 180.
100 | # https://aihabitat.org/docs/habitat-api/view-transform-warp.html
101 | # config['K'] = K
102 | # config['K'] = np.array([[fx, 0.0, 0.0], [0.0, fx, 0.0], [0.0, 0.0, 1.0]],
103 | # dtype=np.float64)
104 |
105 | # hfov = float(agent_cfg.sensor_specifications[0].parameters['hfov'])
106 | # fx = 1.0 / np.tan(hfov / 2.0)
107 | # config['K'] = np.array([[fx, 0.0, 0.0], [0.0, fx, 0.0], [0.0, 0.0, 1.0]],
108 | # dtype=np.float64)
109 |
110 | # Get the intrinsic camera parameters
111 |
112 |
113 | logging.info('Habitat simulator initialized')
114 |
115 | return sim, hs_cfg, config
116 |
117 | def save_renders(save_path, observation, enable_semantic, suffix=""):
118 | save_path_rgb = os.path.join(save_path, "rgb")
119 | save_path_depth = os.path.join(save_path, "depth")
120 | save_path_sem_class = os.path.join(save_path, "semantic_class")
121 | save_path_sem_instance = os.path.join(save_path, "semantic_instance")
122 |
123 | if not os.path.exists(save_path_rgb):
124 | os.makedirs(save_path_rgb)
125 | if not os.path.exists(save_path_depth):
126 | os.makedirs(save_path_depth)
127 | if not os.path.exists(save_path_sem_class):
128 | os.makedirs(save_path_sem_class)
129 | if not os.path.exists(save_path_sem_instance):
130 | os.makedirs(save_path_sem_instance)
131 |
132 | cv2.imwrite(os.path.join(save_path_rgb, "rgb{}.png".format(suffix)), observation["color_sensor"][:,:,::-1]) # change from RGB to BGR for opencv write
133 | cv2.imwrite(os.path.join(save_path_depth, "depth{}.png".format(suffix)), observation["depth_sensor_mm"])
134 |
135 | if enable_semantic:
136 | cv2.imwrite(os.path.join(save_path_sem_class, "semantic_class{}.png".format(suffix)), observation["semantic_class"])
137 | cv2.imwrite(os.path.join(save_path_sem_class, "vis_sem_class{}.png".format(suffix)), observation["vis_sem_class"][:,:,::-1])
138 |
139 | cv2.imwrite(os.path.join(save_path_sem_instance, "semantic_instance{}.png".format(suffix)), observation["semantic_instance"])
140 | cv2.imwrite(os.path.join(save_path_sem_instance, "vis_sem_instance{}.png".format(suffix)), observation["vis_sem_instance"][:,:,::-1])
141 |
142 |
143 | def render(sim, config):
144 | """Return the sensor observations and ground truth pose"""
145 | observation = sim.get_sensor_observations()
146 |
147 | # process rgb imagem change from RGBA to RGB
148 | observation['color_sensor'] = observation['color_sensor'][..., 0:3]
149 | rgb_img = observation['color_sensor']
150 |
151 | # process depth
152 | depth_mm = (observation['depth_sensor'].copy()*1000).astype(np.uint16) # change meters to mm
153 | observation['depth_sensor_mm'] = depth_mm
154 |
155 | # process semantics
156 | if config['enable_semantics']:
157 |
158 | # Assuming the scene has no more than 65534 objects
159 | observation['semantic_instance'] = np.clip(observation['semantic_sensor'].astype(np.uint16), 0, 65535)
160 | # observation['semantic_instance'][observation['semantic_instance']==12]=0 # mask out certain instance
161 | # Convert instance IDs to class IDs
162 |
163 |
164 | # observation['semantic_classes'] = np.zeros(observation['semantic'].shape, dtype=np.uint8)
165 | # TODO make this conversion more efficient
166 | semantic_class = config["instance2semantic"][observation['semantic_instance']]
167 | semantic_class[semantic_class < 0] = 0
168 |
169 | vis_sem_class = config["label_colour_map"][semantic_class]
170 | vis_sem_instance = config["instance_colour_map"][observation['semantic_instance']] # may cause error when having more than 255 instances in the scene
171 |
172 | observation['semantic_class'] = semantic_class.astype(np.uint8)
173 | observation["vis_sem_class"] = vis_sem_class.astype(np.uint8)
174 | observation["vis_sem_instance"] = vis_sem_instance.astype(np.uint8)
175 |
176 | # del observation["semantic_sensor"]
177 |
178 | # Get the camera ground truth pose (T_HC) in the habitat frame from the
179 | # position and orientation
180 | t_HC = sim.get_agent(0).get_state().position
181 | q_HC = sim.get_agent(0).get_state().rotation
182 | T_HC = transformation.combine_pose(t_HC, q_HC)
183 |
184 | observation['T_HC'] = T_HC
185 | observation['T_WC'] = transformation.Thc_to_Twc(T_HC)
186 |
187 | return observation
188 |
189 | def set_agent_position(sim, pose):
190 | # Move the agent
191 | R = pose[:3, :3]
192 | orientation_quat = quaternion.from_rotation_matrix(R)
193 | t = pose[:3, 3]
194 | position = t
195 |
196 | orientation = [orientation_quat.x, orientation_quat.y, orientation_quat.z, orientation_quat.w]
197 | agent = sim.get_agent(0)
198 | agent_state = hs.agent.AgentState(position, orientation)
199 | # agent.set_state(agent_state, reset_sensors=False)
200 | agent.set_state(agent_state)
201 |
202 | def main():
203 | parser = argparse.ArgumentParser(description='Render Colour, Depth, Semantic, Instance labeling from Habitat-Simultation.')
204 | parser.add_argument('--config_file', type=str,
205 | default="./data_generation/replica_render_config_vMAP.yaml",
206 | help='the path to custom config file.')
207 | args = parser.parse_args()
208 |
209 | """Initialize the config dict and Habitat simulator"""
210 | # Read YAML file
211 | with open(args.config_file, 'r') as f:
212 | config = yaml.safe_load(f)
213 |
214 | config["save_path"] = os.path.join(config["save_path"])
215 | if not os.path.exists(config["save_path"]):
216 | os.makedirs(config["save_path"])
217 |
218 | T_wc = np.loadtxt(config["pose_file"]).reshape(-1, 4, 4)
219 | Ts_cam2world = T_wc
220 |
221 | print("-----Initialise and Set Habitat-Sim-----")
222 | sim, hs_cfg, config = init_habitat(config)
223 | # Set agent state
224 | sim.initialize_agent(config["default_agent"])
225 |
226 | """Set agent state"""
227 | print("-----Render Images from Habitat-Sim-----")
228 | with open(os.path.join(config["save_path"], 'render_config.yaml'), 'w') as outfile:
229 | yaml.dump(config, outfile, default_flow_style=False)
230 | start_time = time.time()
231 | total_render_num = Ts_cam2world.shape[0]
232 | for i in range(total_render_num):
233 | if i % 100 == 0 :
234 | print("Rendering Process: {}/{}".format(i, total_render_num))
235 | set_agent_position(sim, transformation.Twc_to_Thc(Ts_cam2world[i]))
236 |
237 | # replica mode
238 | observation = render(sim, config)
239 | save_renders(config["save_path"], observation, config["enable_semantics"], suffix="_{}".format(i))
240 |
241 | end_time = time.time()
242 | print("-----Finish Habitat Rendering, Showing Trajectories.-----")
243 | print("Average rendering time per image is {} seconds.".format((end_time-start_time)/Ts_cam2world.shape[0]))
244 |
245 | if __name__ == "__main__":
246 | main()
247 |
248 |
249 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | vMAP SOFTWARE
2 |
3 | LICENCE AGREEMENT
4 |
5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College
6 | London”)) ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY
7 | ON THE CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE FOLLOWING
8 | AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE DOWNLOADING THE SOFTWARE.
9 | BY EXERCISING THE OPTION TO DOWNLOAD THE SOFTWARE YOU AGREE TO BE BOUND BY THE
10 | TERMS OF THE AGREEMENT.
11 |
12 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS)
13 |
14 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully
15 | paid-up, royalty free, non-transferable, non-sub- licensable licence (the
16 | “Licence”) to use the elastic fusion source code, including any modification,
17 | part or derivative (the “Software”).
18 |
19 | Ownership and Licence. Your rights to use and download the Software onto your
20 | computer, and all other copies that You are authorised to make, are specified
21 | in this Agreement. However, we (or our licensors) retain all rights, including
22 | but not limited to all copyright and other intellectual property rights
23 | anywhere in the world, in the Software not expressly granted to You in this
24 | Agreement.
25 |
26 | 2. Permitted use of the Licence:
27 |
28 | (a) You may download and install the Software onto one computer or server for
29 | use in accordance with Clause 2(b) of this Agreement provided that You ensure
30 | that the Software is not accessible by other users unless they have themselves
31 | accepted the terms of this licence agreement.
32 |
33 | (b) You may use the Software solely for non-commercial, internal or academic
34 | research purposes and only in accordance with the terms of this Agreement. You
35 | may not use the Software for commercial purposes, including but not limited to
36 | (1) integration of all or part of the source code or the Software into a
37 | product for sale or licence by or on behalf of You to third parties or (2) use
38 | of the Software or any derivative of it for research to develop software
39 | products for sale or licence to a third party or (3) use of the Software or any
40 | derivative of it for research to develop non-software products for sale or
41 | licence to a third party, or (4) use of the Software to provide any service to
42 | an external organisation for which payment is received.
43 |
44 | Should You wish to use the Software for commercial purposes, You shall
45 | email researchcontracts.engineering@imperial.ac.uk .
46 |
47 | (c) Right to Copy. You may copy the Software for back-up and archival purposes,
48 | provided that each copy is kept in your possession and provided You reproduce
49 | our copyright notice (set out in Schedule 1) on each copy.
50 |
51 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software
52 | and You may not transmit, transfer or sub-license this licence to use the
53 | Software or any of your rights or obligations under this Agreement to another
54 | party.
55 |
56 | (e) Identity of Licensee. The licence granted herein is personal to You. You
57 | shall not permit any third party to access, modify or otherwise use the
58 | Software nor shall You access modify or otherwise use the Software on behalf of
59 | any third party. If You wish to obtain a licence for mutiple users or a site
60 | licence for the Software please contact us
61 | at researchcontracts.engineering@imperial.ac.uk .
62 |
63 | (f) Publications and presentations. You may make public, results or data
64 | obtained from, dependent on or arising from research carried out using the
65 | Software, provided that any such presentation or publication identifies the
66 | Software as the source of the results or the data, including the Copyright
67 | Notice given in each element of the Software, and stating that the Software has
68 | been made available for use by You under licence from Imperial College London
69 | and You provide a copy of any such publication to Imperial College London.
70 |
71 | 3. Prohibited Uses. You may not, without written permission from us
72 | at researchcontracts.engineering@imperial.ac.uk :
73 |
74 | (a) Use, copy, modify, merge, or transfer copies of the Software or any
75 | documentation provided by us which relates to the Software except as provided
76 | in this Agreement;
77 |
78 | (b) Use any back-up or archival copies of the Software (or allow anyone else to
79 | use such copies) for any purpose other than to replace the original copy in the
80 | event it is destroyed or becomes defective; or
81 |
82 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner
83 | decode the Software for any reason.
84 |
85 | 4. Warranty Disclaimer
86 |
87 | (a) Disclaimer. The Software has been developed for research purposes only. You
88 | acknowledge that we are providing the Software to You under this licence
89 | agreement free of charge and on condition that the disclaimer set out below
90 | shall apply. We do not represent or warrant that the Software as to: (i) the
91 | quality, accuracy or reliability of the Software; (ii) the suitability of the
92 | Software for any particular use or for use under any specific conditions; and
93 | (iii) whether use of the Software will infringe third-party rights.
94 |
95 | You acknowledge that You have reviewed and evaluated the Software to determine
96 | that it meets your needs and that You assume all responsibility and liability
97 | for determining the suitability of the Software as fit for your particular
98 | purposes and requirements. Subject to Clause 4(b), we exclude and expressly
99 | disclaim all express and implied representations, warranties, conditions and
100 | terms not stated herein (including the implied conditions or warranties of
101 | satisfactory quality, merchantable quality, merchantability and fitness for
102 | purpose).
103 |
104 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or
105 | impose obligations upon us which cannot, in whole or in part, be excluded,
106 | restricted or modified or otherwise do not allow the exclusion of implied
107 | warranties, conditions or terms, in which case the above warranty disclaimer
108 | and exclusion will only apply to You to the extent permitted in the relevant
109 | jurisdiction and does not in any event exclude any implied warranties,
110 | conditions or terms which may not under applicable law be excluded.
111 |
112 | (c) Imperial College London disclaims all responsibility for the use which is
113 | made of the Software and any liability for the outcomes arising from using the
114 | Software.
115 |
116 | 5. Limitation of Liability
117 |
118 | (a) You acknowledge that we are providing the Software to You under this
119 | licence agreement free of charge and on condition that the limitation of
120 | liability set out below shall apply. Accordingly, subject to Clause 5(b), we
121 | exclude all liability whether in contract, tort, negligence or otherwise, in
122 | respect of the Software and/or any related documentation provided to You by us
123 | including, but not limited to, liability for loss or corruption of data, loss
124 | of contracts, loss of income, loss of profits, loss of cover and any
125 | consequential or indirect loss or damage of any kind arising out of or in
126 | connection with this licence agreement, however caused. This exclusion shall
127 | apply even if we have been advised of the possibility of such loss or damage.
128 |
129 | (b) You agree to indemnify Imperial College London and hold it harmless from
130 | and against any and all claims, damages and liabilities asserted by third
131 | parties (including claims for negligence) which arise directly or indirectly
132 | from the use of the Software or any derivative of it or the sale of any
133 | products based on the Software. You undertake to make no liability claim
134 | against any employee, student, agent or appointee of Imperial College London,
135 | in connection with this Licence or the Software.
136 |
137 | (c) Nothing in this Agreement shall have the effect of excluding or limiting
138 | our statutory liability.
139 |
140 | (d) Some jurisdictions do not allow these limitations or exclusions either
141 | wholly or in part, and, to that extent, they may not apply to you. Nothing in
142 | this licence agreement will affect your statutory rights or other relevant
143 | statutory provisions which cannot be excluded, restricted or modified, and its
144 | terms and conditions must be read and construed subject to any such statutory
145 | rights and/or provisions.
146 |
147 | 6. Confidentiality. You agree not to disclose any confidential information
148 | provided to You by us pursuant to this Agreement to any third party without our
149 | prior written consent. The obligations in this Clause 6 shall survive the
150 | termination of this Agreement for any reason.
151 |
152 | 7. Termination.
153 |
154 | (a) We may terminate this licence agreement and your right to use the Software
155 | at any time with immediate effect upon written notice to You.
156 |
157 | (b) This licence agreement and your right to use the Software automatically
158 | terminate if You:
159 |
160 | (i) fail to comply with any provisions of this Agreement; or
161 |
162 | (ii) destroy the copies of the Software in your possession, or voluntarily
163 | return the Software to us.
164 |
165 | (c) Upon termination You will destroy all copies of the Software.
166 |
167 | (d) Otherwise, the restrictions on your rights to use the Software will expire
168 | 10 (ten) years after first use of the Software under this licence agreement.
169 |
170 | 8. Miscellaneous Provisions.
171 |
172 | (a) This Agreement will be governed by and construed in accordance with the
173 | substantive laws of England and Wales whose courts shall have exclusive
174 | jurisdiction over all disputes which may arise between us.
175 |
176 | (b) This is the entire agreement between us relating to the Software, and
177 | supersedes any prior purchase order, communications, advertising or
178 | representations concerning the Software.
179 |
180 | (c) No change or modification of this Agreement will be valid unless it is in
181 | writing, and is signed by us.
182 |
183 | (d) The unenforceability or invalidity of any part of this Agreement will not
184 | affect the enforceability or validity of the remaining parts.
185 |
186 | BSD Elements of the Software
187 |
188 | For BSD elements of the Software, the following terms shall apply:
189 | Copyright as indicated in the header of the individual element of the Software.
190 | All rights reserved.
191 |
192 | Redistribution and use in source and binary forms, with or without
193 | modification, are permitted provided that the following conditions are met:
194 |
195 | 1. Redistributions of source code must retain the above copyright notice, this
196 | list of conditions and the following disclaimer.
197 |
198 | 2. Redistributions in binary form must reproduce the above copyright notice,
199 | this list of conditions and the following disclaimer in the documentation
200 | and/or other materials provided with the distribution.
201 |
202 | 3. Neither the name of the copyright holder nor the names of its contributors
203 | may be used to endorse or promote products derived from this software without
204 | specific prior written permission.
205 |
206 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
207 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
208 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
209 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
210 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
211 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
212 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
213 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
214 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
215 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
216 |
217 | SCHEDULE 1
218 |
219 | The Software
220 |
221 | vMAP is an object-level mapping system with each object represented by a small MLP, that can be efficiently vectorised trained.
222 | It is based on the techniques described in the following publication:
223 |
224 | • Xin Kong, Shikun Liu, Marwan Taher, Andrew J. Davison. vMAP: Vectorised Object Mapping for Neural Field SLAM. ArXiv Preprint, 2023
225 | _________________________
226 |
227 | Acknowledgments
228 |
229 | If you use the software, you should reference the following paper in any
230 | publication:
231 |
232 | • Xin Kong, Shikun Liu, Marwan Taher, Andrew J. Davison. vMAP: Vectorised Object Mapping for Neural Field SLAM. ArXiv Preprint, 2023
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import imgviz
2 | from torch.utils.data import Dataset, DataLoader
3 | import torch
4 | import numpy as np
5 | import cv2
6 | import os
7 | from utils import enlarge_bbox, get_bbox2d, get_bbox2d_batch, box_filter
8 | import glob
9 | from torchvision import transforms
10 | import image_transforms
11 | import open3d
12 | import time
13 |
14 | def next_live_data(track_to_map_IDT, inited):
15 | while True:
16 | if track_to_map_IDT.empty():
17 | if inited:
18 | return None # no new frame, use kf buffer
19 | else: # blocking until get the first frame
20 | continue
21 | else:
22 | Buffer_data = track_to_map_IDT.get(block=False)
23 | break
24 |
25 |
26 | if Buffer_data is not None:
27 | image, depth, T, obj, bbox_dict, kf_id = Buffer_data
28 | del Buffer_data
29 | T_obj = torch.eye(4)
30 | sample = {"image": image, "depth": depth, "T": T, "T_obj": T_obj,
31 | "obj": obj, "bbox_dict": bbox_dict, "frame_id": kf_id}
32 |
33 | return sample
34 | else:
35 | print("getting nothing?")
36 | exit(-1)
37 | # return None
38 |
39 | def init_loader(cfg, multi_worker=True):
40 | if cfg.dataset_format == "Replica":
41 | dataset = Replica(cfg)
42 | elif cfg.dataset_format == "ScanNet":
43 | dataset = ScanNet(cfg)
44 | else:
45 | print("Dataset format {} not found".format(cfg.dataset_format))
46 | exit(-1)
47 |
48 | # init dataloader
49 | if multi_worker:
50 | # multi worker loader
51 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, sampler=None,
52 | batch_sampler=None, num_workers=4, collate_fn=None,
53 | pin_memory=True, drop_last=False, timeout=0,
54 | worker_init_fn=None, generator=None, prefetch_factor=2,
55 | persistent_workers=True)
56 | else:
57 | # single worker loader
58 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, sampler=None,
59 | batch_sampler=None, num_workers=0)
60 |
61 | return dataloader
62 |
63 | class Replica(Dataset):
64 | def __init__(self, cfg):
65 | self.imap_mode = cfg.imap_mode
66 | self.root_dir = cfg.dataset_dir
67 | traj_file = os.path.join(self.root_dir, "traj_w_c.txt")
68 | self.Twc = np.loadtxt(traj_file, delimiter=" ").reshape([-1, 4, 4])
69 | self.depth_transform = transforms.Compose(
70 | [image_transforms.DepthScale(cfg.depth_scale),
71 | image_transforms.DepthFilter(cfg.max_depth)])
72 |
73 | # background semantic classes: undefined--1, undefined-0 beam-5 blinds-12 curtain-30 ceiling-31 floor-40 pillar-60 vent-92 wall-93 wall-plug-95 window-97 rug-98
74 | self.background_cls_list = [5,12,30,31,40,60,92,93,95,97,98,79]
75 | # Not sure: door-37 handrail-43 lamp-47 pipe-62 rack-66 shower-stall-73 stair-77 switch-79 wall-cabinet-94 picture-59
76 | self.bbox_scale = 0.2 # 1 #1.5 0.9== s=1/9, s=0.2
77 |
78 | def __len__(self):
79 | return len(os.listdir(os.path.join(self.root_dir, "depth")))
80 |
81 | def __getitem__(self, idx):
82 | bbox_dict = {}
83 | rgb_file = os.path.join(self.root_dir, "rgb", "rgb_" + str(idx) + ".png")
84 | depth_file = os.path.join(self.root_dir, "depth", "depth_" + str(idx) + ".png")
85 | inst_file = os.path.join(self.root_dir, "semantic_instance", "semantic_instance_" + str(idx) + ".png")
86 | obj_file = os.path.join(self.root_dir, "semantic_class", "semantic_class_" + str(idx) + ".png")
87 | depth = cv2.imread(depth_file, -1).astype(np.float32).transpose(1,0)
88 | image = cv2.imread(rgb_file).astype(np.uint8)
89 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose(1,0,2)
90 | obj = cv2.imread(obj_file, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0) # uint16 -> int32
91 | inst = cv2.imread(inst_file, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0) # uint16 -> int32
92 |
93 | bbox_scale = self.bbox_scale
94 |
95 | if self.imap_mode:
96 | obj = np.zeros_like(obj)
97 | else:
98 | obj_ = np.zeros_like(obj)
99 | inst_list = []
100 | batch_masks = []
101 | for inst_id in np.unique(inst):
102 | inst_mask = inst == inst_id
103 | # if np.sum(inst_mask) <= 2000: # too small 20 400
104 | # continue
105 | sem_cls = np.unique(obj[inst_mask]) # sem label, only interested obj
106 | assert sem_cls.shape[0] != 0
107 | if sem_cls in self.background_cls_list:
108 | continue
109 | obj_mask = inst == inst_id
110 | batch_masks.append(obj_mask)
111 | inst_list.append(inst_id)
112 | if len(batch_masks) > 0:
113 | batch_masks = torch.from_numpy(np.stack(batch_masks))
114 | cmins, cmaxs, rmins, rmaxs = get_bbox2d_batch(batch_masks)
115 |
116 | for i in range(batch_masks.shape[0]):
117 | w = rmaxs[i] - rmins[i]
118 | h = cmaxs[i] - cmins[i]
119 | if w <= 10 or h <= 10: # too small todo
120 | continue
121 | bbox_enlarged = enlarge_bbox([rmins[i], cmins[i], rmaxs[i], cmaxs[i]], scale=bbox_scale,
122 | w=obj.shape[1], h=obj.shape[0])
123 | # inst_list.append(inst_id)
124 | inst_id = inst_list[i]
125 | obj_[batch_masks[i]] = 1
126 | # bbox_dict.update({int(inst_id): torch.from_numpy(np.array(bbox_enlarged).reshape(-1))}) # batch format
127 | bbox_dict.update({inst_id: torch.from_numpy(np.array(
128 | [bbox_enlarged[1], bbox_enlarged[3], bbox_enlarged[0], bbox_enlarged[2]]))}) # bbox order
129 |
130 | inst[obj_ == 0] = 0 # for background
131 | obj = inst
132 |
133 | bbox_dict.update({0: torch.from_numpy(np.array([int(0), int(obj.shape[0]), 0, int(obj.shape[1])]))}) # bbox order
134 |
135 | T = self.Twc[idx] # could change to ORB-SLAM pose or else
136 | T_obj = np.eye(4) # obj pose, if dynamic
137 | sample = {"image": image, "depth": depth, "T": T, "T_obj": T_obj,
138 | "obj": obj, "bbox_dict": bbox_dict, "frame_id": idx}
139 |
140 | if image is None or depth is None:
141 | print(rgb_file)
142 | print(depth_file)
143 | raise ValueError
144 |
145 | if self.depth_transform:
146 | sample["depth"] = self.depth_transform(sample["depth"])
147 |
148 | return sample
149 |
150 | class ScanNet(Dataset):
151 | def __init__(self, cfg):
152 | self.imap_mode = cfg.imap_mode
153 | self.root_dir = cfg.dataset_dir
154 | self.color_paths = sorted(glob.glob(os.path.join(
155 | self.root_dir, 'color', '*.jpg')), key=lambda x: int(os.path.basename(x)[:-4]))
156 | self.depth_paths = sorted(glob.glob(os.path.join(
157 | self.root_dir, 'depth', '*.png')), key=lambda x: int(os.path.basename(x)[:-4]))
158 | self.inst_paths = sorted(glob.glob(os.path.join(
159 | self.root_dir, 'instance-filt', '*.png')), key=lambda x: int(os.path.basename(x)[:-4])) # instance-filt
160 | self.sem_paths = sorted(glob.glob(os.path.join(
161 | self.root_dir, 'label-filt', '*.png')), key=lambda x: int(os.path.basename(x)[:-4])) # label-filt
162 | self.load_poses(os.path.join(self.root_dir, 'pose'))
163 | self.n_img = len(self.color_paths)
164 | self.depth_transform = transforms.Compose(
165 | [image_transforms.DepthScale(cfg.depth_scale),
166 | image_transforms.DepthFilter(cfg.max_depth)])
167 | # self.rgb_transform = rgb_transform
168 | self.W = cfg.W
169 | self.H = cfg.H
170 | self.fx = cfg.fx
171 | self.fy = cfg.fy
172 | self.cx = cfg.cx
173 | self.cy = cfg.cy
174 | self.edge = cfg.mw
175 | self.intrinsic_open3d = open3d.camera.PinholeCameraIntrinsic(
176 | width=self.W,
177 | height=self.H,
178 | fx=self.fx,
179 | fy=self.fy,
180 | cx=self.cx,
181 | cy=self.cy,
182 | )
183 |
184 | self.min_pixels = 1500
185 | # from scannetv2-labels.combined.tsv
186 | #1-wall, 3-floor, 16-window, 41-ceiling, 232-light switch 0-unknown? 21-pillar 161-doorframe, shower walls-128, curtain-21, windowsill-141
187 | self.background_cls_list = [-1, 0, 1, 3, 16, 41, 232, 21, 161, 128, 21]
188 | self.bbox_scale = 0.2
189 | self.inst_dict = {}
190 |
191 | def load_poses(self, path):
192 | self.poses = []
193 | pose_paths = sorted(glob.glob(os.path.join(path, '*.txt')),
194 | key=lambda x: int(os.path.basename(x)[:-4]))
195 | for pose_path in pose_paths:
196 | with open(pose_path, "r") as f:
197 | lines = f.readlines()
198 | ls = []
199 | for line in lines:
200 | l = list(map(float, line.split(' ')))
201 | ls.append(l)
202 | c2w = np.array(ls).reshape(4, 4)
203 | self.poses.append(c2w)
204 |
205 | def __len__(self):
206 | return self.n_img
207 |
208 | def __getitem__(self, index):
209 | bbox_scale = self.bbox_scale
210 | color_path = self.color_paths[index]
211 | depth_path = self.depth_paths[index]
212 | inst_path = self.inst_paths[index]
213 | sem_path = self.sem_paths[index]
214 | color_data = cv2.imread(color_path).astype(np.uint8)
215 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
216 | depth_data = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
217 | depth_data = np.nan_to_num(depth_data, nan=0.)
218 | T = None
219 | if self.poses is not None:
220 | T = self.poses[index]
221 | if np.any(np.isinf(T)):
222 | if index + 1 == self.__len__():
223 | print("pose inf!")
224 | return None
225 | return self.__getitem__(index + 1)
226 |
227 | H, W = depth_data.shape
228 | color_data = cv2.resize(color_data, (W, H), interpolation=cv2.INTER_LINEAR)
229 | if self.edge:
230 | edge = self.edge # crop image edge, there are invalid value on the edge of the color image
231 | color_data = color_data[edge:-edge, edge:-edge]
232 | depth_data = depth_data[edge:-edge, edge:-edge]
233 | if self.depth_transform:
234 | depth_data = self.depth_transform(depth_data)
235 | bbox_dict = {}
236 | if self.imap_mode:
237 | inst_data = np.zeros_like(depth_data).astype(np.int32)
238 | else:
239 | inst_data = cv2.imread(inst_path, cv2.IMREAD_UNCHANGED)
240 | inst_data = cv2.resize(inst_data, (W, H), interpolation=cv2.INTER_NEAREST).astype(np.int32)
241 | sem_data = cv2.imread(sem_path, cv2.IMREAD_UNCHANGED)#.astype(np.int32)
242 | sem_data = cv2.resize(sem_data, (W, H), interpolation=cv2.INTER_NEAREST)
243 | if self.edge:
244 | edge = self.edge
245 | inst_data = inst_data[edge:-edge, edge:-edge]
246 | sem_data = sem_data[edge:-edge, edge:-edge]
247 | inst_data += 1 # shift from 0->1 , 0 is for background
248 |
249 | # box filter
250 | track_start = time.time()
251 | masks = []
252 | classes = []
253 | # convert to list of arrays
254 | obj_ids = np.unique(inst_data)
255 | for obj_id in obj_ids:
256 | mask = inst_data == obj_id
257 | sem_cls = np.unique(sem_data[mask])
258 | if sem_cls in self.background_cls_list:
259 | inst_data[mask] = 0 # set to background
260 | continue
261 | masks.append(mask)
262 | classes.append(obj_id)
263 | T_CW = np.linalg.inv(T)
264 | inst_data = box_filter(masks, classes, depth_data, self.inst_dict, self.intrinsic_open3d, T_CW, min_pixels=self.min_pixels)
265 |
266 | merged_obj_ids = np.unique(inst_data)
267 | for obj_id in merged_obj_ids:
268 | mask = inst_data == obj_id
269 | bbox2d = get_bbox2d(mask, bbox_scale=bbox_scale)
270 | if bbox2d is None:
271 | inst_data[mask] = 0 # set to bg
272 | else:
273 | min_x, min_y, max_x, max_y = bbox2d
274 | bbox_dict.update({int(obj_id): torch.from_numpy(np.array([min_x, max_x, min_y, max_y]).reshape(-1))}) # batch format
275 | bbox_time = time.time()
276 | # print("bbox time ", bbox_time - filter_time)
277 | cv2.imshow("inst", imgviz.label2rgb(inst_data))
278 | cv2.waitKey(1)
279 | print("frame {} track time {}".format(index, bbox_time-track_start))
280 |
281 | bbox_dict.update({0: torch.from_numpy(np.array([int(0), int(inst_data.shape[1]), 0, int(inst_data.shape[0])]))}) # bbox order
282 | # wrap data to frame dict
283 | T_obj = np.identity(4)
284 | sample = {"image": color_data.transpose(1,0,2), "depth": depth_data.transpose(1,0), "T": T, "T_obj": T_obj}
285 | if color_data is None or depth_data is None:
286 | print(color_path)
287 | print(depth_path)
288 | raise ValueError
289 |
290 | sample.update({"obj": inst_data.transpose(1,0)})
291 | sample.update({"bbox_dict": bbox_dict})
292 | return sample
293 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imgviz
3 | import numpy as np
4 | import torch
5 | from functorch import combine_state_for_ensemble
6 | import open3d
7 | import queue
8 | import copy
9 | import torch.utils.dlpack
10 |
11 | class BoundingBox():
12 | def __init__(self):
13 | super(BoundingBox, self).__init__()
14 | self.extent = None
15 | self.R = None
16 | self.center = None
17 | self.points3d = None # (8,3)
18 |
19 | def bbox_open3d2bbox(bbox_o3d):
20 | bbox = BoundingBox()
21 | bbox.extent = bbox_o3d.extent
22 | bbox.R = bbox_o3d.R
23 | bbox.center = bbox_o3d.center
24 | return bbox
25 |
26 | def bbox_bbox2open3d(bbox):
27 | bbox_o3d = open3d.geometry.OrientedBoundingBox(bbox.center, bbox.R, bbox.extent)
28 | return bbox_o3d
29 |
30 | def update_vmap(models, optimiser):
31 | fmodel, params, buffers = combine_state_for_ensemble(models)
32 | [p.requires_grad_() for p in params]
33 | optimiser.add_param_group({"params": params}) # imap b l
34 | return (fmodel, params, buffers)
35 |
36 | def enlarge_bbox(bbox, scale, w, h):
37 | assert scale >= 0
38 | # print(bbox)
39 | min_x, min_y, max_x, max_y = bbox
40 | margin_x = int(0.5 * scale * (max_x - min_x))
41 | margin_y = int(0.5 * scale * (max_y - min_y))
42 | if margin_y == 0 or margin_x == 0:
43 | return None
44 | # assert margin_x != 0
45 | # assert margin_y != 0
46 | min_x -= margin_x
47 | max_x += margin_x
48 | min_y -= margin_y
49 | max_y += margin_y
50 |
51 | min_x = np.clip(min_x, 0, w-1)
52 | min_y = np.clip(min_y, 0, h-1)
53 | max_x = np.clip(max_x, 0, w-1)
54 | max_y = np.clip(max_y, 0, h-1)
55 |
56 | bbox_enlarged = [int(min_x), int(min_y), int(max_x), int(max_y)]
57 | return bbox_enlarged
58 |
59 | def get_bbox2d(obj_mask, bbox_scale=1.0):
60 | contours, hierarchy = cv2.findContours(obj_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[
61 | -2:]
62 | # # Find the index of the largest contour
63 | # areas = [cv2.contourArea(c) for c in contours]
64 | # max_index = np.argmax(areas)
65 | # cnt = contours[max_index]
66 | # Concatenate all contours
67 | if len(contours) == 0:
68 | return None
69 | cnt = np.concatenate(contours)
70 | x, y, w, h = cv2.boundingRect(cnt) # todo if multiple contours, choose the outmost one?
71 | # x, y, w, h = cv2.boundingRect(contours)
72 | bbox_enlarged = enlarge_bbox([x, y, x + w, y + h], scale=bbox_scale, w=obj_mask.shape[1], h=obj_mask.shape[0])
73 | return bbox_enlarged
74 |
75 | def get_bbox2d_batch(img):
76 | b,h,w = img.shape[:3]
77 | rows = torch.any(img, axis=2)
78 | cols = torch.any(img, axis=1)
79 | rmins = torch.argmax(rows.float(), dim=1)
80 | rmaxs = h - torch.argmax(rows.float().flip(dims=[1]), dim=1)
81 | cmins = torch.argmax(cols.float(), dim=1)
82 | cmaxs = w - torch.argmax(cols.float().flip(dims=[1]), dim=1)
83 |
84 | return rmins, rmaxs, cmins, cmaxs
85 |
86 | def get_latest_queue(q):
87 | message = None
88 | while(True):
89 | try:
90 | message_latest = q.get(block=False)
91 | if message is not None:
92 | del message
93 | message = message_latest
94 |
95 | except queue.Empty:
96 | break
97 |
98 | return message
99 |
100 | # for association/tracking
101 | class InstData:
102 | def __init__(self):
103 | super(InstData, self).__init__()
104 | self.bbox3D = None
105 | self.inst_id = None # instance
106 | self.class_id = None # semantic
107 | self.pc_sample = None
108 | self.merge_cnt = 0 # merge times counting
109 | self.cmp_cnt = 0
110 |
111 |
112 | def box_filter(masks, classes, depth, inst_dict, intrinsic_open3d, T_CW, min_pixels=500, voxel_size=0.01):
113 | bbox3d_scale = 1.0 # 1.05
114 | inst_data = np.zeros_like(depth, dtype=np.int)
115 | for i in range(len(masks)):
116 | diff_mask = None
117 | inst_mask = masks[i]
118 | inst_id = classes[i]
119 | if inst_id == 0:
120 | continue
121 | inst_depth = np.copy(depth)
122 | inst_depth[~inst_mask] = 0. # inst_mask
123 | # proj_time = time.time()
124 | inst_pc = unproject_pointcloud(inst_depth, intrinsic_open3d, T_CW)
125 | # print("proj time ", time.time()-proj_time)
126 | if len(inst_pc.points) <= 10: # too small
127 | inst_data[inst_mask] = 0 # set to background
128 | continue
129 | if inst_id in inst_dict.keys():
130 | candidate_inst = inst_dict[inst_id]
131 | # iou_time = time.time()
132 | IoU, indices = check_inside_ratio(inst_pc, candidate_inst.bbox3D)
133 | # print("iou time ", time.time()-iou_time)
134 | # if indices empty
135 | candidate_inst.cmp_cnt += 1
136 | if len(indices) >= 1:
137 | candidate_inst.pc += inst_pc.select_by_index(indices) # only merge pcs inside scale*bbox
138 | # todo check indices follow valid depth
139 | valid_depth_mask = np.zeros_like(inst_depth, dtype=np.bool)
140 | valid_pc_mask = valid_depth_mask[inst_depth!=0]
141 | valid_pc_mask[indices] = True
142 | valid_depth_mask[inst_depth != 0] = valid_pc_mask
143 | valid_mask = valid_depth_mask
144 | diff_mask = np.zeros_like(inst_mask)
145 | # uv_opencv, _ = cv2.projectPoints(np.array(inst_pc.select_by_index(indices).points), T_CW[:3, :3],
146 | # T_CW[:3, 3], intrinsic_open3d.intrinsic_matrix[:3, :3], None)
147 | # uv = np.round(uv_opencv).squeeze().astype(int)
148 | # u = uv[:, 0].reshape(-1, 1)
149 | # v = uv[:, 1].reshape(-1, 1)
150 | # vu = np.concatenate([v, u], axis=-1)
151 | # valid_mask = np.zeros_like(inst_mask)
152 | # valid_mask[tuple(vu.T)] = True
153 | # # cv2.imshow("valid", (inst_depth!=0).astype(np.uint8)*255)
154 | # # cv2.waitKey(1)
155 | diff_mask[(inst_depth != 0) & (~valid_mask)] = True
156 | # cv2.imshow("diff_mask", diff_mask.astype(np.uint8) * 255)
157 | # cv2.waitKey(1)
158 | else: # merge all for scannet
159 | # print("too few pcs obj ", inst_id)
160 | inst_data[inst_mask] = -1
161 | continue
162 | # downsample_time = time.time()
163 | # adapt_voxel_size = np.maximum(np.max(candidate_inst.bbox3D.extent)/100, 0.1)
164 | candidate_inst.pc = candidate_inst.pc.voxel_down_sample(voxel_size) # adapt_voxel_size
165 | # candidate_inst.pc = candidate_inst.pc.farthest_point_down_sample(500)
166 | # candidate_inst.pc = candidate_inst.pc.random_down_sample(np.minimum(len(candidate_inst.pc.points)/500.,1))
167 | # print("downsample time ", time.time() - downsample_time) # 0.03s even
168 | # bbox_time = time.time()
169 | try:
170 | candidate_inst.bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(candidate_inst.pc.points)
171 | except RuntimeError:
172 | # print("too few pcs obj ", inst_id)
173 | inst_data[inst_mask] = -1
174 | continue
175 | # enlarge
176 | candidate_inst.bbox3D.scale(bbox3d_scale, candidate_inst.bbox3D.get_center())
177 | else: # new inst
178 | # init new inst and new sem
179 | new_inst = InstData()
180 | new_inst.inst_id = inst_id
181 | smaller_mask = cv2.erode(inst_mask.astype(np.uint8), np.ones((5, 5)), iterations=3).astype(bool)
182 | if np.sum(smaller_mask) < min_pixels:
183 | # print("too few pcs obj ", inst_id)
184 | inst_data[inst_mask] = 0
185 | continue
186 | inst_depth_small = depth.copy()
187 | inst_depth_small[~smaller_mask] = 0
188 | inst_pc_small = unproject_pointcloud(inst_depth_small, intrinsic_open3d, T_CW)
189 | new_inst.pc = inst_pc_small
190 | new_inst.pc = new_inst.pc.voxel_down_sample(voxel_size)
191 | try:
192 | inst_bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(new_inst.pc.points)
193 | except RuntimeError:
194 | # print("too few pcs obj ", inst_id)
195 | inst_data[inst_mask] = 0
196 | continue
197 | # scale up
198 | inst_bbox3D.scale(bbox3d_scale, inst_bbox3D.get_center())
199 | new_inst.bbox3D = inst_bbox3D
200 | # update inst_dict
201 | inst_dict.update({inst_id: new_inst}) # init new sem
202 |
203 | # update inst_data
204 | inst_data[inst_mask] = inst_id
205 | if diff_mask is not None:
206 | inst_data[diff_mask] = -1 # unsure area
207 |
208 | return inst_data
209 |
210 | def load_matrix_from_txt(path, shape=(4, 4)):
211 | with open(path) as f:
212 | txt = f.readlines()
213 | txt = ''.join(txt).replace('\n', ' ')
214 | matrix = [float(v) for v in txt.split()]
215 | return np.array(matrix).reshape(shape)
216 |
217 | def check_mask_order(obj_masks, depth_np, obj_ids):
218 | print(len(obj_masks))
219 | print(len(obj_ids))
220 |
221 | assert len(obj_masks) == len(obj_ids)
222 | depth = torch.from_numpy(depth_np)
223 | obj_masked_modified = copy.deepcopy(obj_masks[:])
224 | for i in range(len(obj_masks) - 1):
225 |
226 | mask1 = obj_masks[i].float()
227 | mask1_ = obj_masked_modified[i].float()
228 | for j in range(i + 1, len(obj_masks)):
229 | mask2 = obj_masks[j].float()
230 | mask2_ = obj_masked_modified[j].float()
231 | # case 1: if they don't intersect we don't touch them
232 | if ((mask1 + mask2) == 2).sum() == 0:
233 | continue
234 | # case 2: the entire object 1 is inside of object 2, we say object 1 is in front of object 2:
235 | elif (((mask1 + mask2) == 2).float() - mask1).sum() == 0:
236 | mask2_ -= mask1_
237 | # case 3: the entire object 2 is inside of object 1, we say object 2 is in front of object 1:
238 | elif (((mask1 + mask2) == 2).float() - mask2).sum() == 0:
239 | mask1_ -= mask2_
240 | # case 4: use depth to check object order:
241 | else:
242 | # object 1 is closer
243 | if (depth * mask1).sum() / mask1.sum() > (depth * mask2).sum() / mask2.sum():
244 | mask2_ -= ((mask1 + mask2) == 2).float()
245 | # object 2 is closer
246 | if (depth * mask1).sum() / mask1.sum() < (depth * mask2).sum() / mask2.sum():
247 | mask1_ -= ((mask1 + mask2) == 2).float()
248 |
249 | final_mask = torch.zeros_like(depth, dtype=torch.int)
250 | # instance_labels = {}
251 | for i in range(len(obj_masked_modified)):
252 | final_mask = final_mask.masked_fill(obj_masked_modified[i] > 0, obj_ids[i])
253 | # instance_labels[i] = obj_ids[i].item()
254 | return final_mask.cpu().numpy()
255 |
256 |
257 | def unproject_pointcloud(depth, intrinsic_open3d, T_CW):
258 | # depth, mask, intrinsic, extrinsic -> point clouds
259 | pc_sample = open3d.geometry.PointCloud.create_from_depth_image(depth=open3d.geometry.Image(depth),
260 | intrinsic=intrinsic_open3d,
261 | extrinsic=T_CW,
262 | depth_scale=1.0,
263 | project_valid_depth_only=True)
264 | return pc_sample
265 |
266 | def check_inside_ratio(pc, bbox3D):
267 | # pc, bbox3d -> inside ratio
268 | indices = bbox3D.get_point_indices_within_bounding_box(pc.points)
269 | assert len(pc.points) > 0
270 | ratio = len(indices) / len(pc.points)
271 | # print("ratio ", ratio)
272 | return ratio, indices
273 |
274 | def track_instance(masks, classes, depth, inst_list, sem_dict, intrinsic_open3d, T_CW, IoU_thresh=0.5, voxel_size=0.1,
275 | min_pixels=2000, erode=True, clip_features=None, class_names=None):
276 | device = masks.device
277 | inst_data_dict = {}
278 | inst_data_dict.update({0: torch.zeros(depth.shape, dtype=torch.int, device=device)})
279 | inst_ids = []
280 | bbox3d_scale = 1.0 # todo 1.0
281 | min_extent = 0.05
282 | depth = torch.from_numpy(depth).to(device)
283 | for i in range(len(masks)):
284 | inst_data = torch.zeros(depth.shape, dtype=torch.int, device=device)
285 | smaller_mask = cv2.erode(masks[i].detach().cpu().numpy().astype(np.uint8), np.ones((5, 5)), iterations=3).astype(bool)
286 | inst_depth_small = depth.detach().cpu().numpy()
287 | inst_depth_small[~smaller_mask] = 0
288 | inst_pc_small = unproject_pointcloud(inst_depth_small, intrinsic_open3d, T_CW)
289 | diff_mask = None
290 | if np.sum(smaller_mask) <= min_pixels: # too small 20 400 # todo use sem to set background
291 | inst_data[masks[i]] = 0 # set to background
292 | continue
293 | inst_pc_voxel = inst_pc_small.voxel_down_sample(voxel_size)
294 | if len(inst_pc_voxel.points) <= 10: # too small 20 400 # todo use sem to set background
295 | inst_data[masks[i]] = 0 # set to background
296 | continue
297 | is_merged = False
298 | inst_id = None
299 | inst_mask = masks[i] #smaller_mask #masks[i] # todo only
300 | inst_class = classes[i]
301 | inst_depth = depth.detach().cpu().numpy()
302 | inst_depth[~masks[i].detach().cpu().numpy()] = 0. # inst_mask
303 | inst_pc = unproject_pointcloud(inst_depth, intrinsic_open3d, T_CW)
304 | sem_inst_list = []
305 | if clip_features is not None: # check similar sems based on clip feature distance
306 | sem_thr = 200 #300. for table #320. # 260.
307 | for sem_exist in sem_dict.keys():
308 | if torch.abs(clip_features[class_names[inst_class]] - clip_features[class_names[sem_exist]]).sum() < sem_thr:
309 | sem_inst_list.extend(sem_dict[sem_exist])
310 | else: # no clip features, only do strictly sem check
311 | if inst_class in sem_dict.keys():
312 | sem_inst_list.extend(sem_dict[inst_class])
313 |
314 | for candidate_inst in sem_inst_list:
315 | # if True: # only consider 3D bbox, merge them if they are spatial together
316 | IoU, indices = check_inside_ratio(inst_pc, candidate_inst.bbox3D)
317 | candidate_inst.cmp_cnt += 1
318 | if IoU > IoU_thresh:
319 | # merge inst to candidate
320 | is_merged = True
321 | candidate_inst.merge_cnt += 1
322 | candidate_inst.pc += inst_pc.select_by_index(indices)
323 | # inst_uv = inst_pc.select_by_index(indices).project_to_depth_image(masks[i].shape[1], masks[i].shape[0], intrinsic_open3d, T_CW, depth_scale=1.0, depth_max=10.0)
324 | # # inst_uv = torch.utils.dlpack.from_dlpack(uv_opencv.as_tensor().to_dlpack())
325 | # valid_mask = inst_uv.squeeze() > 0. # shape --> H, W
326 | # diff_mask = (inst_depth > 0.) & (~valid_mask)
327 | diff_mask = torch.zeros_like(inst_mask)
328 | uv_opencv, _ = cv2.projectPoints(np.array(inst_pc.select_by_index(indices).points), T_CW[:3, :3],
329 | T_CW[:3, 3], intrinsic_open3d.intrinsic_matrix[:3,:3], None)
330 | uv = np.round(uv_opencv).squeeze().astype(int)
331 | u = uv[:, 0].reshape(-1, 1)
332 | v = uv[:, 1].reshape(-1, 1)
333 | vu = np.concatenate([v, u], axis=-1)
334 | valid_mask = np.zeros(inst_mask.shape, dtype=np.bool)
335 | valid_mask[tuple(vu.T)] = True
336 | diff_mask[(inst_depth!=0) & (~valid_mask)] = True
337 | # downsample pcs
338 | candidate_inst.pc = candidate_inst.pc.voxel_down_sample(voxel_size)
339 | # candidate_inst.pc.random_down_sample(np.minimum(500//len(candidate_inst.pc.points),1))
340 | candidate_inst.bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(candidate_inst.pc.points)
341 | # enlarge
342 | candidate_inst.bbox3D.scale(bbox3d_scale, candidate_inst.bbox3D.get_center())
343 | candidate_inst.bbox3D.extent = np.maximum(candidate_inst.bbox3D.extent, min_extent) # at least bigger than min_extent
344 | inst_id = candidate_inst.inst_id
345 | break
346 | # if candidate_inst.cmp_cnt >= 20 and candidate_inst.merge_cnt <= 5:
347 | # sem_inst_list.remove(candidate_inst)
348 |
349 | if not is_merged:
350 | # init new inst and new sem
351 | new_inst = InstData()
352 | new_inst.inst_id = len(inst_list) + 1
353 | new_inst.class_id = inst_class
354 |
355 | new_inst.pc = inst_pc_small
356 | new_inst.pc = new_inst.pc.voxel_down_sample(voxel_size)
357 | inst_bbox3D = open3d.geometry.OrientedBoundingBox.create_from_points(new_inst.pc.points)
358 | # scale up
359 | inst_bbox3D.scale(bbox3d_scale, inst_bbox3D.get_center())
360 | inst_bbox3D.extent = np.maximum(inst_bbox3D.extent, min_extent)
361 | new_inst.bbox3D = inst_bbox3D
362 | inst_list.append(new_inst)
363 | inst_id = new_inst.inst_id
364 | # update sem_dict
365 | if inst_class in sem_dict.keys():
366 | sem_dict[inst_class].append(new_inst) # append new inst to exist sem
367 | else:
368 | sem_dict.update({inst_class: [new_inst]}) # init new sem
369 | # update inst_data
370 | inst_data[inst_mask] = inst_id
371 | if diff_mask is not None:
372 | inst_data[diff_mask] = -1 # unsure area
373 | if inst_id not in inst_ids:
374 | inst_data_dict.update({inst_id: inst_data})
375 | else:
376 | continue
377 | # idx = inst_ids.index(inst_id)
378 | # inst_data_list[idx] = inst_data_list[idx] & torch.from_numpy(inst_data) # merge them? todo
379 | # return inst_data
380 | mask_bg = torch.stack(list(inst_data_dict.values())).sum(0) != 0
381 | inst_data_dict.update({0: mask_bg.int()})
382 | return inst_data_dict
383 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import loss
3 | from vmap import *
4 | import utils
5 | import open3d
6 | import dataset
7 | import vis
8 | from functorch import vmap
9 | import argparse
10 | from cfg import Config
11 | import shutil
12 |
13 | if __name__ == "__main__":
14 | #############################################
15 | # init config
16 | torch.backends.cudnn.enabled = True
17 | torch.backends.cudnn.benchmark = True
18 |
19 | # setting params
20 | parser = argparse.ArgumentParser(description='Model training for single GPU')
21 | parser.add_argument('--logdir', default='./logs/debug',
22 | type=str)
23 | parser.add_argument('--config',
24 | default='./configs/Replica/config_replica_room0_vMAP.json',
25 | type=str)
26 | parser.add_argument('--save_ckpt',
27 | default=False,
28 | type=bool)
29 | args = parser.parse_args()
30 |
31 | log_dir = args.logdir
32 | config_file = args.config
33 | save_ckpt = args.save_ckpt
34 | os.makedirs(log_dir, exist_ok=True) # saving logs
35 | shutil.copy(config_file, log_dir)
36 | cfg = Config(config_file) # config params
37 | n_sample_per_step = cfg.n_per_optim
38 | n_sample_per_step_bg = cfg.n_per_optim_bg
39 |
40 | # param for vis
41 | vis3d = open3d.visualization.Visualizer()
42 | vis3d.create_window(window_name="3D mesh vis",
43 | width=cfg.W,
44 | height=cfg.H,
45 | left=600, top=50)
46 | view_ctl = vis3d.get_view_control()
47 | view_ctl.set_constant_z_far(10.)
48 |
49 | # set camera
50 | cam_info = cameraInfo(cfg)
51 | intrinsic_open3d = open3d.camera.PinholeCameraIntrinsic(
52 | width=cfg.W,
53 | height=cfg.H,
54 | fx=cfg.fx,
55 | fy=cfg.fy,
56 | cx=cfg.cx,
57 | cy=cfg.cy)
58 |
59 | # init obj_dict
60 | obj_dict = {} # only objs
61 | vis_dict = {} # including bg
62 |
63 | # init for training
64 | AMP = False
65 | if AMP:
66 | scaler = torch.cuda.amp.GradScaler() # amp https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/
67 | optimiser = torch.optim.AdamW([torch.autograd.Variable(torch.tensor(0))], lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
68 |
69 | #############################################
70 | # init data stream
71 | if not cfg.live_mode:
72 | # load dataset
73 | dataloader = dataset.init_loader(cfg)
74 | dataloader_iterator = iter(dataloader)
75 | dataset_len = len(dataloader)
76 | else:
77 | dataset_len = 1000000
78 | # # init ros node
79 | # torch.multiprocessing.set_start_method('spawn') # spawn
80 | # import ros_nodes
81 | # track_to_map_Buffer = torch.multiprocessing.Queue(maxsize=5)
82 | # # track_to_vis_T_WC = torch.multiprocessing.Queue(maxsize=1)
83 | # kfs_que = torch.multiprocessing.Queue(maxsize=5) # to store one more buffer
84 | # track_p = torch.multiprocessing.Process(target=ros_nodes.Tracking,
85 | # args=(
86 | # (cfg), (track_to_map_Buffer), (None),
87 | # (kfs_que), (True),))
88 | # track_p.start()
89 |
90 |
91 | # init vmap
92 | fc_models, pe_models = [], []
93 | scene_bg = None
94 |
95 | for frame_id in tqdm(range(dataset_len)):
96 | print("*********************************************")
97 | # get new frame data
98 | with performance_measure(f"getting next data"):
99 | if not cfg.live_mode:
100 | # get data from dataloader
101 | sample = next(dataloader_iterator)
102 | else:
103 | pass
104 |
105 | if sample is not None: # new frame
106 | last_frame_time = time.time()
107 | with performance_measure(f"Appending data"):
108 | rgb = sample["image"].to(cfg.data_device)
109 | depth = sample["depth"].to(cfg.data_device)
110 | twc = sample["T"].to(cfg.data_device)
111 | bbox_dict = sample["bbox_dict"]
112 | if "frame_id" in sample.keys():
113 | live_frame_id = sample["frame_id"]
114 | else:
115 | live_frame_id = frame_id
116 | if not cfg.live_mode:
117 | inst = sample["obj"].to(cfg.data_device)
118 | obj_ids = torch.unique(inst)
119 | else:
120 | inst_data_dict = sample["obj"]
121 | obj_ids = inst_data_dict.keys()
122 | # append new frame info to objs in current view
123 | for obj_id in obj_ids:
124 | if obj_id == -1: # unsured area
125 | continue
126 | obj_id = int(obj_id)
127 | # convert inst mask to state
128 | if not cfg.live_mode:
129 | state = torch.zeros_like(inst, dtype=torch.uint8, device=cfg.data_device)
130 | state[inst == obj_id] = 1
131 | state[inst == -1] = 2
132 | else:
133 | inst_mask = inst_data_dict[obj_id].permute(1,0)
134 | label_list = torch.unique(inst_mask).tolist()
135 | state = torch.zeros_like(inst_mask, dtype=torch.uint8, device=cfg.data_device)
136 | state[inst_mask == obj_id] = 1
137 | state[inst_mask == -1] = 2
138 | bbox = bbox_dict[obj_id]
139 | if obj_id in vis_dict.keys():
140 | scene_obj = vis_dict[obj_id]
141 | scene_obj.append_keyframe(rgb, depth, state, bbox, twc, live_frame_id)
142 | else: # init scene_obj
143 | if len(obj_dict.keys()) >= cfg.max_n_models:
144 | print("models full!!!! current num ", len(obj_dict.keys()))
145 | continue
146 | print("init new obj ", obj_id)
147 | if cfg.do_bg and obj_id == 0: # todo param
148 | scene_bg = sceneObject(cfg, obj_id, rgb, depth, state, bbox, twc, live_frame_id)
149 | # scene_bg.init_obj_center(intrinsic_open3d, depth, state, twc)
150 | optimiser.add_param_group({"params": scene_bg.trainer.fc_occ_map.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay})
151 | optimiser.add_param_group({"params": scene_bg.trainer.pe.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay})
152 | vis_dict.update({obj_id: scene_bg})
153 | else:
154 | scene_obj = sceneObject(cfg, obj_id, rgb, depth, state, bbox, twc, live_frame_id)
155 | # scene_obj.init_obj_center(intrinsic_open3d, depth, state, twc)
156 | obj_dict.update({obj_id: scene_obj})
157 | vis_dict.update({obj_id: scene_obj})
158 | # params = [scene_obj.trainer.fc_occ_map.parameters(), scene_obj.trainer.pe.parameters()]
159 | optimiser.add_param_group({"params": scene_obj.trainer.fc_occ_map.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay})
160 | optimiser.add_param_group({"params": scene_obj.trainer.pe.parameters(), "lr": cfg.learning_rate, "weight_decay": cfg.weight_decay})
161 | if cfg.training_strategy == "vmap":
162 | update_vmap_model = True
163 | fc_models.append(obj_dict[obj_id].trainer.fc_occ_map)
164 | pe_models.append(obj_dict[obj_id].trainer.pe)
165 |
166 | # ###################################
167 | # # measure trainable params in total
168 | # total_params = 0
169 | # obj_k = obj_dict[obj_id]
170 | # for p in obj_k.trainer.fc_occ_map.parameters():
171 | # if p.requires_grad:
172 | # total_params += p.numel()
173 | # for p in obj_k.trainer.pe.parameters():
174 | # if p.requires_grad:
175 | # total_params += p.numel()
176 | # print("total param ", total_params)
177 |
178 | # dynamically add vmap
179 | with performance_measure(f"add vmap"):
180 | if cfg.training_strategy == "vmap" and update_vmap_model == True:
181 | fc_model, fc_param, fc_buffer = utils.update_vmap(fc_models, optimiser)
182 | pe_model, pe_param, pe_buffer = utils.update_vmap(pe_models, optimiser)
183 | update_vmap_model = False
184 |
185 |
186 | ##################################################################
187 | # training data preperation, get training data for all objs
188 | Batch_N_gt_depth = []
189 | Batch_N_gt_rgb = []
190 | Batch_N_depth_mask = []
191 | Batch_N_obj_mask = []
192 | Batch_N_input_pcs = []
193 | Batch_N_sampled_z = []
194 |
195 | with performance_measure(f"Sampling over {len(obj_dict.keys())} objects,"):
196 | if cfg.do_bg and scene_bg is not None:
197 | gt_rgb, gt_depth, valid_depth_mask, obj_mask, input_pcs, sampled_z \
198 | = scene_bg.get_training_samples(cfg.n_iter_per_frame * cfg.win_size_bg, cfg.n_samples_per_frame_bg,
199 | cam_info.rays_dir_cache)
200 | bg_gt_depth = gt_depth.reshape([gt_depth.shape[0] * gt_depth.shape[1]])
201 | bg_gt_rgb = gt_rgb.reshape([gt_rgb.shape[0] * gt_rgb.shape[1], gt_rgb.shape[2]])
202 | bg_valid_depth_mask = valid_depth_mask
203 | bg_obj_mask = obj_mask
204 | bg_input_pcs = input_pcs.reshape(
205 | [input_pcs.shape[0] * input_pcs.shape[1], input_pcs.shape[2], input_pcs.shape[3]])
206 | bg_sampled_z = sampled_z.reshape([sampled_z.shape[0] * sampled_z.shape[1], sampled_z.shape[2]])
207 |
208 | for obj_id, obj_k in obj_dict.items():
209 | gt_rgb, gt_depth, valid_depth_mask, obj_mask, input_pcs, sampled_z \
210 | = obj_k.get_training_samples(cfg.n_iter_per_frame * cfg.win_size, cfg.n_samples_per_frame,
211 | cam_info.rays_dir_cache)
212 | # merge first two dims, sample_per_frame*num_per_frame
213 | Batch_N_gt_depth.append(gt_depth.reshape([gt_depth.shape[0] * gt_depth.shape[1]]))
214 | Batch_N_gt_rgb.append(gt_rgb.reshape([gt_rgb.shape[0] * gt_rgb.shape[1], gt_rgb.shape[2]]))
215 | Batch_N_depth_mask.append(valid_depth_mask)
216 | Batch_N_obj_mask.append(obj_mask)
217 | Batch_N_input_pcs.append(input_pcs.reshape([input_pcs.shape[0] * input_pcs.shape[1], input_pcs.shape[2], input_pcs.shape[3]]))
218 | Batch_N_sampled_z.append(sampled_z.reshape([sampled_z.shape[0] * sampled_z.shape[1], sampled_z.shape[2]]))
219 |
220 | # # vis sampled points in open3D
221 | # # sampled pcs
222 | # pc = open3d.geometry.PointCloud()
223 | # pc.points = open3d.utility.Vector3dVector(input_pcs.cpu().numpy().reshape(-1,3))
224 | # open3d.visualization.draw_geometries([pc])
225 | # rgb_np = rgb.cpu().numpy().astype(np.uint8).transpose(1,0,2)
226 | # # print("rgb ", rgb_np.shape)
227 | # # print(rgb_np)
228 | # # cv2.imshow("rgb", rgb_np)
229 | # # cv2.waitKey(1)
230 | # depth_np = depth.cpu().numpy().astype(np.float32).transpose(1,0)
231 | # twc_np = twc.cpu().numpy()
232 | # rgbd = open3d.geometry.RGBDImage.create_from_color_and_depth(
233 | # open3d.geometry.Image(rgb_np),
234 | # open3d.geometry.Image(depth_np),
235 | # depth_trunc=max_depth,
236 | # depth_scale=1,
237 | # convert_rgb_to_intensity=False,
238 | # )
239 | # T_CW = np.linalg.inv(twc_np)
240 | # # input image pc
241 | # input_pc = open3d.geometry.PointCloud.create_from_rgbd_image(
242 | # image=rgbd,
243 | # intrinsic=intrinsic_open3d,
244 | # extrinsic=T_CW)
245 | # input_pc.points = open3d.utility.Vector3dVector(np.array(input_pc.points) - obj_k.obj_center.cpu().numpy())
246 | # open3d.visualization.draw_geometries([pc, input_pc])
247 |
248 |
249 | ####################################################
250 | # training
251 | assert len(Batch_N_input_pcs) > 0
252 | # move data to GPU (n_obj, n_iter_per_frame, win_size*num_per_frame, 3)
253 | with performance_measure(f"stacking and moving to gpu: "):
254 |
255 | Batch_N_input_pcs = torch.stack(Batch_N_input_pcs).to(cfg.training_device)
256 | Batch_N_gt_depth = torch.stack(Batch_N_gt_depth).to(cfg.training_device)
257 | Batch_N_gt_rgb = torch.stack(Batch_N_gt_rgb).to(cfg.training_device) / 255. # todo
258 | Batch_N_depth_mask = torch.stack(Batch_N_depth_mask).to(cfg.training_device)
259 | Batch_N_obj_mask = torch.stack(Batch_N_obj_mask).to(cfg.training_device)
260 | Batch_N_sampled_z = torch.stack(Batch_N_sampled_z).to(cfg.training_device)
261 | if cfg.do_bg:
262 | bg_input_pcs = bg_input_pcs.to(cfg.training_device)
263 | bg_gt_depth = bg_gt_depth.to(cfg.training_device)
264 | bg_gt_rgb = bg_gt_rgb.to(cfg.training_device) / 255.
265 | bg_valid_depth_mask = bg_valid_depth_mask.to(cfg.training_device)
266 | bg_obj_mask = bg_obj_mask.to(cfg.training_device)
267 | bg_sampled_z = bg_sampled_z.to(cfg.training_device)
268 |
269 | with performance_measure(f"Training over {len(obj_dict.keys())} objects,"):
270 | for iter_step in range(cfg.n_iter_per_frame):
271 | data_idx = slice(iter_step*n_sample_per_step, (iter_step+1)*n_sample_per_step)
272 | batch_input_pcs = Batch_N_input_pcs[:, data_idx, ...]
273 | batch_gt_depth = Batch_N_gt_depth[:, data_idx, ...]
274 | batch_gt_rgb = Batch_N_gt_rgb[:, data_idx, ...]
275 | batch_depth_mask = Batch_N_depth_mask[:, data_idx, ...]
276 | batch_obj_mask = Batch_N_obj_mask[:, data_idx, ...]
277 | batch_sampled_z = Batch_N_sampled_z[:, data_idx, ...]
278 | if cfg.training_strategy == "forloop":
279 | # for loop training
280 | batch_alpha = []
281 | batch_color = []
282 | for k, obj_id in enumerate(obj_dict.keys()):
283 | obj_k = obj_dict[obj_id]
284 | embedding_k = obj_k.trainer.pe(batch_input_pcs[k])
285 | alpha_k, color_k = obj_k.trainer.fc_occ_map(embedding_k)
286 | batch_alpha.append(alpha_k)
287 | batch_color.append(color_k)
288 |
289 | batch_alpha = torch.stack(batch_alpha)
290 | batch_color = torch.stack(batch_color)
291 | elif cfg.training_strategy == "vmap":
292 | # batched training
293 | batch_embedding = vmap(pe_model)(pe_param, pe_buffer, batch_input_pcs)
294 | batch_alpha, batch_color = vmap(fc_model)(fc_param, fc_buffer, batch_embedding)
295 | # print("batch alpha ", batch_alpha.shape)
296 | else:
297 | print("training strategy {} is not implemented ".format(cfg.training_strategy))
298 | exit(-1)
299 |
300 |
301 | # step loss
302 | # with performance_measure(f"Batch LOSS"):
303 | batch_loss, _ = loss.step_batch_loss(batch_alpha, batch_color,
304 | batch_gt_depth.detach(), batch_gt_rgb.detach(),
305 | batch_obj_mask.detach(), batch_depth_mask.detach(),
306 | batch_sampled_z.detach())
307 |
308 | if cfg.do_bg:
309 | bg_data_idx = slice(iter_step * n_sample_per_step_bg, (iter_step + 1) * n_sample_per_step_bg)
310 | bg_embedding = scene_bg.trainer.pe(bg_input_pcs[bg_data_idx, ...])
311 | bg_alpha, bg_color = scene_bg.trainer.fc_occ_map(bg_embedding)
312 | bg_loss, _ = loss.step_batch_loss(bg_alpha[None, ...], bg_color[None, ...],
313 | bg_gt_depth[None, bg_data_idx, ...].detach(), bg_gt_rgb[None, bg_data_idx].detach(),
314 | bg_obj_mask[None, bg_data_idx, ...].detach(), bg_valid_depth_mask[None, bg_data_idx, ...].detach(),
315 | bg_sampled_z[None, bg_data_idx, ...].detach())
316 | batch_loss += bg_loss
317 |
318 | # with performance_measure(f"Backward"):
319 | if AMP:
320 | scaler.scale(batch_loss).backward()
321 | scaler.step(optimiser)
322 | scaler.update()
323 | else:
324 | batch_loss.backward()
325 | optimiser.step()
326 | optimiser.zero_grad(set_to_none=True)
327 | # print("loss ", batch_loss.item())
328 |
329 | # update each origin model params
330 | # todo find a better way # https://github.com/pytorch/functorch/issues/280
331 | with performance_measure(f"updating vmap param"):
332 | if cfg.training_strategy == "vmap":
333 | with torch.no_grad():
334 | for model_id, (obj_id, obj_k) in enumerate(obj_dict.items()):
335 | for i, param in enumerate(obj_k.trainer.fc_occ_map.parameters()):
336 | param.copy_(fc_param[i][model_id])
337 | for i, param in enumerate(obj_k.trainer.pe.parameters()):
338 | param.copy_(pe_param[i][model_id])
339 |
340 |
341 | ####################################################################
342 | # live vis mesh
343 | if (((frame_id % cfg.n_vis_iter) == 0 or frame_id == dataset_len-1) or
344 | (cfg.live_mode and time.time()-last_frame_time>cfg.keep_live_time)) and frame_id >= 10:
345 | vis3d.clear_geometries()
346 | for obj_id, obj_k in vis_dict.items():
347 | bound = obj_k.get_bound(intrinsic_open3d)
348 | if bound is None:
349 | print("get bound failed obj ", obj_id)
350 | continue
351 | adaptive_grid_dim = int(np.minimum(np.max(bound.extent)//cfg.live_voxel_size+1, cfg.grid_dim))
352 | mesh = obj_k.trainer.meshing(bound, obj_k.obj_center, grid_dim=adaptive_grid_dim)
353 | if mesh is None:
354 | print("meshing failed obj ", obj_id)
355 | continue
356 |
357 | # save to dir
358 | obj_mesh_output = os.path.join(log_dir, "scene_mesh")
359 | os.makedirs(obj_mesh_output, exist_ok=True)
360 | mesh.export(os.path.join(obj_mesh_output, "frame_{}_obj{}.obj".format(frame_id, str(obj_id))))
361 |
362 | # live vis
363 | open3d_mesh = vis.trimesh_to_open3d(mesh)
364 | vis3d.add_geometry(open3d_mesh)
365 | vis3d.add_geometry(bound)
366 | # update vis3d
367 | vis3d.poll_events()
368 | vis3d.update_renderer()
369 |
370 | if False: # follow cam
371 | cam = view_ctl.convert_to_pinhole_camera_parameters()
372 | T_CW_np = np.linalg.inv(twc.cpu().numpy())
373 | cam.extrinsic = T_CW_np
374 | view_ctl.convert_from_pinhole_camera_parameters(cam)
375 | vis3d.poll_events()
376 | vis3d.update_renderer()
377 |
378 | with performance_measure("saving ckpt"):
379 | if save_ckpt and ((((frame_id % cfg.n_vis_iter) == 0 or frame_id == dataset_len - 1) or
380 | (cfg.live_mode and time.time() - last_frame_time > cfg.keep_live_time)) and frame_id >= 10):
381 | for obj_id, obj_k in vis_dict.items():
382 | ckpt_dir = os.path.join(log_dir, "ckpt", str(obj_id))
383 | os.makedirs(ckpt_dir, exist_ok=True)
384 | bound = obj_k.get_bound(intrinsic_open3d) # update bound
385 | obj_k.save_checkpoints(ckpt_dir, frame_id)
386 | # save current cam pose
387 | cam_dir = os.path.join(log_dir, "cam_pose")
388 | os.makedirs(cam_dir, exist_ok=True)
389 | torch.save({"twc": twc,}, os.path.join(cam_dir, "twc_frame_{}".format(frame_id) + ".pth"))
390 |
391 |
392 |
--------------------------------------------------------------------------------
/vmap.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from time import perf_counter_ns
5 | from tqdm import tqdm
6 | import trainer
7 | import open3d
8 | import trimesh
9 | import scipy
10 | from bidict import bidict
11 | import copy
12 | import os
13 |
14 | import utils
15 |
16 |
17 | class performance_measure:
18 |
19 | def __init__(self, name) -> None:
20 | self.name = name
21 |
22 | def __enter__(self):
23 | self.start_time = perf_counter_ns()
24 |
25 | def __exit__(self, type, value, tb):
26 | self.end_time = perf_counter_ns()
27 | self.exec_time = self.end_time - self.start_time
28 |
29 | print(f"{self.name} excution time: {(self.exec_time)/1000000:.2f} ms")
30 |
31 | def origin_dirs_W(T_WC, dirs_C):
32 |
33 | assert T_WC.shape[0] == dirs_C.shape[0]
34 | assert T_WC.shape[1:] == (4, 4)
35 | assert dirs_C.shape[2] == 3
36 |
37 | dirs_W = (T_WC[:, None, :3, :3] @ dirs_C[..., None]).squeeze()
38 |
39 | origins = T_WC[:, :3, -1]
40 |
41 | return origins, dirs_W
42 |
43 |
44 | # @torch.jit.script
45 | def stratified_bins(min_depth, max_depth, n_bins, n_rays, type=torch.float32, device = "cuda:0"):
46 | # type: (Tensor, Tensor, int, int) -> Tensor
47 |
48 | bin_limits_scale = torch.linspace(0, 1, n_bins+1, dtype=type, device=device)
49 |
50 | if not torch.is_tensor(min_depth):
51 | min_depth = torch.ones(n_rays, dtype=type, device=device) * min_depth
52 |
53 | if not torch.is_tensor(max_depth):
54 | max_depth = torch.ones(n_rays, dtype=type, device=device) * max_depth
55 |
56 | depth_range = max_depth - min_depth
57 |
58 | lower_limits_scale = depth_range[..., None] * bin_limits_scale + min_depth[..., None]
59 | lower_limits_scale = lower_limits_scale[:, :-1]
60 |
61 | assert lower_limits_scale.shape == (n_rays, n_bins)
62 |
63 | bin_length_scale = depth_range / n_bins
64 | increments_scale = torch.rand(
65 | n_rays, n_bins, device=device,
66 | dtype=torch.float32) * bin_length_scale[..., None]
67 |
68 | z_vals_scale = lower_limits_scale + increments_scale
69 |
70 | assert z_vals_scale.shape == (n_rays, n_bins)
71 |
72 | return z_vals_scale
73 |
74 | # @torch.jit.script
75 | def normal_bins_sampling(depth, n_bins, n_rays, delta, device = "cuda:0"):
76 | # type: (Tensor, int, int, float) -> Tensor
77 |
78 | # device = "cpu"
79 | # bins = torch.normal(0.0, delta / 3., size=[n_rays, n_bins], devi
80 | # self.keyframes_batch = torch.empty(self.n_keyframes,ce=device).sort().values
81 | bins = torch.empty(n_rays, n_bins, dtype=torch.float32, device=device).normal_(mean=0.,std=delta / 3.).sort().values
82 | bins = torch.clip(bins, -delta, delta)
83 | z_vals = depth[:, None] + bins
84 |
85 | assert z_vals.shape == (n_rays, n_bins)
86 |
87 | return z_vals
88 |
89 |
90 | class sceneObject:
91 | """
92 | object instance mapping,
93 | updating keyframes, get training samples, optimizing MLP map
94 | """
95 |
96 | def __init__(self, cfg, obj_id, rgb:torch.tensor, depth:torch.tensor, mask:torch.tensor, bbox_2d:torch.tensor, t_wc:torch.tensor, live_frame_id) -> None:
97 | self.do_bg = cfg.do_bg
98 | self.obj_id = obj_id
99 | self.data_device = cfg.data_device
100 | self.training_device = cfg.training_device
101 |
102 | assert rgb.shape[:2] == depth.shape
103 | assert rgb.shape[:2] == mask.shape
104 | assert bbox_2d.shape == (4,)
105 | assert t_wc.shape == (4, 4,)
106 |
107 | if self.do_bg and self.obj_id == 0: # do seperate bg
108 | self.obj_scale = cfg.bg_scale
109 | self.hidden_feature_size = cfg.hidden_feature_size_bg
110 | self.n_bins_cam2surface = cfg.n_bins_cam2surface_bg
111 | self.keyframe_step = cfg.keyframe_step_bg
112 | else:
113 | self.obj_scale = cfg.obj_scale
114 | self.hidden_feature_size = cfg.hidden_feature_size
115 | self.n_bins_cam2surface = cfg.n_bins_cam2surface
116 | self.keyframe_step = cfg.keyframe_step
117 |
118 | self.frames_width = rgb.shape[0]
119 | self.frames_height = rgb.shape[1]
120 |
121 | self.min_bound = cfg.min_depth
122 | self.max_bound = cfg.max_depth
123 | self.n_bins = cfg.n_bins
124 | self.n_unidir_funcs = cfg.n_unidir_funcs
125 |
126 | self.surface_eps = cfg.surface_eps
127 | self.stop_eps = cfg.stop_eps
128 |
129 | self.n_keyframes = 1 # Number of keyframes
130 | self.kf_pointer = None
131 | self.keyframe_buffer_size = cfg.keyframe_buffer_size
132 | self.kf_id_dict = bidict({live_frame_id:0})
133 | self.kf_buffer_full = False
134 | self.frame_cnt = 0 # number of frames taken in
135 | self.lastest_kf_queue = []
136 |
137 | self.bbox = torch.empty( # obj bounding bounding box in the frame
138 | self.keyframe_buffer_size,
139 | 4,
140 | device=self.data_device) # [u low, u high, v low, v high]
141 | self.bbox[0] = bbox_2d
142 |
143 | # RGB + pixel state batch
144 | self.rgb_idx = slice(0, 3)
145 | self.state_idx = slice(3, 4)
146 | self.rgbs_batch = torch.empty(self.keyframe_buffer_size,
147 | self.frames_width,
148 | self.frames_height,
149 | 4,
150 | dtype=torch.uint8,
151 | device=self.data_device)
152 |
153 | # Pixel states:
154 | self.other_obj = 0 # pixel doesn't belong to obj
155 | self.this_obj = 1 # pixel belong to obj
156 | self.unknown_obj = 2 # pixel state is unknown
157 |
158 | # Initialize first frame rgb and pixel state
159 | self.rgbs_batch[0, :, :, self.rgb_idx] = rgb
160 | self.rgbs_batch[0, :, :, self.state_idx] = mask[..., None]
161 |
162 | self.depth_batch = torch.empty(self.keyframe_buffer_size,
163 | self.frames_width,
164 | self.frames_height,
165 | dtype=torch.float32,
166 | device=self.data_device)
167 |
168 | # Initialize first frame's depth
169 | self.depth_batch[0] = depth
170 | self.t_wc_batch = torch.empty(
171 | self.keyframe_buffer_size, 4, 4,
172 | dtype=torch.float32,
173 | device=self.data_device) # world to camera transform
174 |
175 | # Initialize first frame's world2cam transform
176 | self.t_wc_batch[0] = t_wc
177 |
178 | # neural field map
179 | trainer_cfg = copy.deepcopy(cfg)
180 | trainer_cfg.obj_id = self.obj_id
181 | trainer_cfg.hidden_feature_size = self.hidden_feature_size
182 | trainer_cfg.obj_scale = self.obj_scale
183 | self.trainer = trainer.Trainer(trainer_cfg)
184 |
185 | # 3D boundary
186 | self.bbox3d = None
187 | self.pc = []
188 |
189 | # init obj local frame
190 | # self.obj_center = self.init_obj_center(intrinsic, depth, mask, t_wc)
191 | self.obj_center = torch.tensor(0.0) # shouldn't make any difference because of frequency embedding
192 |
193 |
194 | def init_obj_center(self, intrinsic_open3d, depth, mask, t_wc):
195 | obj_depth = depth.cpu().clone()
196 | obj_depth[mask!=self.this_obj] = 0
197 | T_CW = np.linalg.inv(t_wc.cpu().numpy())
198 | pc_obj_init = open3d.geometry.PointCloud.create_from_depth_image(
199 | depth=open3d.geometry.Image(np.asarray(obj_depth.permute(1,0).numpy(), order="C")),
200 | intrinsic=intrinsic_open3d,
201 | extrinsic=T_CW,
202 | depth_trunc=self.max_bound,
203 | depth_scale=1.0)
204 | obj_center = torch.from_numpy(np.mean(pc_obj_init.points, axis=0)).float()
205 | return obj_center
206 |
207 | # @profile
208 | def append_keyframe(self, rgb:torch.tensor, depth:torch.tensor, mask:torch.tensor, bbox_2d:torch.tensor, t_wc:torch.tensor, frame_id:np.uint8=1):
209 | assert rgb.shape[:2] == depth.shape
210 | assert rgb.shape[:2] == mask.shape
211 | assert bbox_2d.shape == (4,)
212 | assert t_wc.shape == (4, 4,)
213 | assert self.n_keyframes <= self.keyframe_buffer_size - 1
214 | assert rgb.dtype == torch.uint8
215 | assert mask.dtype == torch.uint8
216 | assert depth.dtype == torch.float32
217 |
218 | # every kf_step choose one kf
219 | is_kf = (self.frame_cnt % self.keyframe_step == 0) or self.n_keyframes == 1
220 | # print("---------------------")
221 | # print("self.kf_id_dict ", self.kf_id_dict)
222 | # print("live frame id ", frame_id)
223 | # print("n_frames ", self.n_keyframes)
224 | if self.n_keyframes == self.keyframe_buffer_size - 1: # kf buffer full, need to prune
225 | self.kf_buffer_full = True
226 | if self.kf_pointer is None:
227 | self.kf_pointer = self.n_keyframes
228 |
229 | self.rgbs_batch[self.kf_pointer, :, :, self.rgb_idx] = rgb
230 | self.rgbs_batch[self.kf_pointer, :, :, self.state_idx] = mask[..., None]
231 | self.depth_batch[self.kf_pointer, ...] = depth
232 | self.t_wc_batch[self.kf_pointer, ...] = t_wc
233 | self.bbox[self.kf_pointer, ...] = bbox_2d
234 | self.kf_id_dict.inv[self.kf_pointer] = frame_id
235 |
236 | if is_kf:
237 | self.lastest_kf_queue.append(self.kf_pointer)
238 | pruned_frame_id, pruned_kf_id = self.prune_keyframe()
239 | self.kf_pointer = pruned_kf_id
240 | print("pruned kf id ", self.kf_pointer)
241 |
242 | else:
243 | if not is_kf: # not kf, replace
244 | self.rgbs_batch[self.n_keyframes-1, :, :, self.rgb_idx] = rgb
245 | self.rgbs_batch[self.n_keyframes-1, :, :, self.state_idx] = mask[..., None]
246 | self.depth_batch[self.n_keyframes-1, ...] = depth
247 | self.t_wc_batch[self.n_keyframes-1, ...] = t_wc
248 | self.bbox[self.n_keyframes-1, ...] = bbox_2d
249 | self.kf_id_dict.inv[self.n_keyframes-1] = frame_id
250 | else: # is kf, add new kf
251 | self.kf_id_dict[frame_id] = self.n_keyframes
252 | self.rgbs_batch[self.n_keyframes, :, :, self.rgb_idx] = rgb
253 | self.rgbs_batch[self.n_keyframes, :, :, self.state_idx] = mask[..., None]
254 | self.depth_batch[self.n_keyframes, ...] = depth
255 | self.t_wc_batch[self.n_keyframes, ...] = t_wc
256 | self.bbox[self.n_keyframes, ...] = bbox_2d
257 | self.lastest_kf_queue.append(self.n_keyframes)
258 | self.n_keyframes += 1
259 |
260 | # print("self.kf_id_dic ", self.kf_id_dict)
261 | self.frame_cnt += 1
262 | if len(self.lastest_kf_queue) > 2: # keep latest two frames
263 | self.lastest_kf_queue = self.lastest_kf_queue[-2:]
264 |
265 | def prune_keyframe(self):
266 | # simple strategy to prune, randomly choose
267 | key, value = random.choice(list(self.kf_id_dict.items())[:-2]) # do not prune latest two frames
268 | return key, value
269 |
270 | def get_bound(self, intrinsic_open3d):
271 | # get 3D boundary from posed depth img todo update sparse pc when append frame
272 | pcs = open3d.geometry.PointCloud()
273 | for kf_id in range(self.n_keyframes):
274 | mask = self.rgbs_batch[kf_id, :, :, self.state_idx].squeeze() == self.this_obj
275 | depth = self.depth_batch[kf_id].cpu().clone()
276 | twc = self.t_wc_batch[kf_id].cpu().numpy()
277 | depth[~mask] = 0
278 | depth = depth.permute(1,0).numpy().astype(np.float32)
279 | T_CW = np.linalg.inv(twc)
280 | pc = open3d.geometry.PointCloud.create_from_depth_image(depth=open3d.geometry.Image(np.asarray(depth, order="C")), intrinsic=intrinsic_open3d, extrinsic=T_CW)
281 | # self.pc += pc
282 | pcs += pc
283 |
284 | # # get minimal oriented 3d bbox
285 | # try:
286 | # bbox3d = open3d.geometry.OrientedBoundingBox.create_from_points(pcs.points)
287 | # except RuntimeError:
288 | # print("too few pcs obj ")
289 | # return None
290 | # trimesh has a better minimal bbox implementation than open3d
291 | try:
292 | transform, extents = trimesh.bounds.oriented_bounds(np.array(pcs.points)) # pc
293 | transform = np.linalg.inv(transform)
294 | except scipy.spatial._qhull.QhullError:
295 | print("too few pcs obj ")
296 | return None
297 |
298 | for i in range(extents.shape[0]):
299 | extents[i] = np.maximum(extents[i], 0.10) # at least rendering 10cm
300 | bbox = utils.BoundingBox()
301 | bbox.center = transform[:3, 3]
302 | bbox.R = transform[:3, :3]
303 | bbox.extent = extents
304 | bbox3d = open3d.geometry.OrientedBoundingBox(bbox.center, bbox.R, bbox.extent)
305 |
306 | min_extent = 0.05
307 | bbox3d.extent = np.maximum(min_extent, bbox3d.extent)
308 | bbox3d.color = (255,0,0)
309 | self.bbox3d = utils.bbox_open3d2bbox(bbox_o3d=bbox3d)
310 | # self.pc = []
311 | print("obj ", self.obj_id)
312 | print("bound ", bbox3d)
313 | print("kf id dict ", self.kf_id_dict)
314 | # open3d.visualization.draw_geometries([bbox3d, pcs])
315 | return bbox3d
316 |
317 |
318 |
319 | def get_training_samples(self, n_frames, n_samples, cached_rays_dir):
320 | # Sample pixels
321 | if self.n_keyframes > 2: # make sure latest 2 frames are sampled todo if kf pruned, this is not the latest frame
322 | keyframe_ids = torch.randint(low=0,
323 | high=self.n_keyframes,
324 | size=(n_frames - 2,),
325 | dtype=torch.long,
326 | device=self.data_device)
327 | # if self.kf_buffer_full:
328 | # latest_frame_ids = list(self.kf_id_dict.values())[-2:]
329 | latest_frame_ids = self.lastest_kf_queue[-2:]
330 | keyframe_ids = torch.cat([keyframe_ids,
331 | torch.tensor(latest_frame_ids, device=keyframe_ids.device)])
332 | # print("latest_frame_ids", latest_frame_ids)
333 | # else: # sample last 2 frames
334 | # keyframe_ids = torch.cat([keyframe_ids,
335 | # torch.tensor([self.n_keyframes-2, self.n_keyframes-1], device=keyframe_ids.device)])
336 | else:
337 | keyframe_ids = torch.randint(low=0,
338 | high=self.n_keyframes,
339 | size=(n_frames,),
340 | dtype=torch.long,
341 | device=self.data_device)
342 | keyframe_ids = torch.unsqueeze(keyframe_ids, dim=-1)
343 | idx_w = torch.rand(n_frames, n_samples, device=self.data_device)
344 | idx_h = torch.rand(n_frames, n_samples, device=self.data_device)
345 |
346 | # resizing idx_w and idx_h to be in the bbox range
347 | idx_w = idx_w * (self.bbox[keyframe_ids, 1] - self.bbox[keyframe_ids, 0]) + self.bbox[keyframe_ids, 0]
348 | idx_h = idx_h * (self.bbox[keyframe_ids, 3] - self.bbox[keyframe_ids, 2]) + self.bbox[keyframe_ids, 2]
349 |
350 | idx_w = idx_w.long()
351 | idx_h = idx_h.long()
352 |
353 | sampled_rgbs = self.rgbs_batch[keyframe_ids, idx_w, idx_h]
354 | sampled_depth = self.depth_batch[keyframe_ids, idx_w, idx_h]
355 |
356 | # Get ray directions for sampled pixels
357 | sampled_ray_dirs = cached_rays_dir[idx_w, idx_h]
358 |
359 | # Get sampled keyframe poses
360 | sampled_twc = self.t_wc_batch[keyframe_ids[:, 0], :, :]
361 |
362 | origins, dirs_w = origin_dirs_W(sampled_twc, sampled_ray_dirs)
363 |
364 | return self.sample_3d_points(sampled_rgbs, sampled_depth, origins, dirs_w)
365 |
366 | def sample_3d_points(self, sampled_rgbs, sampled_depth, origins, dirs_w):
367 | """
368 | 3D sampling strategy
369 |
370 | * For pixels with invalid depth:
371 | - N+M from minimum bound to max (stratified)
372 |
373 | * For pixels with valid depth:
374 | # Pixel belongs to this object
375 | - N from cam to surface (stratified)
376 | - M around surface (stratified/normal)
377 | # Pixel belongs that don't belong to this object
378 | - N from cam to surface (stratified)
379 | - M around surface (stratified)
380 | # Pixel with unknown state
381 | - Do nothing!
382 | """
383 |
384 | n_bins_cam2surface = self.n_bins_cam2surface
385 | n_bins = self.n_bins
386 | eps = self.surface_eps
387 | other_objs_max_eps = self.stop_eps #0.05 # todo 0.02
388 | # print("max depth ", torch.max(sampled_depth))
389 | sampled_z = torch.zeros(
390 | sampled_rgbs.shape[0] * sampled_rgbs.shape[1],
391 | n_bins_cam2surface + n_bins,
392 | dtype=self.depth_batch.dtype,
393 | device=self.data_device) # shape (N*n_rays, n_bins_cam2surface + n_bins)
394 |
395 | invalid_depth_mask = (sampled_depth <= self.min_bound).view(-1)
396 | # max_bound = self.max_bound
397 | max_bound = torch.max(sampled_depth)
398 | # sampling for points with invalid depth
399 | invalid_depth_count = invalid_depth_mask.count_nonzero()
400 | if invalid_depth_count:
401 | sampled_z[invalid_depth_mask, :] = stratified_bins(
402 | self.min_bound, max_bound,
403 | n_bins_cam2surface + n_bins, invalid_depth_count,
404 | device=self.data_device)
405 |
406 | # sampling for valid depth rays
407 | valid_depth_mask = ~invalid_depth_mask
408 | valid_depth_count = valid_depth_mask.count_nonzero()
409 |
410 |
411 | if valid_depth_count:
412 | # Sample between min bound and depth for all pixels with valid depth
413 | sampled_z[valid_depth_mask, :n_bins_cam2surface] = stratified_bins(
414 | self.min_bound, sampled_depth.view(-1)[valid_depth_mask]-eps,
415 | n_bins_cam2surface, valid_depth_count, device=self.data_device)
416 |
417 | # sampling around depth for this object
418 | obj_mask = (sampled_rgbs[..., -1] == self.this_obj).view(-1) & valid_depth_mask # todo obj_mask
419 | assert sampled_z.shape[0] == obj_mask.shape[0]
420 | obj_count = obj_mask.count_nonzero()
421 |
422 | if obj_count:
423 | sampling_method = "normal" # stratified or normal
424 | if sampling_method == "stratified":
425 | sampled_z[obj_mask, n_bins_cam2surface:] = stratified_bins(
426 | sampled_depth.view(-1)[obj_mask] - eps, sampled_depth.view(-1)[obj_mask] + eps,
427 | n_bins, obj_count, device=self.data_device)
428 |
429 | elif sampling_method == "normal":
430 | sampled_z[obj_mask, n_bins_cam2surface:] = normal_bins_sampling(
431 | sampled_depth.view(-1)[obj_mask],
432 | n_bins,
433 | obj_count,
434 | delta=eps,
435 | device=self.data_device)
436 |
437 | else:
438 | raise (
439 | f"sampling method not implemented {sampling_method}, \
440 | stratified and normal sampling only currenty implemented."
441 | )
442 |
443 | # sampling around depth of other objects
444 | other_obj_mask = (sampled_rgbs[..., -1] != self.this_obj).view(-1) & valid_depth_mask
445 | other_objs_count = other_obj_mask.count_nonzero()
446 | if other_objs_count:
447 | sampled_z[other_obj_mask, n_bins_cam2surface:] = stratified_bins(
448 | sampled_depth.view(-1)[other_obj_mask] - eps,
449 | sampled_depth.view(-1)[other_obj_mask] + other_objs_max_eps,
450 | n_bins, other_objs_count, device=self.data_device)
451 |
452 | sampled_z = sampled_z.view(sampled_rgbs.shape[0],
453 | sampled_rgbs.shape[1],
454 | -1) # view as (n_rays, n_samples, 10)
455 | input_pcs = origins[..., None, None, :] + (dirs_w[:, :, None, :] *
456 | sampled_z[..., None])
457 | input_pcs -= self.obj_center
458 | obj_labels = sampled_rgbs[..., -1].view(-1)
459 | return sampled_rgbs[..., :3], sampled_depth, valid_depth_mask, obj_labels, input_pcs, sampled_z
460 |
461 | def save_checkpoints(self, path, epoch):
462 | obj_id = self.obj_id
463 | chechpoint_load_file = (path + "/obj_" + str(obj_id) + "_frame_" + str(epoch) + ".pth")
464 |
465 | torch.save(
466 | {
467 | "epoch": epoch,
468 | "FC_state_dict": self.trainer.fc_occ_map.state_dict(),
469 | "PE_state_dict": self.trainer.pe.state_dict(),
470 | "obj_id": self.obj_id,
471 | "bbox": self.bbox3d,
472 | "obj_scale": self.trainer.obj_scale
473 | },
474 | chechpoint_load_file,
475 | )
476 | # optimiser?
477 |
478 | def load_checkpoints(self, ckpt_file):
479 | checkpoint_load_file = (ckpt_file)
480 | if not os.path.exists(checkpoint_load_file):
481 | print("ckpt not exist ", checkpoint_load_file)
482 | return
483 | checkpoint = torch.load(checkpoint_load_file)
484 | self.trainer.fc_occ_map.load_state_dict(checkpoint["FC_state_dict"])
485 | self.trainer.pe.load_state_dict(checkpoint["PE_state_dict"])
486 | self.obj_id = checkpoint["obj_id"]
487 | self.bbox3d = checkpoint["bbox"]
488 | self.trainer.obj_scale = checkpoint["obj_scale"]
489 |
490 | self.trainer.fc_occ_map.to(self.training_device)
491 | self.trainer.pe.to(self.training_device)
492 |
493 |
494 | class cameraInfo:
495 |
496 | def __init__(self, cfg) -> None:
497 | self.device = cfg.data_device
498 | self.width = cfg.W # Frame width
499 | self.height = cfg.H # Frame height
500 |
501 | self.fx = cfg.fx
502 | self.fy = cfg.fy
503 | self.cx = cfg.cx
504 | self.cy = cfg.cy
505 |
506 | self.rays_dir_cache = self.get_rays_dirs()
507 |
508 | def get_rays_dirs(self, depth_type="z"):
509 | idx_w = torch.arange(end=self.width, device=self.device)
510 | idx_h = torch.arange(end=self.height, device=self.device)
511 |
512 | dirs = torch.ones((self.width, self.height, 3), device=self.device)
513 |
514 | dirs[:, :, 0] = ((idx_w - self.cx) / self.fx)[:, None]
515 | dirs[:, :, 1] = ((idx_h - self.cy) / self.fy)
516 |
517 | if depth_type == "euclidean":
518 | raise Exception(
519 | "Get camera rays directions with euclidean depth not yet implemented"
520 | )
521 | norm = torch.norm(dirs, dim=-1)
522 | dirs = dirs * (1. / norm)[:, :, :, None]
523 |
524 | return dirs
525 |
526 |
--------------------------------------------------------------------------------