├── README.md
├── assets
└── teaser.jpg
├── requirements.txt
└── scripts
├── build_graph_structure.py
├── edge_weights.py
├── node_feature.py
├── predict_masks.py
├── render_depth.py
├── run.py
├── sam_encoder_feature.py
└── superpoint_projection.py
/README.md:
--------------------------------------------------------------------------------
1 | # SAM-guided Graph Cut for 3D Instance Segmentation
2 | ### [Project Page](https://zju3dv.github.io/sam_graph) | [Video](https://www.youtube.com/watch?v=daWiQiFPpZ0) | [Paper](https://arxiv.org/abs/2312.08372)
3 |
4 | > [SAM-guided Graph Cut for 3D Instance Segmentation](https://arxiv.org/abs/2312.08372)
5 | > [Haoyu Guo](https://github.com/ghy0324)\*, [He Zhu](https://github.com/Ada4321)\*, [Sida Peng](https://pengsida.net), [Yuang Wang](https://github.com/angshine), [Yujun Shen](https://shenyujun.github.io/), [Ruizhen Hu](https://csse.szu.edu.cn/staff/ruizhenhu/), [Xiaowei Zhou](https://xzhou.me)
6 |
7 |
8 | 
9 |
10 |
11 |
12 | ## TODO
13 |
14 | - [ ] Segmentation with / without GNN
15 | - [x] Graph construction and SAM based annotation
16 | - [ ] Processing of point clouds
17 | - [x] Processing of triangle meshes
18 |
19 |
20 | ## Setup
21 |
22 | The code is tested with Python 3.8 and PyTorch 1.12.0.
23 |
24 | 1. Clone the repository:
25 | ```
26 | git clone https://github.com/zju3dv/SAM_Graph.git
27 | ```
28 |
29 | 2. Install dependencies:
30 | ```
31 | pip install -r requirements.txt
32 | ```
33 | Clone ScanNet repository and build the [segmentor](https://github.com/ScanNet/ScanNet/tree/master/Segmentator) and modify `segmentor_path` in `scripts/run.py`.
34 |
35 | Download the [checkpoint of segment-anything model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and modify `sam_ckpt_path` in `scripts/run.py`.
36 |
37 | ## Data preparation
38 |
39 | Please download the example data from [here](https://drive.google.com/file/d/1cvrs9Hd6TOUza7OV2bd9XhHAxpTG6q7m/view?usp=drive_link) and modify `data_path` in `scripts/run.py` for fast testing. If you want to use your own data, please organize the data as the same format as the example data.
40 |
41 | ## Run
42 |
43 | The pipeline of our method is illustrated as follows:
44 |
45 |
46 |
47 | ```mermaid
48 | graph TD
49 | subgraph Input
50 | A[multi-view images]
51 | B[sensor depth]
52 | C[camera pose]
53 | D[point cloud / mesh]
54 | end
55 |
56 | E[rendered depth]
57 | F[superpoints]
58 | G[sam encoder feature]
59 | H[depth difference]
60 | I[superpoint projections]
61 | J[graph structure]
62 | K[predicted masks]
63 | L[edge weights]
64 | M[node feature]
65 | N[graph segmentation]
66 |
67 | A --> G
68 | B --> H
69 | C --> E
70 | C --> I
71 | D --> E
72 | D --> F
73 | E --> H
74 | F --> I
75 | F --> J
76 | G --> K
77 | G --> M
78 | H --> I
79 | I --> K
80 | I --> M
81 | J --> L
82 | J --> N
83 | K --> L
84 | L --> N
85 | M --> N
86 | ```
87 |
88 | Note that `depth difference` step is optional, but is recommended if accurate sensor depth is available and the point cloud / mesh contains large holes or missing regions.
89 |
90 | To run the pipeline, simply run:
91 |
92 | ```
93 | cd scripts
94 | python run.py
95 | ```
96 |
97 | The results of each step will be saved in the individual folders.
98 |
99 | ## Citation
100 |
101 | ```bibtex
102 | @inproceedings{guo2024sam,
103 | title={SAM-guided Graph Cut for 3D Instance Segmentation},
104 | author={Guo, Haoyu and Zhu, He and Peng, Sida and Wang, Yuang and Shen, Yujun and Hu, Ruizhen and Zhou, Xiaowei},
105 | booktitle={ECCV},
106 | year={2024}
107 | }
108 | ```
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zju3dv/SAM-Graph/8e58c46d93904525cfa94c16a4f7a6ce5aaea4d8/assets/teaser.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.0
2 | pytorch3d==0.7.4
3 | numpy==1.23.5
4 | trimesh==3.23.5
5 | scipy
6 | PIL
7 | opencv-python
8 | segment_anything
9 | pycolmap
--------------------------------------------------------------------------------
/scripts/build_graph_structure.py:
--------------------------------------------------------------------------------
1 | import argparse, trimesh, numpy as np, json, os
2 | from tqdm import tqdm, trange
3 | from scipy.spatial import KDTree
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--mesh_path", type=str, help="Path to the mesh (or point cloud)")
7 | parser.add_argument("--superpoint_path", type=str, help="Path to the superpoint segmentation (.json)")
8 | parser.add_argument("--graph_structure_path", type=str, help="Path to the output graph structure (.npz)")
9 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists")
10 | args = parser.parse_args()
11 |
12 | if args.skip_if_exist and os.path.exists(args.graph_structure_path):
13 | print(f'{args.graph_structure_path} already exists, skip.')
14 | exit()
15 |
16 | m = trimesh.load(args.mesh_path)
17 | seg = json.load(open(args.superpoint_path))
18 | seg_indices = np.array(seg['segIndices'])
19 | sp_ids = np.unique(seg_indices)
20 | mapping = np.zeros((sp_ids.max() + 1, ), dtype=np.int32)
21 | for i, sp_id in enumerate(sp_ids):
22 | mapping[sp_id] = i
23 | seg_indices = mapping[seg_indices]
24 |
25 | pc_list = []
26 | for sp_id in np.unique(seg_indices):
27 | pc = m.vertices[seg_indices == sp_id]
28 | pc = np.asarray(pc)
29 | if pc.shape[0] > 50:
30 | pc = pc[np.random.choice(pc.shape[0], size=50, replace=False)]
31 | pc_list.append(pc)
32 |
33 | distances = {}
34 | for i in trange(len(pc_list) - 1, desc='Building graph structure'):
35 | tree_i = KDTree(pc_list[i])
36 | pc_concat = np.concatenate(pc_list[i + 1:])
37 | dist_concat = tree_i.query(pc_concat)[0]
38 | dist_split = np.split(dist_concat, np.cumsum([len(pc) for pc in pc_list[i+1:]]))[:-1]
39 | for j, dist in enumerate(dist_split):
40 | distances[(i, i + j + 1)] = dist.min()
41 | # for j in range(i + 1, len(pc_list)):
42 | # dist, _ = tree_i.query(pc_list[j])
43 | # dist = min(dist)
44 | # distances[(i, j)] = dist
45 |
46 | edges = []
47 | for (a, b), distance in distances.items():
48 | if distance < 0.3:
49 | edges.append((a, b))
50 |
51 | np.savez_compressed(args.graph_structure_path, np.asarray(edges))
52 |
--------------------------------------------------------------------------------
/scripts/edge_weights.py:
--------------------------------------------------------------------------------
1 | import argparse, numpy as np, os
2 | from tqdm import tqdm
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--image_path", type=str, help="Path to the images")
6 | parser.add_argument("--graph_structure_path", type=str, help="Path to the graph structure (.npz)")
7 | parser.add_argument("--sam_mask_path", type=str, help="Path to the SAM masks")
8 | parser.add_argument("--edge_weights_path", type=str, help="Path to the output edge weights (.npz)")
9 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists")
10 | args = parser.parse_args()
11 |
12 | if args.skip_if_exist and os.path.exists(args.edge_weights_path):
13 | print(f'{args.edge_weights_path} already exists, skip.')
14 | exit()
15 |
16 | def select_mask(masks, iou_preds):
17 | if iou_preds[2] > iou_preds.max() - 0.05:
18 | return masks[2], iou_preds[2]
19 | elif iou_preds[1] > iou_preds[0] - 0.05:
20 | return masks[1], iou_preds[1]
21 | else:
22 | return masks[0], iou_preds[0]
23 |
24 | def weighted_average(data):
25 | score_sum = 0
26 | distance_sum = 0
27 | for k, (score, score1, score2, iou_pred1, iou_pred2, distance) in data.items():
28 | score_sum += max(score1, score2) * distance * iou_pred1 * iou_pred2
29 | distance_sum += distance * iou_pred1 * iou_pred2
30 | return score_sum / distance_sum
31 |
32 | edges = np.load(args.graph_structure_path)['arr_0']
33 |
34 | image_list = os.listdir(args.image_path)
35 |
36 | edge_weights = dict()
37 |
38 | for image_f in tqdm(image_list, desc='Calculating edge weights'):
39 | image_id = image_f.split('.')[0]
40 | mask_data = np.load(f'{args.sam_mask_path}/{image_id}.npz', allow_pickle=True)['arr_0'].tolist()
41 | points_per_instance, masks, iou_preds = mask_data['points_per_instance'], mask_data['masks'], mask_data['iou_preds']
42 |
43 | for p1, p2 in edges:
44 | if p1 not in points_per_instance.keys():
45 | continue
46 | if p2 not in points_per_instance.keys():
47 | continue
48 | if (p1, p2) not in edge_weights:
49 | edge_weights[(p1, p2)] = dict()
50 | assert (p2, p1) not in edge_weights
51 | edge_weights[(p2, p1)] = dict()
52 | xy1 = points_per_instance[p1]
53 | xy2 = points_per_instance[p2]
54 |
55 | distance2d = ((np.array(xy1).mean(0) - np.array(xy2).mean(0)) ** 2).sum() ** 0.5
56 | p1_id, p2_id = list(points_per_instance.keys()).index(p1), list(points_per_instance.keys()).index(p2)
57 |
58 | mask1, mask2 = masks[p1_id], masks[p2_id]
59 | iou_pred1, iou_pred2 = iou_preds[p1_id], iou_preds[p2_id]
60 |
61 | mask1, iou_pred1 = select_mask(mask1, iou_pred1)
62 | mask2, iou_pred2 = select_mask(mask2, iou_pred2)
63 |
64 | iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
65 | ioa = (mask1 & mask2).sum() / mask1.sum()
66 | iob = (mask1 & mask2).sum() / mask2.sum()
67 |
68 | edge_weights[(p1, p2)][image_id] = [iou, ioa, iob, iou_pred1, iou_pred2, distance2d]
69 | edge_weights[(p2, p1)][image_id] = [iou, iob, ioa, iou_pred2, iou_pred1, distance2d]
70 |
71 | for k, v in edge_weights.items():
72 | edge_weights[k] = weighted_average(v)
73 |
74 | np.savez(args.edge_weights_path, edge_weights)
75 |
--------------------------------------------------------------------------------
/scripts/node_feature.py:
--------------------------------------------------------------------------------
1 | import argparse, cv2, numpy as np, os
2 | from tqdm import tqdm
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--image_path", type=str, help="Path to the images")
6 | parser.add_argument("--feature_path", type=str, help="Path to the SAM encoder features")
7 | parser.add_argument("--sam_mask_path", type=str, help="Path to the SAM masks")
8 | parser.add_argument("--node_feature_path", type=str, help="Path to the output node features")
9 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists")
10 | args = parser.parse_args()
11 |
12 | image_list = os.listdir(args.image_path)
13 |
14 | h, w = cv2.imread(f'{args.image_path}/{image_list[0]}').shape[:2]
15 | # Note: here we simply assume that all images are of the same resolution, which can be modified if needed.
16 |
17 | feature_h, feature_w = 64, 64 # resolution of SAM encoder feature maps
18 | if h < w:
19 | feature_h = int(feature_w * h / w)
20 | else:
21 | feature_w = int(feature_h * w / h)
22 |
23 | feature_dict = dict()
24 |
25 | if args.skip_if_exist and os.path.exists(args.node_feature_path):
26 | print(f'{args.node_feature_path} already exists, skip.')
27 | exit()
28 |
29 | for image_f in tqdm(image_list, desc='Calulating node features'):
30 | image_id = image_f.split('.')[0]
31 | mask_data = np.load(f'{args.sam_mask_path}/{image_id}.npz', allow_pickle=True)['arr_0'].tolist()
32 | points_per_instance, _, _ = mask_data['points_per_instance'], mask_data['masks'], mask_data['iou_preds']
33 | feature_map = np.load(f'{args.feature_path}/{image_id}.npy')
34 | for sp_id, pts in points_per_instance.items():
35 | xy = np.array(pts).astype(np.float32).T
36 | xy[0] = xy[0] / w * feature_w
37 | xy[1] = xy[1] / h * feature_h
38 | xy1 = xy.astype(np.int32)
39 | xy2 = xy1 + 1
40 | xy2[0] = xy2[0].clip(max=feature_w - 1)
41 | xy2[1] = xy2[1].clip(max=feature_h - 1)
42 | feature = feature_map[xy1[1], xy1[0]] * (xy2[0] - xy[0])[:, None] * (xy2[1] - xy[1])[:, None] + \
43 | feature_map[xy2[1], xy1[0]] * (xy2[0] - xy[0])[:, None] * (xy[1] - xy1[1])[:, None] + \
44 | feature_map[xy1[1], xy2[0]] * (xy[0] - xy1[0])[:, None] * (xy2[1] - xy[1])[:, None] + \
45 | feature_map[xy2[1], xy2[0]] * (xy[0] - xy1[0])[:, None] * (xy[1] - xy1[1])[:, None]
46 | feature = feature.mean(0)
47 | if sp_id not in feature_dict:
48 | feature_dict[sp_id] = [feature]
49 | else:
50 | feature_dict[sp_id].append(feature)
51 |
52 | feature_mean_dict = dict()
53 | for k, v in feature_dict.items():
54 | feature_mean_dict[k] = sum(v) / len(v)
55 |
56 | np.savez_compressed(args.node_feature_path, feature_mean_dict)
57 |
--------------------------------------------------------------------------------
/scripts/predict_masks.py:
--------------------------------------------------------------------------------
1 | import argparse, cv2, numpy as np, os, json, torch
2 | np.set_printoptions(suppress=True)
3 | from PIL import Image
4 | from tqdm import tqdm
5 | from segment_anything import build_sam, SamPredictor
6 |
7 |
8 | def remove_small_masks_from_segmentation(segmentation_mask, min_size=100):
9 | unique_ids = np.unique(segmentation_mask)
10 | for instance_id in unique_ids:
11 | if instance_id == -1:
12 | continue
13 |
14 | instance_mask = (segmentation_mask == instance_id).astype(np.uint8)
15 |
16 | num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(instance_mask, connectivity=8)
17 |
18 | for i in range(1, num_labels):
19 | if stats[i, cv2.CC_STAT_AREA] < min_size:
20 | segmentation_mask[labels == i] = -1
21 | return segmentation_mask
22 |
23 |
24 | def sample_points_from_mask(mask, num_points=10):
25 | h, w = mask.shape
26 |
27 | extended_mask = np.zeros((h+2, w+2), dtype=np.uint8)
28 | extended_mask[1:-1, 1:-1] = mask
29 |
30 | distance_transform = cv2.distanceTransform(extended_mask, cv2.DIST_L2, 5)
31 |
32 | distance_transform = distance_transform[1:-1, 1:-1]
33 | sampled_points = []
34 |
35 | for _ in range(num_points):
36 | min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(distance_transform)
37 | sampled_points.append(max_loc)
38 |
39 | cv2.circle(distance_transform, max_loc, int(max_val), 0, -1)
40 |
41 | return sampled_points
42 |
43 |
44 | def sample_points_for_each_instance(segmentation_mask, num_points=10):
45 | ret = dict()
46 | unique_ids = np.unique(segmentation_mask)
47 |
48 | for instance_id in unique_ids:
49 | if instance_id == -1:
50 | continue
51 |
52 | instance_mask = (segmentation_mask == instance_id).astype(np.uint8) * 255
53 |
54 | points = sample_points_from_mask(instance_mask, num_points=num_points)
55 | ret[instance_id] = points
56 |
57 | return ret
58 |
59 |
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--image_path", type=str, help="Path to the images")
62 | parser.add_argument("--image_width", type=float, default=640, help="Width of the resized images")
63 | parser.add_argument("--sam_ckpt_path", type=str, help="Path to the SAM model parameters")
64 | parser.add_argument("--feature_path", type=str, help="Path to the SAM encoder features")
65 | parser.add_argument("--superpoint_projection_path", type=str, help="Path to the superpoint projection masks")
66 | # parser.add_argument("--depth_difference_path", type=str, help="Path to the depth difference")
67 | parser.add_argument("--superpoint_path", type=str, help="Path to the superpoint segmentation (.json)")
68 | parser.add_argument("--sam_mask_path", type=str, help="Path to the SAM masks")
69 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists")
70 | args = parser.parse_args()
71 |
72 | sam = None
73 |
74 | os.makedirs(args.sam_mask_path, exist_ok=True)
75 |
76 | image_list = os.listdir(args.image_path)
77 |
78 | height_original, width_original = cv2.imread(f'{args.image_path}/{image_list[0]}').shape[:2]
79 | w = int(args.image_width)
80 | h = int(w * height_original / width_original)
81 | # Note: here we simply assume that all images are of the same resolution, which can be modified if needed.
82 |
83 | for image_f in tqdm(image_list, desc='Predicting masks'):
84 | image_id = image_f.split('.')[0]
85 |
86 | if args.skip_if_exist and os.path.exists(f'{args.sam_mask_path}/{image_id}.npz'):
87 | continue
88 |
89 | if sam is None:
90 | sam = build_sam(checkpoint=args.sam_ckpt_path)
91 | sam.to(device='cuda')
92 | sam_predictor = SamPredictor(sam)
93 |
94 | seg = json.load(open(args.superpoint_path))
95 | seg_indices = np.array(seg['segIndices'])
96 | sp_ids = np.unique(seg_indices)
97 | mapping = np.zeros((sp_ids.max() + 1, ), dtype=np.int32)
98 | for i, sp_id in enumerate(sp_ids):
99 | mapping[sp_id] = i
100 | seg_indices = mapping[seg_indices]
101 |
102 | if not sam_predictor.is_image_set:
103 | img = cv2.imread(f'{args.image_path}/{image_f}')
104 | img = cv2.resize(img, (w, h))
105 | sam_predictor.set_image(img)
106 | else:
107 | sam_predictor.features = torch.from_numpy(np.load(f'{args.feature_path}/{image_id}.npy')).permute(2, 0, 1)[None].cuda()
108 |
109 | view_overseg = Image.open(f'{args.superpoint_projection_path}/{image_id}.png')
110 | view_overseg = np.asarray(view_overseg).astype(np.uint32)
111 | view_overseg = (view_overseg[..., 0] << 16) | (view_overseg[..., 1] << 8) | view_overseg[..., 2]
112 | bg = (view_overseg == 256 ** 3 - 1)
113 | # mask = Image.open(f'{args.rendered_depth_path}/{image_id}.png')
114 | # mask = np.array(mask) > 127
115 | # bg = bg | (~mask)
116 | view_overseg[bg] = 0
117 | view_overseg = mapping[view_overseg]
118 | view_overseg[bg] = -1
119 | view_overseg = remove_small_masks_from_segmentation(view_overseg)
120 |
121 | points_per_instance = sample_points_for_each_instance(view_overseg, num_points=5)
122 | points = np.array([list(_) for _ in points_per_instance.values()])
123 |
124 | transformed_points = sam_predictor.transform.apply_coords(points, (h, w))
125 | in_points = torch.as_tensor(transformed_points, device=sam_predictor.device)
126 | in_labels = torch.ones((in_points.shape[0], in_points.shape[1]), dtype=torch.int, device=in_points.device)
127 |
128 | masks_logits, iou_preds, _ = sam_predictor.predict_torch(
129 | in_points,
130 | in_labels,
131 | multimask_output=True,
132 | return_logits=True,
133 | )
134 |
135 | masks_logits = masks_logits.cpu().numpy()
136 | masks = masks_logits > 0
137 | iou_preds = iou_preds.cpu().numpy()
138 |
139 | data = {'points_per_instance': points_per_instance, 'masks': masks, 'iou_preds': iou_preds}
140 | np.savez_compressed(f'{args.sam_mask_path}/{image_id}.npz', data)
141 |
--------------------------------------------------------------------------------
/scripts/render_depth.py:
--------------------------------------------------------------------------------
1 | import argparse, cv2, trimesh, numpy as np, os, pyrender
2 | from tqdm import tqdm
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--image_path", type=str, help="Path to the images")
6 | parser.add_argument("--pose_path", type=str, help="Path to the camera poses (camera to world)")
7 | parser.add_argument("--intrinsic_path", type=str, help="Path to the intrinsic matrix (.txt)")
8 | parser.add_argument("--mesh_path", type=str, help="Path to the mesh (or point cloud)")
9 | parser.add_argument("--rendered_depth_path", type=str, help="Path to the rendered depth maps")
10 | parser.add_argument("--rendered_depth_vis_path", type=str, default='', help="Path to the visualization of rendered depth maps (optional)")
11 | args = parser.parse_args()
12 |
13 | class Renderer():
14 | def __init__(self, height=1440, width=1440):
15 | self.renderer = pyrender.OffscreenRenderer(width, height)
16 | self.scene = pyrender.Scene()
17 | # self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES
18 |
19 | def __call__(self, height, width, intrinsics, pose, mesh):
20 | self.renderer.viewport_height = height
21 | self.renderer.viewport_width = width
22 | self.scene.clear()
23 | self.scene.add(mesh)
24 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2],
25 | fx=intrinsics[0, 0], fy=intrinsics[1, 1])
26 | self.scene.add(cam, pose=self.fix_pose(pose))
27 | return self.renderer.render(self.scene) # , self.render_flags)
28 |
29 | def fix_pose(self, pose):
30 | # 3D Rotation about the x-axis.
31 | t = np.pi
32 | c = np.cos(t)
33 | s = np.sin(t)
34 | R = np.array([[1, 0, 0],
35 | [0, c, -s],
36 | [0, s, c]])
37 | axis_transform = np.eye(4)
38 | axis_transform[:3, :3] = R
39 | return pose @ axis_transform
40 |
41 | def mesh_opengl(self, mesh):
42 | material = pyrender.MetallicRoughnessMaterial(
43 | baseColorFactor=[1, 1, 1, 1.],
44 | metallicFactor=0.0,
45 | roughnessFactor=0.0,
46 | smooth=False,
47 | alphaMode='OPAQUE')
48 | return pyrender.Mesh.from_trimesh(mesh, material=material)
49 |
50 | def delete(self):
51 | self.renderer.delete()
52 |
53 | image_list = os.listdir(args.image_path)
54 |
55 | h, w = cv2.imread(f'{args.image_path}/{image_list[0]}').shape[:2]
56 | # Note: here we simply assume that all images are of the same resolution
57 |
58 | os.makedirs(args.superpoint_projection_path, exist_ok=True)
59 |
60 | ixt = np.loadtxt(args.intrinsic_path)
61 |
62 | c2ws = dict()
63 | for image_f in image_list:
64 | image_id = image_f.split('.')[0]
65 | c2w = np.loadtxt(f'{args.pose_path}/{image_id}.txt')
66 | c2ws[image_id] = c2w
67 |
68 | m = trimesh.load(args.mesh_path)
69 |
70 | renderer = Renderer()
71 | mesh_opengl = renderer.mesh_opengl(m)
72 |
73 | os.makedirs(args.rendered_depth_path, exist_ok=True)
74 | if args.rendered_depth_vis_path != '':
75 | os.makedirs(args.rendered_depth_vis_path, exist_ok=True)
76 |
77 | for k, c2w in tqdm(c2ws.items()):
78 | _, depth = renderer(h, w, ixt, c2w, mesh_opengl)
79 | np.savez_compressed(f'{args.rendered_depth_path}/{k}.npz', depth)
80 |
81 | if args.rendered_depth_vis_path != '':
82 | depth = (depth - depth.min()) / (depth.max() - depth.min())
83 | depth = (depth * 255).astype(np.uint8)
84 | depth_vis = cv2.applyColorMap(depth, cv2.COLORMAP_JET)
85 | cv2.imwrite(f'{args.rendered_depth_vis_path}/{k}.jpg', depth_vis)
86 |
--------------------------------------------------------------------------------
/scripts/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 |
4 | # Input paths
5 | data_path = "../example_data/21d970d8de"
6 | image_path = os.path.join(data_path, "images/")
7 | colmap_path = os.path.join(data_path, "sparse/")
8 | mesh_path = os.path.join(data_path, "rgb.ply")
9 |
10 | # Dependencies
11 | segmentor_path = "/home/guohaoyu/repos/ScanNet/Segmentator/segmentator"
12 | sam_ckpt_path = "../sam_vit_h_4b8939.pth"
13 |
14 | # Output paths
15 | superpoint_path = os.path.join(data_path, "rgb.0.200000.segs.json")
16 | feature_path = os.path.join(data_path, "sam_encoder_feature/")
17 | superpoint_projection_path = os.path.join(data_path, "superpoint_projection/")
18 | sam_mask_path = os.path.join(data_path, "sam_mask/")
19 | graph_structure_path = os.path.join(data_path, "graph_structure.npz")
20 | node_feature_path = os.path.join(data_path, "node_feature.npz")
21 | edge_weights_path = os.path.join(data_path, "edge_weights.npz")
22 | graph_segmentation_path = os.path.join(data_path, "graph_segmentation.npy")
23 | graph_segmentation_visualize_path = os.path.join(data_path, "graph_segmentation_visualize.ply")
24 |
25 | # Options
26 | image_width = 640 # resize images to this width
27 | skip_if_exist = 1
28 | skip_arg = "--skip_if_exist" if skip_if_exist else ""
29 |
30 | # Helper function to run shell commands
31 | def run_command(command):
32 | print(f"Running: {command}")
33 | subprocess.run(command, shell=True, check=True)
34 |
35 | # Commands
36 | run_command(
37 | "python sam_encoder_feature.py "
38 | f"--image_path {image_path} "
39 | f"--image_width {image_width} "
40 | f"--feature_path {feature_path} "
41 | f"--sam_ckpt_path {sam_ckpt_path} "
42 | f"{skip_arg}"
43 | )
44 |
45 | if os.path.exists(superpoint_path) and skip_if_exist:
46 | print("Superpoint file already exists, skipping segmentor step")
47 | else:
48 | run_command(f"{segmentor_path} {mesh_path} 0.2 50")
49 |
50 | run_command(
51 | "python superpoint_projection.py "
52 | f"--image_path {image_path} "
53 | f"--image_width {image_width} "
54 | f"--colmap_path {colmap_path} "
55 | f"--mesh_path {mesh_path} "
56 | f"--superpoint_path {superpoint_path} "
57 | f"--superpoint_projection_path {superpoint_projection_path} "
58 | f"{skip_arg}"
59 | )
60 |
61 | run_command(
62 | "python predict_masks.py "
63 | f"--image_path {image_path} "
64 | f"--image_width {image_width} "
65 | f"--feature_path {feature_path} "
66 | f"--sam_ckpt_path {sam_ckpt_path} "
67 | f"--superpoint_projection_path {superpoint_projection_path} "
68 | f"--superpoint_path {superpoint_path} "
69 | f"--sam_mask_path {sam_mask_path} "
70 | f"{skip_arg}"
71 | )
72 |
73 | run_command(
74 | "python build_graph_structure.py "
75 | f"--mesh_path {mesh_path} "
76 | f"--superpoint_path {superpoint_path} "
77 | f"--graph_structure_path {graph_structure_path} "
78 | f"{skip_arg}"
79 | )
80 |
81 | run_command(
82 | "python node_feature.py "
83 | f"--image_path {image_path} "
84 | f"--feature_path {feature_path} "
85 | f"--sam_mask_path {sam_mask_path} "
86 | f"--node_feature_path {node_feature_path} "
87 | f"{skip_arg}"
88 | )
89 |
90 | run_command(
91 | "python edge_weights.py "
92 | f"--image_path {image_path} "
93 | f"--graph_structure_path {graph_structure_path} "
94 | f"--sam_mask_path {sam_mask_path} "
95 | f"--edge_weights_path {edge_weights_path} "
96 | f"{skip_arg}"
97 | )
98 |
99 | # TODO
100 | # run_command(
101 | # "python graph_segmentation.py "
102 | # f"--image_path {image_path} "
103 | # f"--mesh_path {mesh_path} "
104 | # f"--superpoint_path {superpoint_path} "
105 | # f"--graph_structure_path {graph_structure_path} "
106 | # f"--sam_mask_path {sam_mask_path} "
107 | # f"--edge_weights_path {edge_weights_path} "
108 | # f"--graph_segmentation_path {graph_segmentation_path} "
109 | # f"--graph_segmentation_visualize_path {graph_segmentation_visualize_path} "
110 | # f"{skip_arg}"
111 | # )
112 |
--------------------------------------------------------------------------------
/scripts/sam_encoder_feature.py:
--------------------------------------------------------------------------------
1 | import argparse, cv2, numpy as np, os
2 | from tqdm import tqdm
3 | from segment_anything import build_sam, SamPredictor
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--image_path", type=str, help="Path to the images")
7 | parser.add_argument("--image_width", type=float, default=640, help="Width of the resized images")
8 | parser.add_argument("--feature_path", type=str, help="Path to the output SAM encoder features")
9 | parser.add_argument("--sam_ckpt_path", type=str, help="Path to the SAM model parameters")
10 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists")
11 | args = parser.parse_args()
12 |
13 | sam = None
14 |
15 | # sam = build_sam(checkpoint=args.sam_ckpt_path)
16 | # sam.to(device='cuda')
17 | # sam_predictor = SamPredictor(sam)
18 |
19 | image_list = os.listdir(args.image_path)
20 |
21 | os.makedirs(args.feature_path, exist_ok=True)
22 |
23 | for image_f in tqdm(image_list, desc='Extracting SAM encoder features'):
24 | image_id = image_f.split('.')[0]
25 | if os.path.exists(f'{args.feature_path}/{image_id}.npy') and args.skip_if_exist:
26 | continue
27 | if sam is None:
28 | sam = build_sam(checkpoint=args.sam_ckpt_path)
29 | sam.to(device='cuda')
30 | sam_predictor = SamPredictor(sam)
31 | img = cv2.imread(f'{args.image_path}/{image_f}')
32 | height_original, width_original = img.shape[:2]
33 | w = int(args.image_width)
34 | h = int(w * height_original / width_original)
35 | img = cv2.resize(img, (w, h))
36 | sam_predictor.set_image(img)
37 | feature_map = sam_predictor.features[0].permute(1, 2, 0).cpu().numpy()
38 | np.save(f'{args.feature_path}/{image_id}.npy', feature_map)
39 |
--------------------------------------------------------------------------------
/scripts/superpoint_projection.py:
--------------------------------------------------------------------------------
1 | import argparse, cv2, trimesh, numpy as np, os, json, torch, pycolmap
2 | from tqdm import tqdm
3 | from pytorch3d.renderer import RasterizationSettings, MeshRasterizer, FoVPerspectiveCameras
4 | from pytorch3d.structures import Meshes
5 | from PIL import Image
6 |
7 | def getProjectionMatrixK(znear, zfar, K, H, W):
8 | P = torch.zeros(4, 4)
9 | z_sign = 1.0
10 |
11 | fx = K[..., 0, 0]
12 | fy = K[..., 1, 1]
13 | cx = K[..., 0, 2]
14 | cy = K[..., 1, 2]
15 | s = K[..., 0, 1]
16 |
17 | P[0, 0] = 2 * fx / H
18 | P[1, 1] = 2 * fy / H
19 | P[0, 1] = 2 * s / W
20 |
21 | P[2, 2] = z_sign * (zfar + znear) / (zfar - znear)
22 | P[2, 3] = -(zfar * znear) / (zfar - znear)
23 | P[3, 2] = z_sign
24 |
25 | P[0, 2] = 1 - 2 * (cx / W)
26 | P[1, 2] = 1 - 2 * (cy / H)
27 |
28 | return P
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--image_path", type=str, help="Path to the images")
32 | parser.add_argument("--image_width", type=float, default=640, help="Width of the resized images")
33 | parser.add_argument("--colmap_path", type=str, help="Path to the colmap SfM results (camera poses and intrinsics)")
34 | parser.add_argument("--mesh_path", type=str, help="Path to the mesh (or point cloud)")
35 | parser.add_argument("--superpoint_path", type=str, help="Path to the superpoint segmentation (.json)")
36 | parser.add_argument("--superpoint_projection_path", type=str, help="Path to the output superpoint projection masks")
37 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists")
38 | args = parser.parse_args()
39 |
40 | image_list = os.listdir(args.image_path)
41 |
42 | m = None
43 |
44 | os.makedirs(args.superpoint_projection_path, exist_ok=True)
45 |
46 | sfm_results = pycolmap.Reconstruction(args.colmap_path)
47 |
48 | for image in tqdm(sfm_results.images.values(), desc='Superpoint projection'):
49 | w2c = np.eye(4)
50 | w2c[:3] = image.cam_from_world.matrix()
51 | c2w = np.linalg.inv(w2c)
52 | fx, fy, cx, cy = image.camera.params[:4]
53 | ixt = np.array([[fx, 0, cx, 0], [0, fy, cy, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
54 | height_original, width_original = image.camera.height, image.camera.width
55 | w = int(args.image_width)
56 | h = int(args.image_width * height_original / width_original)
57 | ixt[:2] *= w / width_original
58 |
59 | if os.path.exists(f'{args.superpoint_projection_path}/{image.name.split(".")[0]}.png') and args.skip_if_exist:
60 | continue
61 |
62 | if m is None:
63 | m = trimesh.load(args.mesh_path)
64 | seg = json.load(open(args.superpoint_path))
65 | seg_indices = np.array(seg['segIndices'])
66 |
67 | mesh = Meshes(verts=[torch.from_numpy(m.vertices).float()], faces=[torch.from_numpy(m.faces)]).cuda()
68 | faces = torch.tensor(m.faces)
69 |
70 | num_seg = seg_indices.max() + 1
71 | random_colors = torch.zeros((num_seg, 3), dtype=torch.uint8)
72 | for i in np.unique(seg_indices):
73 | random_colors[i][0] = (i >> 16) & 255
74 | random_colors[i][1] = (i >> 8) & 255
75 | random_colors[i][2] = i & 255
76 |
77 | K = getProjectionMatrixK(0.0001, 100, ixt, h, w).numpy()
78 |
79 | R, t = c2w[:3, :3].copy(), c2w[:3, 3].copy()
80 | R[:, 0] = -R[:, 0]
81 | R[:, 1] = -R[:, 1]
82 | c2w[:3, :3] = R
83 | c2w[:3, 3] = t
84 | w2c = np.linalg.inv(c2w.copy())
85 | R, t = w2c[:3, :3].copy(), w2c[:3, 3].copy()
86 | R = R.T
87 |
88 | cameras = FoVPerspectiveCameras(device=torch.device('cuda'), R=R[None].astype(np.float32), T=t[None].astype(np.float32), K=K[None].astype(np.float32))
89 | raster_settings = RasterizationSettings(image_size=(h, w))
90 | rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
91 | fragments = rasterizer(mesh)
92 |
93 | face_ids_per_pixel = fragments.pix_to_face[..., 0].cpu()
94 |
95 | vertice_id = faces[face_ids_per_pixel[0]]
96 | seg_id = torch.from_numpy(seg_indices)[vertice_id]
97 | consistent_mask = (seg_id[..., 0] == seg_id[..., 1]) & (seg_id[..., 1] == seg_id[..., 2])
98 | edge_mask = ~consistent_mask
99 | image_colors = random_colors[seg_id[..., 0]]
100 | background_color = torch.tensor([255, 255, 255], dtype=torch.uint8)
101 | mask = (face_ids_per_pixel[0] == -1) | edge_mask
102 | image_colors[mask] = background_color
103 |
104 | overseg_vis = image_colors.cpu().numpy()
105 |
106 | Image.fromarray(overseg_vis).save(f'{args.superpoint_projection_path}/{image.name.split(".")[0]}.png')
107 |
--------------------------------------------------------------------------------