├── README.md ├── assets └── teaser.jpg ├── requirements.txt └── scripts ├── build_graph_structure.py ├── edge_weights.py ├── node_feature.py ├── predict_masks.py ├── render_depth.py ├── run.py ├── sam_encoder_feature.py └── superpoint_projection.py /README.md: -------------------------------------------------------------------------------- 1 | # SAM-guided Graph Cut for 3D Instance Segmentation 2 | ### [Project Page](https://zju3dv.github.io/sam_graph) | [Video](https://www.youtube.com/watch?v=daWiQiFPpZ0) | [Paper](https://arxiv.org/abs/2312.08372) 3 | 4 | > [SAM-guided Graph Cut for 3D Instance Segmentation](https://arxiv.org/abs/2312.08372) 5 | > [Haoyu Guo](https://github.com/ghy0324)\*, [He Zhu](https://github.com/Ada4321)\*, [Sida Peng](https://pengsida.net), [Yuang Wang](https://github.com/angshine), [Yujun Shen](https://shenyujun.github.io/), [Ruizhen Hu](https://csse.szu.edu.cn/staff/ruizhenhu/), [Xiaowei Zhou](https://xzhou.me) 6 |
7 | 8 | ![introduction](./assets/teaser.jpg) 9 | 10 | 11 | 12 | ## TODO 13 | 14 | - [ ] Segmentation with / without GNN 15 | - [x] Graph construction and SAM based annotation 16 | - [ ] Processing of point clouds 17 | - [x] Processing of triangle meshes 18 | 19 | 20 | ## Setup 21 | 22 | The code is tested with Python 3.8 and PyTorch 1.12.0. 23 | 24 | 1. Clone the repository: 25 | ``` 26 | git clone https://github.com/zju3dv/SAM_Graph.git 27 | ``` 28 | 29 | 2. Install dependencies: 30 | ``` 31 | pip install -r requirements.txt 32 | ``` 33 | Clone ScanNet repository and build the [segmentor](https://github.com/ScanNet/ScanNet/tree/master/Segmentator) and modify `segmentor_path` in `scripts/run.py`. 34 | 35 | Download the [checkpoint of segment-anything model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and modify `sam_ckpt_path` in `scripts/run.py`. 36 | 37 | ## Data preparation 38 | 39 | Please download the example data from [here](https://drive.google.com/file/d/1cvrs9Hd6TOUza7OV2bd9XhHAxpTG6q7m/view?usp=drive_link) and modify `data_path` in `scripts/run.py` for fast testing. If you want to use your own data, please organize the data as the same format as the example data. 40 | 41 | ## Run 42 | 43 | The pipeline of our method is illustrated as follows: 44 | 45 | 46 | 47 | ```mermaid 48 | graph TD 49 | subgraph Input 50 | A[multi-view images] 51 | B[sensor depth] 52 | C[camera pose] 53 | D[point cloud / mesh] 54 | end 55 | 56 | E[rendered depth] 57 | F[superpoints] 58 | G[sam encoder feature] 59 | H[depth difference] 60 | I[superpoint projections] 61 | J[graph structure] 62 | K[predicted masks] 63 | L[edge weights] 64 | M[node feature] 65 | N[graph segmentation] 66 | 67 | A --> G 68 | B --> H 69 | C --> E 70 | C --> I 71 | D --> E 72 | D --> F 73 | E --> H 74 | F --> I 75 | F --> J 76 | G --> K 77 | G --> M 78 | H --> I 79 | I --> K 80 | I --> M 81 | J --> L 82 | J --> N 83 | K --> L 84 | L --> N 85 | M --> N 86 | ``` 87 | 88 | Note that `depth difference` step is optional, but is recommended if accurate sensor depth is available and the point cloud / mesh contains large holes or missing regions. 89 | 90 | To run the pipeline, simply run: 91 | 92 | ``` 93 | cd scripts 94 | python run.py 95 | ``` 96 | 97 | The results of each step will be saved in the individual folders. 98 | 99 | ## Citation 100 | 101 | ```bibtex 102 | @inproceedings{guo2024sam, 103 | title={SAM-guided Graph Cut for 3D Instance Segmentation}, 104 | author={Guo, Haoyu and Zhu, He and Peng, Sida and Wang, Yuang and Shen, Yujun and Hu, Ruizhen and Zhou, Xiaowei}, 105 | booktitle={ECCV}, 106 | year={2024} 107 | } 108 | ``` -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/SAM-Graph/8e58c46d93904525cfa94c16a4f7a6ce5aaea4d8/assets/teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.0 2 | pytorch3d==0.7.4 3 | numpy==1.23.5 4 | trimesh==3.23.5 5 | scipy 6 | PIL 7 | opencv-python 8 | segment_anything 9 | pycolmap -------------------------------------------------------------------------------- /scripts/build_graph_structure.py: -------------------------------------------------------------------------------- 1 | import argparse, trimesh, numpy as np, json, os 2 | from tqdm import tqdm, trange 3 | from scipy.spatial import KDTree 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--mesh_path", type=str, help="Path to the mesh (or point cloud)") 7 | parser.add_argument("--superpoint_path", type=str, help="Path to the superpoint segmentation (.json)") 8 | parser.add_argument("--graph_structure_path", type=str, help="Path to the output graph structure (.npz)") 9 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists") 10 | args = parser.parse_args() 11 | 12 | if args.skip_if_exist and os.path.exists(args.graph_structure_path): 13 | print(f'{args.graph_structure_path} already exists, skip.') 14 | exit() 15 | 16 | m = trimesh.load(args.mesh_path) 17 | seg = json.load(open(args.superpoint_path)) 18 | seg_indices = np.array(seg['segIndices']) 19 | sp_ids = np.unique(seg_indices) 20 | mapping = np.zeros((sp_ids.max() + 1, ), dtype=np.int32) 21 | for i, sp_id in enumerate(sp_ids): 22 | mapping[sp_id] = i 23 | seg_indices = mapping[seg_indices] 24 | 25 | pc_list = [] 26 | for sp_id in np.unique(seg_indices): 27 | pc = m.vertices[seg_indices == sp_id] 28 | pc = np.asarray(pc) 29 | if pc.shape[0] > 50: 30 | pc = pc[np.random.choice(pc.shape[0], size=50, replace=False)] 31 | pc_list.append(pc) 32 | 33 | distances = {} 34 | for i in trange(len(pc_list) - 1, desc='Building graph structure'): 35 | tree_i = KDTree(pc_list[i]) 36 | pc_concat = np.concatenate(pc_list[i + 1:]) 37 | dist_concat = tree_i.query(pc_concat)[0] 38 | dist_split = np.split(dist_concat, np.cumsum([len(pc) for pc in pc_list[i+1:]]))[:-1] 39 | for j, dist in enumerate(dist_split): 40 | distances[(i, i + j + 1)] = dist.min() 41 | # for j in range(i + 1, len(pc_list)): 42 | # dist, _ = tree_i.query(pc_list[j]) 43 | # dist = min(dist) 44 | # distances[(i, j)] = dist 45 | 46 | edges = [] 47 | for (a, b), distance in distances.items(): 48 | if distance < 0.3: 49 | edges.append((a, b)) 50 | 51 | np.savez_compressed(args.graph_structure_path, np.asarray(edges)) 52 | -------------------------------------------------------------------------------- /scripts/edge_weights.py: -------------------------------------------------------------------------------- 1 | import argparse, numpy as np, os 2 | from tqdm import tqdm 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--image_path", type=str, help="Path to the images") 6 | parser.add_argument("--graph_structure_path", type=str, help="Path to the graph structure (.npz)") 7 | parser.add_argument("--sam_mask_path", type=str, help="Path to the SAM masks") 8 | parser.add_argument("--edge_weights_path", type=str, help="Path to the output edge weights (.npz)") 9 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists") 10 | args = parser.parse_args() 11 | 12 | if args.skip_if_exist and os.path.exists(args.edge_weights_path): 13 | print(f'{args.edge_weights_path} already exists, skip.') 14 | exit() 15 | 16 | def select_mask(masks, iou_preds): 17 | if iou_preds[2] > iou_preds.max() - 0.05: 18 | return masks[2], iou_preds[2] 19 | elif iou_preds[1] > iou_preds[0] - 0.05: 20 | return masks[1], iou_preds[1] 21 | else: 22 | return masks[0], iou_preds[0] 23 | 24 | def weighted_average(data): 25 | score_sum = 0 26 | distance_sum = 0 27 | for k, (score, score1, score2, iou_pred1, iou_pred2, distance) in data.items(): 28 | score_sum += max(score1, score2) * distance * iou_pred1 * iou_pred2 29 | distance_sum += distance * iou_pred1 * iou_pred2 30 | return score_sum / distance_sum 31 | 32 | edges = np.load(args.graph_structure_path)['arr_0'] 33 | 34 | image_list = os.listdir(args.image_path) 35 | 36 | edge_weights = dict() 37 | 38 | for image_f in tqdm(image_list, desc='Calculating edge weights'): 39 | image_id = image_f.split('.')[0] 40 | mask_data = np.load(f'{args.sam_mask_path}/{image_id}.npz', allow_pickle=True)['arr_0'].tolist() 41 | points_per_instance, masks, iou_preds = mask_data['points_per_instance'], mask_data['masks'], mask_data['iou_preds'] 42 | 43 | for p1, p2 in edges: 44 | if p1 not in points_per_instance.keys(): 45 | continue 46 | if p2 not in points_per_instance.keys(): 47 | continue 48 | if (p1, p2) not in edge_weights: 49 | edge_weights[(p1, p2)] = dict() 50 | assert (p2, p1) not in edge_weights 51 | edge_weights[(p2, p1)] = dict() 52 | xy1 = points_per_instance[p1] 53 | xy2 = points_per_instance[p2] 54 | 55 | distance2d = ((np.array(xy1).mean(0) - np.array(xy2).mean(0)) ** 2).sum() ** 0.5 56 | p1_id, p2_id = list(points_per_instance.keys()).index(p1), list(points_per_instance.keys()).index(p2) 57 | 58 | mask1, mask2 = masks[p1_id], masks[p2_id] 59 | iou_pred1, iou_pred2 = iou_preds[p1_id], iou_preds[p2_id] 60 | 61 | mask1, iou_pred1 = select_mask(mask1, iou_pred1) 62 | mask2, iou_pred2 = select_mask(mask2, iou_pred2) 63 | 64 | iou = (mask1 & mask2).sum() / (mask1 | mask2).sum() 65 | ioa = (mask1 & mask2).sum() / mask1.sum() 66 | iob = (mask1 & mask2).sum() / mask2.sum() 67 | 68 | edge_weights[(p1, p2)][image_id] = [iou, ioa, iob, iou_pred1, iou_pred2, distance2d] 69 | edge_weights[(p2, p1)][image_id] = [iou, iob, ioa, iou_pred2, iou_pred1, distance2d] 70 | 71 | for k, v in edge_weights.items(): 72 | edge_weights[k] = weighted_average(v) 73 | 74 | np.savez(args.edge_weights_path, edge_weights) 75 | -------------------------------------------------------------------------------- /scripts/node_feature.py: -------------------------------------------------------------------------------- 1 | import argparse, cv2, numpy as np, os 2 | from tqdm import tqdm 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--image_path", type=str, help="Path to the images") 6 | parser.add_argument("--feature_path", type=str, help="Path to the SAM encoder features") 7 | parser.add_argument("--sam_mask_path", type=str, help="Path to the SAM masks") 8 | parser.add_argument("--node_feature_path", type=str, help="Path to the output node features") 9 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists") 10 | args = parser.parse_args() 11 | 12 | image_list = os.listdir(args.image_path) 13 | 14 | h, w = cv2.imread(f'{args.image_path}/{image_list[0]}').shape[:2] 15 | # Note: here we simply assume that all images are of the same resolution, which can be modified if needed. 16 | 17 | feature_h, feature_w = 64, 64 # resolution of SAM encoder feature maps 18 | if h < w: 19 | feature_h = int(feature_w * h / w) 20 | else: 21 | feature_w = int(feature_h * w / h) 22 | 23 | feature_dict = dict() 24 | 25 | if args.skip_if_exist and os.path.exists(args.node_feature_path): 26 | print(f'{args.node_feature_path} already exists, skip.') 27 | exit() 28 | 29 | for image_f in tqdm(image_list, desc='Calulating node features'): 30 | image_id = image_f.split('.')[0] 31 | mask_data = np.load(f'{args.sam_mask_path}/{image_id}.npz', allow_pickle=True)['arr_0'].tolist() 32 | points_per_instance, _, _ = mask_data['points_per_instance'], mask_data['masks'], mask_data['iou_preds'] 33 | feature_map = np.load(f'{args.feature_path}/{image_id}.npy') 34 | for sp_id, pts in points_per_instance.items(): 35 | xy = np.array(pts).astype(np.float32).T 36 | xy[0] = xy[0] / w * feature_w 37 | xy[1] = xy[1] / h * feature_h 38 | xy1 = xy.astype(np.int32) 39 | xy2 = xy1 + 1 40 | xy2[0] = xy2[0].clip(max=feature_w - 1) 41 | xy2[1] = xy2[1].clip(max=feature_h - 1) 42 | feature = feature_map[xy1[1], xy1[0]] * (xy2[0] - xy[0])[:, None] * (xy2[1] - xy[1])[:, None] + \ 43 | feature_map[xy2[1], xy1[0]] * (xy2[0] - xy[0])[:, None] * (xy[1] - xy1[1])[:, None] + \ 44 | feature_map[xy1[1], xy2[0]] * (xy[0] - xy1[0])[:, None] * (xy2[1] - xy[1])[:, None] + \ 45 | feature_map[xy2[1], xy2[0]] * (xy[0] - xy1[0])[:, None] * (xy[1] - xy1[1])[:, None] 46 | feature = feature.mean(0) 47 | if sp_id not in feature_dict: 48 | feature_dict[sp_id] = [feature] 49 | else: 50 | feature_dict[sp_id].append(feature) 51 | 52 | feature_mean_dict = dict() 53 | for k, v in feature_dict.items(): 54 | feature_mean_dict[k] = sum(v) / len(v) 55 | 56 | np.savez_compressed(args.node_feature_path, feature_mean_dict) 57 | -------------------------------------------------------------------------------- /scripts/predict_masks.py: -------------------------------------------------------------------------------- 1 | import argparse, cv2, numpy as np, os, json, torch 2 | np.set_printoptions(suppress=True) 3 | from PIL import Image 4 | from tqdm import tqdm 5 | from segment_anything import build_sam, SamPredictor 6 | 7 | 8 | def remove_small_masks_from_segmentation(segmentation_mask, min_size=100): 9 | unique_ids = np.unique(segmentation_mask) 10 | for instance_id in unique_ids: 11 | if instance_id == -1: 12 | continue 13 | 14 | instance_mask = (segmentation_mask == instance_id).astype(np.uint8) 15 | 16 | num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(instance_mask, connectivity=8) 17 | 18 | for i in range(1, num_labels): 19 | if stats[i, cv2.CC_STAT_AREA] < min_size: 20 | segmentation_mask[labels == i] = -1 21 | return segmentation_mask 22 | 23 | 24 | def sample_points_from_mask(mask, num_points=10): 25 | h, w = mask.shape 26 | 27 | extended_mask = np.zeros((h+2, w+2), dtype=np.uint8) 28 | extended_mask[1:-1, 1:-1] = mask 29 | 30 | distance_transform = cv2.distanceTransform(extended_mask, cv2.DIST_L2, 5) 31 | 32 | distance_transform = distance_transform[1:-1, 1:-1] 33 | sampled_points = [] 34 | 35 | for _ in range(num_points): 36 | min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(distance_transform) 37 | sampled_points.append(max_loc) 38 | 39 | cv2.circle(distance_transform, max_loc, int(max_val), 0, -1) 40 | 41 | return sampled_points 42 | 43 | 44 | def sample_points_for_each_instance(segmentation_mask, num_points=10): 45 | ret = dict() 46 | unique_ids = np.unique(segmentation_mask) 47 | 48 | for instance_id in unique_ids: 49 | if instance_id == -1: 50 | continue 51 | 52 | instance_mask = (segmentation_mask == instance_id).astype(np.uint8) * 255 53 | 54 | points = sample_points_from_mask(instance_mask, num_points=num_points) 55 | ret[instance_id] = points 56 | 57 | return ret 58 | 59 | 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--image_path", type=str, help="Path to the images") 62 | parser.add_argument("--image_width", type=float, default=640, help="Width of the resized images") 63 | parser.add_argument("--sam_ckpt_path", type=str, help="Path to the SAM model parameters") 64 | parser.add_argument("--feature_path", type=str, help="Path to the SAM encoder features") 65 | parser.add_argument("--superpoint_projection_path", type=str, help="Path to the superpoint projection masks") 66 | # parser.add_argument("--depth_difference_path", type=str, help="Path to the depth difference") 67 | parser.add_argument("--superpoint_path", type=str, help="Path to the superpoint segmentation (.json)") 68 | parser.add_argument("--sam_mask_path", type=str, help="Path to the SAM masks") 69 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists") 70 | args = parser.parse_args() 71 | 72 | sam = None 73 | 74 | os.makedirs(args.sam_mask_path, exist_ok=True) 75 | 76 | image_list = os.listdir(args.image_path) 77 | 78 | height_original, width_original = cv2.imread(f'{args.image_path}/{image_list[0]}').shape[:2] 79 | w = int(args.image_width) 80 | h = int(w * height_original / width_original) 81 | # Note: here we simply assume that all images are of the same resolution, which can be modified if needed. 82 | 83 | for image_f in tqdm(image_list, desc='Predicting masks'): 84 | image_id = image_f.split('.')[0] 85 | 86 | if args.skip_if_exist and os.path.exists(f'{args.sam_mask_path}/{image_id}.npz'): 87 | continue 88 | 89 | if sam is None: 90 | sam = build_sam(checkpoint=args.sam_ckpt_path) 91 | sam.to(device='cuda') 92 | sam_predictor = SamPredictor(sam) 93 | 94 | seg = json.load(open(args.superpoint_path)) 95 | seg_indices = np.array(seg['segIndices']) 96 | sp_ids = np.unique(seg_indices) 97 | mapping = np.zeros((sp_ids.max() + 1, ), dtype=np.int32) 98 | for i, sp_id in enumerate(sp_ids): 99 | mapping[sp_id] = i 100 | seg_indices = mapping[seg_indices] 101 | 102 | if not sam_predictor.is_image_set: 103 | img = cv2.imread(f'{args.image_path}/{image_f}') 104 | img = cv2.resize(img, (w, h)) 105 | sam_predictor.set_image(img) 106 | else: 107 | sam_predictor.features = torch.from_numpy(np.load(f'{args.feature_path}/{image_id}.npy')).permute(2, 0, 1)[None].cuda() 108 | 109 | view_overseg = Image.open(f'{args.superpoint_projection_path}/{image_id}.png') 110 | view_overseg = np.asarray(view_overseg).astype(np.uint32) 111 | view_overseg = (view_overseg[..., 0] << 16) | (view_overseg[..., 1] << 8) | view_overseg[..., 2] 112 | bg = (view_overseg == 256 ** 3 - 1) 113 | # mask = Image.open(f'{args.rendered_depth_path}/{image_id}.png') 114 | # mask = np.array(mask) > 127 115 | # bg = bg | (~mask) 116 | view_overseg[bg] = 0 117 | view_overseg = mapping[view_overseg] 118 | view_overseg[bg] = -1 119 | view_overseg = remove_small_masks_from_segmentation(view_overseg) 120 | 121 | points_per_instance = sample_points_for_each_instance(view_overseg, num_points=5) 122 | points = np.array([list(_) for _ in points_per_instance.values()]) 123 | 124 | transformed_points = sam_predictor.transform.apply_coords(points, (h, w)) 125 | in_points = torch.as_tensor(transformed_points, device=sam_predictor.device) 126 | in_labels = torch.ones((in_points.shape[0], in_points.shape[1]), dtype=torch.int, device=in_points.device) 127 | 128 | masks_logits, iou_preds, _ = sam_predictor.predict_torch( 129 | in_points, 130 | in_labels, 131 | multimask_output=True, 132 | return_logits=True, 133 | ) 134 | 135 | masks_logits = masks_logits.cpu().numpy() 136 | masks = masks_logits > 0 137 | iou_preds = iou_preds.cpu().numpy() 138 | 139 | data = {'points_per_instance': points_per_instance, 'masks': masks, 'iou_preds': iou_preds} 140 | np.savez_compressed(f'{args.sam_mask_path}/{image_id}.npz', data) 141 | -------------------------------------------------------------------------------- /scripts/render_depth.py: -------------------------------------------------------------------------------- 1 | import argparse, cv2, trimesh, numpy as np, os, pyrender 2 | from tqdm import tqdm 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--image_path", type=str, help="Path to the images") 6 | parser.add_argument("--pose_path", type=str, help="Path to the camera poses (camera to world)") 7 | parser.add_argument("--intrinsic_path", type=str, help="Path to the intrinsic matrix (.txt)") 8 | parser.add_argument("--mesh_path", type=str, help="Path to the mesh (or point cloud)") 9 | parser.add_argument("--rendered_depth_path", type=str, help="Path to the rendered depth maps") 10 | parser.add_argument("--rendered_depth_vis_path", type=str, default='', help="Path to the visualization of rendered depth maps (optional)") 11 | args = parser.parse_args() 12 | 13 | class Renderer(): 14 | def __init__(self, height=1440, width=1440): 15 | self.renderer = pyrender.OffscreenRenderer(width, height) 16 | self.scene = pyrender.Scene() 17 | # self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES 18 | 19 | def __call__(self, height, width, intrinsics, pose, mesh): 20 | self.renderer.viewport_height = height 21 | self.renderer.viewport_width = width 22 | self.scene.clear() 23 | self.scene.add(mesh) 24 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2], 25 | fx=intrinsics[0, 0], fy=intrinsics[1, 1]) 26 | self.scene.add(cam, pose=self.fix_pose(pose)) 27 | return self.renderer.render(self.scene) # , self.render_flags) 28 | 29 | def fix_pose(self, pose): 30 | # 3D Rotation about the x-axis. 31 | t = np.pi 32 | c = np.cos(t) 33 | s = np.sin(t) 34 | R = np.array([[1, 0, 0], 35 | [0, c, -s], 36 | [0, s, c]]) 37 | axis_transform = np.eye(4) 38 | axis_transform[:3, :3] = R 39 | return pose @ axis_transform 40 | 41 | def mesh_opengl(self, mesh): 42 | material = pyrender.MetallicRoughnessMaterial( 43 | baseColorFactor=[1, 1, 1, 1.], 44 | metallicFactor=0.0, 45 | roughnessFactor=0.0, 46 | smooth=False, 47 | alphaMode='OPAQUE') 48 | return pyrender.Mesh.from_trimesh(mesh, material=material) 49 | 50 | def delete(self): 51 | self.renderer.delete() 52 | 53 | image_list = os.listdir(args.image_path) 54 | 55 | h, w = cv2.imread(f'{args.image_path}/{image_list[0]}').shape[:2] 56 | # Note: here we simply assume that all images are of the same resolution 57 | 58 | os.makedirs(args.superpoint_projection_path, exist_ok=True) 59 | 60 | ixt = np.loadtxt(args.intrinsic_path) 61 | 62 | c2ws = dict() 63 | for image_f in image_list: 64 | image_id = image_f.split('.')[0] 65 | c2w = np.loadtxt(f'{args.pose_path}/{image_id}.txt') 66 | c2ws[image_id] = c2w 67 | 68 | m = trimesh.load(args.mesh_path) 69 | 70 | renderer = Renderer() 71 | mesh_opengl = renderer.mesh_opengl(m) 72 | 73 | os.makedirs(args.rendered_depth_path, exist_ok=True) 74 | if args.rendered_depth_vis_path != '': 75 | os.makedirs(args.rendered_depth_vis_path, exist_ok=True) 76 | 77 | for k, c2w in tqdm(c2ws.items()): 78 | _, depth = renderer(h, w, ixt, c2w, mesh_opengl) 79 | np.savez_compressed(f'{args.rendered_depth_path}/{k}.npz', depth) 80 | 81 | if args.rendered_depth_vis_path != '': 82 | depth = (depth - depth.min()) / (depth.max() - depth.min()) 83 | depth = (depth * 255).astype(np.uint8) 84 | depth_vis = cv2.applyColorMap(depth, cv2.COLORMAP_JET) 85 | cv2.imwrite(f'{args.rendered_depth_vis_path}/{k}.jpg', depth_vis) 86 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | # Input paths 5 | data_path = "../example_data/21d970d8de" 6 | image_path = os.path.join(data_path, "images/") 7 | colmap_path = os.path.join(data_path, "sparse/") 8 | mesh_path = os.path.join(data_path, "rgb.ply") 9 | 10 | # Dependencies 11 | segmentor_path = "/home/guohaoyu/repos/ScanNet/Segmentator/segmentator" 12 | sam_ckpt_path = "../sam_vit_h_4b8939.pth" 13 | 14 | # Output paths 15 | superpoint_path = os.path.join(data_path, "rgb.0.200000.segs.json") 16 | feature_path = os.path.join(data_path, "sam_encoder_feature/") 17 | superpoint_projection_path = os.path.join(data_path, "superpoint_projection/") 18 | sam_mask_path = os.path.join(data_path, "sam_mask/") 19 | graph_structure_path = os.path.join(data_path, "graph_structure.npz") 20 | node_feature_path = os.path.join(data_path, "node_feature.npz") 21 | edge_weights_path = os.path.join(data_path, "edge_weights.npz") 22 | graph_segmentation_path = os.path.join(data_path, "graph_segmentation.npy") 23 | graph_segmentation_visualize_path = os.path.join(data_path, "graph_segmentation_visualize.ply") 24 | 25 | # Options 26 | image_width = 640 # resize images to this width 27 | skip_if_exist = 1 28 | skip_arg = "--skip_if_exist" if skip_if_exist else "" 29 | 30 | # Helper function to run shell commands 31 | def run_command(command): 32 | print(f"Running: {command}") 33 | subprocess.run(command, shell=True, check=True) 34 | 35 | # Commands 36 | run_command( 37 | "python sam_encoder_feature.py " 38 | f"--image_path {image_path} " 39 | f"--image_width {image_width} " 40 | f"--feature_path {feature_path} " 41 | f"--sam_ckpt_path {sam_ckpt_path} " 42 | f"{skip_arg}" 43 | ) 44 | 45 | if os.path.exists(superpoint_path) and skip_if_exist: 46 | print("Superpoint file already exists, skipping segmentor step") 47 | else: 48 | run_command(f"{segmentor_path} {mesh_path} 0.2 50") 49 | 50 | run_command( 51 | "python superpoint_projection.py " 52 | f"--image_path {image_path} " 53 | f"--image_width {image_width} " 54 | f"--colmap_path {colmap_path} " 55 | f"--mesh_path {mesh_path} " 56 | f"--superpoint_path {superpoint_path} " 57 | f"--superpoint_projection_path {superpoint_projection_path} " 58 | f"{skip_arg}" 59 | ) 60 | 61 | run_command( 62 | "python predict_masks.py " 63 | f"--image_path {image_path} " 64 | f"--image_width {image_width} " 65 | f"--feature_path {feature_path} " 66 | f"--sam_ckpt_path {sam_ckpt_path} " 67 | f"--superpoint_projection_path {superpoint_projection_path} " 68 | f"--superpoint_path {superpoint_path} " 69 | f"--sam_mask_path {sam_mask_path} " 70 | f"{skip_arg}" 71 | ) 72 | 73 | run_command( 74 | "python build_graph_structure.py " 75 | f"--mesh_path {mesh_path} " 76 | f"--superpoint_path {superpoint_path} " 77 | f"--graph_structure_path {graph_structure_path} " 78 | f"{skip_arg}" 79 | ) 80 | 81 | run_command( 82 | "python node_feature.py " 83 | f"--image_path {image_path} " 84 | f"--feature_path {feature_path} " 85 | f"--sam_mask_path {sam_mask_path} " 86 | f"--node_feature_path {node_feature_path} " 87 | f"{skip_arg}" 88 | ) 89 | 90 | run_command( 91 | "python edge_weights.py " 92 | f"--image_path {image_path} " 93 | f"--graph_structure_path {graph_structure_path} " 94 | f"--sam_mask_path {sam_mask_path} " 95 | f"--edge_weights_path {edge_weights_path} " 96 | f"{skip_arg}" 97 | ) 98 | 99 | # TODO 100 | # run_command( 101 | # "python graph_segmentation.py " 102 | # f"--image_path {image_path} " 103 | # f"--mesh_path {mesh_path} " 104 | # f"--superpoint_path {superpoint_path} " 105 | # f"--graph_structure_path {graph_structure_path} " 106 | # f"--sam_mask_path {sam_mask_path} " 107 | # f"--edge_weights_path {edge_weights_path} " 108 | # f"--graph_segmentation_path {graph_segmentation_path} " 109 | # f"--graph_segmentation_visualize_path {graph_segmentation_visualize_path} " 110 | # f"{skip_arg}" 111 | # ) 112 | -------------------------------------------------------------------------------- /scripts/sam_encoder_feature.py: -------------------------------------------------------------------------------- 1 | import argparse, cv2, numpy as np, os 2 | from tqdm import tqdm 3 | from segment_anything import build_sam, SamPredictor 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--image_path", type=str, help="Path to the images") 7 | parser.add_argument("--image_width", type=float, default=640, help="Width of the resized images") 8 | parser.add_argument("--feature_path", type=str, help="Path to the output SAM encoder features") 9 | parser.add_argument("--sam_ckpt_path", type=str, help="Path to the SAM model parameters") 10 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists") 11 | args = parser.parse_args() 12 | 13 | sam = None 14 | 15 | # sam = build_sam(checkpoint=args.sam_ckpt_path) 16 | # sam.to(device='cuda') 17 | # sam_predictor = SamPredictor(sam) 18 | 19 | image_list = os.listdir(args.image_path) 20 | 21 | os.makedirs(args.feature_path, exist_ok=True) 22 | 23 | for image_f in tqdm(image_list, desc='Extracting SAM encoder features'): 24 | image_id = image_f.split('.')[0] 25 | if os.path.exists(f'{args.feature_path}/{image_id}.npy') and args.skip_if_exist: 26 | continue 27 | if sam is None: 28 | sam = build_sam(checkpoint=args.sam_ckpt_path) 29 | sam.to(device='cuda') 30 | sam_predictor = SamPredictor(sam) 31 | img = cv2.imread(f'{args.image_path}/{image_f}') 32 | height_original, width_original = img.shape[:2] 33 | w = int(args.image_width) 34 | h = int(w * height_original / width_original) 35 | img = cv2.resize(img, (w, h)) 36 | sam_predictor.set_image(img) 37 | feature_map = sam_predictor.features[0].permute(1, 2, 0).cpu().numpy() 38 | np.save(f'{args.feature_path}/{image_id}.npy', feature_map) 39 | -------------------------------------------------------------------------------- /scripts/superpoint_projection.py: -------------------------------------------------------------------------------- 1 | import argparse, cv2, trimesh, numpy as np, os, json, torch, pycolmap 2 | from tqdm import tqdm 3 | from pytorch3d.renderer import RasterizationSettings, MeshRasterizer, FoVPerspectiveCameras 4 | from pytorch3d.structures import Meshes 5 | from PIL import Image 6 | 7 | def getProjectionMatrixK(znear, zfar, K, H, W): 8 | P = torch.zeros(4, 4) 9 | z_sign = 1.0 10 | 11 | fx = K[..., 0, 0] 12 | fy = K[..., 1, 1] 13 | cx = K[..., 0, 2] 14 | cy = K[..., 1, 2] 15 | s = K[..., 0, 1] 16 | 17 | P[0, 0] = 2 * fx / H 18 | P[1, 1] = 2 * fy / H 19 | P[0, 1] = 2 * s / W 20 | 21 | P[2, 2] = z_sign * (zfar + znear) / (zfar - znear) 22 | P[2, 3] = -(zfar * znear) / (zfar - znear) 23 | P[3, 2] = z_sign 24 | 25 | P[0, 2] = 1 - 2 * (cx / W) 26 | P[1, 2] = 1 - 2 * (cy / H) 27 | 28 | return P 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--image_path", type=str, help="Path to the images") 32 | parser.add_argument("--image_width", type=float, default=640, help="Width of the resized images") 33 | parser.add_argument("--colmap_path", type=str, help="Path to the colmap SfM results (camera poses and intrinsics)") 34 | parser.add_argument("--mesh_path", type=str, help="Path to the mesh (or point cloud)") 35 | parser.add_argument("--superpoint_path", type=str, help="Path to the superpoint segmentation (.json)") 36 | parser.add_argument("--superpoint_projection_path", type=str, help="Path to the output superpoint projection masks") 37 | parser.add_argument("--skip_if_exist", default=False, action='store_true', help="Whether to skip if target already exists") 38 | args = parser.parse_args() 39 | 40 | image_list = os.listdir(args.image_path) 41 | 42 | m = None 43 | 44 | os.makedirs(args.superpoint_projection_path, exist_ok=True) 45 | 46 | sfm_results = pycolmap.Reconstruction(args.colmap_path) 47 | 48 | for image in tqdm(sfm_results.images.values(), desc='Superpoint projection'): 49 | w2c = np.eye(4) 50 | w2c[:3] = image.cam_from_world.matrix() 51 | c2w = np.linalg.inv(w2c) 52 | fx, fy, cx, cy = image.camera.params[:4] 53 | ixt = np.array([[fx, 0, cx, 0], [0, fy, cy, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 54 | height_original, width_original = image.camera.height, image.camera.width 55 | w = int(args.image_width) 56 | h = int(args.image_width * height_original / width_original) 57 | ixt[:2] *= w / width_original 58 | 59 | if os.path.exists(f'{args.superpoint_projection_path}/{image.name.split(".")[0]}.png') and args.skip_if_exist: 60 | continue 61 | 62 | if m is None: 63 | m = trimesh.load(args.mesh_path) 64 | seg = json.load(open(args.superpoint_path)) 65 | seg_indices = np.array(seg['segIndices']) 66 | 67 | mesh = Meshes(verts=[torch.from_numpy(m.vertices).float()], faces=[torch.from_numpy(m.faces)]).cuda() 68 | faces = torch.tensor(m.faces) 69 | 70 | num_seg = seg_indices.max() + 1 71 | random_colors = torch.zeros((num_seg, 3), dtype=torch.uint8) 72 | for i in np.unique(seg_indices): 73 | random_colors[i][0] = (i >> 16) & 255 74 | random_colors[i][1] = (i >> 8) & 255 75 | random_colors[i][2] = i & 255 76 | 77 | K = getProjectionMatrixK(0.0001, 100, ixt, h, w).numpy() 78 | 79 | R, t = c2w[:3, :3].copy(), c2w[:3, 3].copy() 80 | R[:, 0] = -R[:, 0] 81 | R[:, 1] = -R[:, 1] 82 | c2w[:3, :3] = R 83 | c2w[:3, 3] = t 84 | w2c = np.linalg.inv(c2w.copy()) 85 | R, t = w2c[:3, :3].copy(), w2c[:3, 3].copy() 86 | R = R.T 87 | 88 | cameras = FoVPerspectiveCameras(device=torch.device('cuda'), R=R[None].astype(np.float32), T=t[None].astype(np.float32), K=K[None].astype(np.float32)) 89 | raster_settings = RasterizationSettings(image_size=(h, w)) 90 | rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) 91 | fragments = rasterizer(mesh) 92 | 93 | face_ids_per_pixel = fragments.pix_to_face[..., 0].cpu() 94 | 95 | vertice_id = faces[face_ids_per_pixel[0]] 96 | seg_id = torch.from_numpy(seg_indices)[vertice_id] 97 | consistent_mask = (seg_id[..., 0] == seg_id[..., 1]) & (seg_id[..., 1] == seg_id[..., 2]) 98 | edge_mask = ~consistent_mask 99 | image_colors = random_colors[seg_id[..., 0]] 100 | background_color = torch.tensor([255, 255, 255], dtype=torch.uint8) 101 | mask = (face_ids_per_pixel[0] == -1) | edge_mask 102 | image_colors[mask] = background_color 103 | 104 | overseg_vis = image_colors.cpu().numpy() 105 | 106 | Image.fromarray(overseg_vis).save(f'{args.superpoint_projection_path}/{image.name.split(".")[0]}.png') 107 | --------------------------------------------------------------------------------