├── .gitignore ├── requirements.txt ├── eval_affordance_transfer.sh ├── README.md ├── interactive_segmenter.py ├── viewer.py ├── demo.py └── affordance_transfer_pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | results 3 | *.zip 4 | checkpoints 5 | *.pt 6 | *.pth 7 | temp.py 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gsplat==1.3.0 2 | ultralytics 3 | git+https://github.com/JojiJoseph/pycolmap-scene-manager.git # To not conflict with python binding of colmap 4 | git+https://github.com/JojiJoseph/segment-anything-2.git # Supports python 3.8+ and sort images with non-numeric names 5 | tyro 6 | natsort 7 | git+https://github.com/ultralytics/CLIP.git 8 | opencv-python 9 | -------------------------------------------------------------------------------- /eval_affordance_transfer.sh: -------------------------------------------------------------------------------- 1 | python affordance_transfer_pipeline.py --data-dir "./data/processed_scene_01" --checkpoint "./data/processed_scene_01/ckpts/ckpt_29999_rank0.pt" --labels_dir "./data/affordance_labels" --results-dir "./results/scene_01" 2 | python affordance_transfer_pipeline.py --data-dir "./data/processed_scene_02" --checkpoint "./data/processed_scene_02/ckpts/ckpt_29999_rank0.pt" --labels_dir "./data/affordance_labels" --results-dir "./results/scene_02" 3 | python affordance_transfer_pipeline.py --data-dir "./data/processed_scene_03" --checkpoint "./data/processed_scene_03/ckpts/ckpt_29999_rank0.pt" --labels_dir "./data/affordance_labels" --results-dir "./results/scene_03" 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gradient-Driven 3D Segmentation and Affordance Transfer in Gaussian Splatting Using 2D Masks 2 | 3 | This repository contains the code for the paper **Gradient-Driven 3D Segmentation and Affordance Transfer in Gaussian Splatting Using 2D Masks**. 4 | 5 | Project page: https://jojijoseph.github.io/3dgs-segmentation 6 | Preprint: https://arxiv.org/abs/2409.11681 7 | 8 | **Updates**: 9 | 10 | Feb 25, 2025 - Added an interactive segmentation script `interactive_segmenter.py`. 11 | 12 | Nov 26, 2024 - Please check out follow up work - [Gradient-Weighted Feature Back-Projection: A Fast Alternative to Feature Distillation in 3D Gaussian Splatting 13 | ](https://github.com/JojiJoseph/3dgs-gradient-backprojection) 14 | 15 | Oct 22, 2024 - Added support to anaglyph 3D in `viewer.py`. Press `3` to toggle anaglyph 3D. 16 | 17 | Oct 11, 2024 - Added a simple viewer `viewer.py` to see the segmented output. 18 | 19 | Sep 24, 2024 - Our poster titled **Segmentation of 3D Gaussians using Masked Gradients**, corresponds to the **preliminary work**, has been accepted to **SIGGRAPH Asia 2024**. 20 | 21 | ## Setup 22 | 23 | Please install the dependencies listed in `requirements.txt` via `pip install -r requirements.txt`. Download `sam2_hiera_large.pt` from https://huggingface.co/facebook/sam2-hiera-large/tree/main and place it in the `./checkpoints` folder. 24 | 25 | Other than that, it's a self-contained repo. Please feel free to raise an issue if you face any problems while running the code. 26 | 27 | ## Demo 28 | 29 | ```bash 30 | python demo.py --help 31 | ``` 32 | 33 | If needed, sample data (chair) can be found [here](https://drive.google.com/file/d/17xugq_6IaZBpm9B9QYU82hcwBelRR4vh/view?usp=sharing). Please create a folder named `data` on root folder and extract the contents of zip file to that folder. Then simply run `python demo.py`. 34 | 35 | To see the output after segmentation, 36 | ```bash 37 | python viewer.py 38 | ``` 39 | 40 | Please type ```python viewer.py --help``` to see more options. 41 | 42 | Trained Mip-NeRF 360 Gaussian splat models (using [gsplat](https://github.com/nerfstudio-project/gsplat) with data factor = 4) can be found [here](https://drive.google.com/file/d/1ZCTgAE6vZOeUBdR3qPXdSPY01QQBHxeO/view?usp=sharing). Extract them to `data` folder. 43 | 44 | ```bash 45 | python demo.py --data-dir data/360_v2/garden/ --checkpoint data/360_v2/garden/ckpts/ckpt_29999_rank0.pt --prompt table --rasterizer gsplat --data-factor 4 --results-dir results/garden 46 | ``` 47 | 48 | https://github.com/user-attachments/assets/62f537ca-87e8-4de8-af5d-150ea22dd1ec 49 | 50 | 51 | ## Affordance Transfer 52 | 53 | ```bash 54 | python affordance_transfer_pipeline.py --help 55 | ``` 56 | 57 | Left: Source images, Middle: 2D-2D affordance transfer, Right: 2D-3D Affordance transfer 58 | 59 | https://github.com/user-attachments/assets/65406bb7-f690-42d5-aca6-59046e08de08 60 | 61 | 62 | ## Affordance Transfer - Evaluation 63 | 64 | Download trained scenes from [here](https://drive.google.com/file/d/1-f-rW3U1H5RqdCvp-1BcuSZxrEGc3Rxo/view?usp=sharing). Original scenes (without trained Gaussian Splat models) can be found at https://users.umiacs.umd.edu/~fer/affordance/Affordance.html. 65 | 66 | ```sh 67 | sh eval_affordance_transfer.sh | tee affordance_transfer.log 68 | ``` 69 | 70 | 71 | ## Some Downstream Applications 72 | 73 | Augmented reality. 74 | 75 | https://github.com/user-attachments/assets/20ee5c8b-031e-423d-890d-368e1a9c5731 76 | 77 | Reorganizing objects in real time. 78 | 79 | https://github.com/user-attachments/assets/91cc6ef1-0fd2-44a5-8881-61a042662a95 80 | 81 | ## Acknowledgements 82 | 83 | A big thanks to the following tools/libraries, which were instrumental in this project: 84 | 85 | - [gsplat](https://github.com/nerfstudio-project/gsplat): 3DGS rasterizer. 86 | - [SAM 2](https://github.com/facebookresearch/segment-anything-2): To track masks throughout the frames. 87 | - [YOLO-World](https://github.com/AILab-CVC/YOLO-World) via [ultralytics](https://docs.ultralytics.com/models/yolo-world/): To find Initial bounding box. 88 | - [labelme](https://github.com/labelmeai/labelme): To label the source images for affordance transfer. 89 | 90 | ## Citation 91 | If you find this paper or the code helpful for your work, please consider citing our preprint, 92 | ``` 93 | @article{joji2024gradient, 94 | title={Gradient-Driven 3D Segmentation and Affordance Transfer in Gaussian Splatting from 2D Masks}, 95 | author={Joji Joseph and Bharadwaj Amrutur and Shalabh Bhatnagar}, 96 | journal={arXiv preprint arXiv:2409.11681}, 97 | year={2024}, 98 | url={https://arxiv.org/abs/2409.11681} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /interactive_segmenter.py: -------------------------------------------------------------------------------- 1 | # Basic OpenCV viewer with sliders for rotation and translation. 2 | # Can be easily customizable to different use cases. 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from gsplat import rasterization 7 | import cv2 8 | import tyro 9 | import os 10 | import numpy as np 11 | from typing import Literal 12 | import pycolmap_scene_manager as pycolmap 13 | from utils import ( 14 | load_checkpoint, 15 | get_rpy_matrix, 16 | get_viewmat_from_colmap_image, 17 | prune_by_gradients, 18 | torch_to_cv, 19 | ) 20 | 21 | from segment_anything import SamPredictor, sam_model_registry 22 | 23 | if not os.path.exists("./checkpoints/sam_vit_h_4b8939.pth"): 24 | raise ValueError( 25 | "Please download sam_vit_h_4b8939.pth from https://github.com/facebookresearch/segment-anything and save it in checkpoints folder" 26 | ) 27 | sam = sam_model_registry["vit_h"]( 28 | checkpoint="./checkpoints/sam_vit_h_4b8939.pth" 29 | ).cuda() 30 | predictor = SamPredictor(sam) 31 | 32 | device = torch.device("cuda:0") 33 | 34 | 35 | def main( 36 | data_dir: str = "./data/chair/", # colmap path 37 | checkpoint: str = "./data/chair/checkpoint.pth", # checkpoint path, can generate from original 3DGS repo 38 | rasterizer: Literal[ 39 | "inria", "gsplat" 40 | ] = "inria", # Original or GSplat for checkpoints 41 | results_dir: str = "./results/chair", 42 | data_factor: int = 1, 43 | ): 44 | splats = load_checkpoint( 45 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 46 | ) 47 | splats = prune_by_gradients(splats) 48 | 49 | means = splats["means"].float().to(device) 50 | opacities = splats["opacity"].to(device) 51 | quats = splats["rotation"].to(device) 52 | scales = splats["scaling"].float() 53 | 54 | opacities = torch.sigmoid(opacities) 55 | scales = torch.exp(scales) 56 | colors = torch.cat([splats["features_dc"], splats["features_rest"]], 1) 57 | 58 | K = splats["camera_matrix"].float() 59 | width = int(K[0, 2] * 2) 60 | height = int(K[1, 2] * 2) 61 | 62 | viewmat_idx = 0 63 | colmap_images = list(splats["colmap_project"].images.values()) 64 | 65 | cv2.namedWindow("Click and Segment", cv2.WINDOW_NORMAL) 66 | 67 | positive_point_prompts = [] 68 | negative_point_prompts = [] 69 | trigger = False 70 | 71 | mask = None 72 | 73 | def mouse_callback(event, x, y, flags, param): 74 | nonlocal trigger 75 | if event == cv2.EVENT_LBUTTONDOWN: 76 | positive_point_prompts.append((x, y)) 77 | trigger = True 78 | if event == cv2.EVENT_MBUTTONDOWN: 79 | negative_point_prompts.append((x, y)) 80 | trigger = True 81 | 82 | cv2.setMouseCallback("Click and Segment", mouse_callback) 83 | 84 | accepted_masks = {} 85 | 86 | while True: 87 | image = colmap_images[viewmat_idx] 88 | viewmat = get_viewmat_from_colmap_image(image) 89 | output, _, _ = rasterization( 90 | means, 91 | quats, 92 | scales, 93 | opacities, 94 | colors, 95 | viewmat[None].to(device), 96 | K[None].to(device), 97 | width=width, 98 | height=height, 99 | render_mode="RGB+D", 100 | sh_degree=3, 101 | ) 102 | 103 | output_cv = torch_to_cv(output[0, ..., :3]) 104 | output_cv = np.ascontiguousarray(output_cv) 105 | predictor.set_image(output_cv[..., ::-1]) 106 | if trigger: 107 | 108 | points_np = np.array(positive_point_prompts + negative_point_prompts) 109 | labels_np = np.array( 110 | [1 for _ in positive_point_prompts] 111 | + [0 for _ in negative_point_prompts] 112 | ) 113 | masks, scores, _ = predictor.predict(points_np, labels_np) 114 | mask = masks[np.argmax(scores)] 115 | trigger = False 116 | if mask is not None: 117 | output_cv[mask] = 0.5 * output_cv[mask] + 0.5 * np.array([0, 0, 255]) 118 | for x, y in positive_point_prompts: 119 | cv2.circle(output_cv, (x, y), 5, (0, 255, 0), -1) 120 | for x, y in negative_point_prompts: 121 | cv2.circle(output_cv, (x, y), 5, (0, 0, 255), -1) 122 | 123 | cv2.putText( 124 | output_cv, 125 | f"n - next view", 126 | (10, 30), 127 | cv2.FONT_HERSHEY_SIMPLEX, 128 | 1, 129 | (0, 255, 0), 130 | 1, 131 | ) 132 | cv2.putText( 133 | output_cv, 134 | f"p - previous view", 135 | (10, 60), 136 | cv2.FONT_HERSHEY_SIMPLEX, 137 | 1, 138 | (0, 255, 0), 139 | 1, 140 | ) 141 | cv2.putText( 142 | output_cv, 143 | f"a - accept mask", 144 | (10, 90), 145 | cv2.FONT_HERSHEY_SIMPLEX, 146 | 1, 147 | (0, 255, 0), 148 | 1, 149 | ) 150 | cv2.putText( 151 | output_cv, 152 | f"q - go to next stage", 153 | (10, 120), 154 | cv2.FONT_HERSHEY_SIMPLEX, 155 | 1, 156 | (0, 255, 0), 157 | 1, 158 | ) 159 | cv2.imshow("Click and Segment", output_cv) 160 | key = cv2.waitKey(10) & 0xFF 161 | if key == ord("q"): 162 | cv2.destroyAllWindows() 163 | break 164 | if key == ord("n"): 165 | viewmat_idx = (viewmat_idx + 1) % len(colmap_images) 166 | mask = None 167 | positive_point_prompts = [] 168 | negative_point_prompts = [] 169 | if key == ord("p"): 170 | viewmat_idx = (viewmat_idx - 1) % len(colmap_images) 171 | mask = None 172 | positive_point_prompts = [] 173 | negative_point_prompts = [] 174 | if key == ord("a"): 175 | accepted_masks[viewmat_idx] = { 176 | "mask": mask, 177 | "positive_point_prompts": positive_point_prompts, 178 | "negative_point_prompts": negative_point_prompts, 179 | } 180 | 181 | if key in [ord("p"), ord("n")]: 182 | if viewmat_idx in accepted_masks: 183 | mask = accepted_masks[viewmat_idx]["mask"] 184 | positive_point_prompts = accepted_masks[viewmat_idx][ 185 | "positive_point_prompts" 186 | ] 187 | negative_point_prompts = accepted_masks[viewmat_idx][ 188 | "negative_point_prompts" 189 | ] 190 | 191 | if len(accepted_masks) == 0: 192 | raise ValueError("Please accept some masks before proceeding") 193 | votes = torch.zeros((means.shape[0], 2)).to(device) 194 | bins = torch.zeros((means.shape[0], 2)).to(device) 195 | bins.requires_grad = True 196 | for idx in accepted_masks: 197 | image = colmap_images[idx] 198 | mask = accepted_masks[idx]["mask"] 199 | viewmat = get_viewmat_from_colmap_image(image) 200 | output, _, _ = rasterization( 201 | means, 202 | quats, 203 | scales, 204 | opacities, 205 | bins, 206 | viewmat[None].to(device), 207 | K[None].to(device), 208 | width=width, 209 | height=height, 210 | render_mode="RGB+D", 211 | # sh_degree=3, 212 | ) 213 | mask = torch.from_numpy(mask).float().to(device) 214 | mask2 = torch.stack([mask, 1 - mask], dim=-1) 215 | target = mask2 * output[0, ..., :2] 216 | target = target.sum() 217 | target.backward() 218 | votes = votes + bins.grad 219 | bins.grad.zero_() 220 | 221 | # Show both extraction and deletion based on the mask weights 222 | cv2.namedWindow("Extraction, Deletion, 2D Mask", cv2.WINDOW_NORMAL) 223 | cv2.createTrackbar("Background weight", "Extraction, Deletion, 2D Mask", 0, 200, lambda x: None) 224 | cv2.setTrackbarPos("Background weight", "Extraction, Deletion, 2D Mask", 100) 225 | 226 | while True: 227 | for image in splats["colmap_project"].images.values(): 228 | viewmat = get_viewmat_from_colmap_image(image) 229 | 230 | background_weight = ( 231 | cv2.getTrackbarPos("Background weight", "Extraction, Deletion, 2D Mask") / 100.0 232 | ) 233 | 234 | mask3d = votes[:, 0] > background_weight * votes[:, 1] 235 | opacities_extracted = opacities.clone() 236 | opacities_extracted[~mask3d] = 0.0 237 | opacities_deleted = opacities.clone() 238 | opacities_deleted[mask3d] = 0.0 239 | colors_mask = colors[:, 0].clone() 240 | colors_mask[mask3d] = 1.0 241 | colors_mask[~mask3d] = 0.0 242 | with torch.no_grad(): 243 | output, alphas, meta = rasterization( 244 | means, 245 | quats, 246 | scales, 247 | opacities_extracted, 248 | colors, 249 | viewmat[None].to(device), 250 | K[None].to(device), 251 | width=width, 252 | height=height, 253 | render_mode="RGB", 254 | sh_degree=3, 255 | ) 256 | 257 | output_cv_extracted = torch_to_cv(output[0]) 258 | 259 | output, alphas, meta = rasterization( 260 | means, 261 | quats, 262 | scales, 263 | opacities_deleted, 264 | colors, 265 | viewmat[None].to(device), 266 | K[None].to(device), 267 | width=width, 268 | height=height, 269 | render_mode="RGB", 270 | sh_degree=3, 271 | ) 272 | 273 | output_cv_deleted = torch_to_cv(output[0]) 274 | 275 | output, alphas, meta = rasterization( 276 | means, 277 | quats, 278 | scales, 279 | opacities, 280 | colors_mask, 281 | viewmat[None].to(device), 282 | K[None].to(device), 283 | width=width, 284 | height=height, 285 | render_mode="RGB", 286 | # sh_degree=3, 287 | ) 288 | 289 | output_cv_mask = torch_to_cv(output[0]) 290 | 291 | output_cv = cv2.hconcat( 292 | [output_cv_extracted, output_cv_deleted, output_cv_mask] 293 | ) 294 | cv2.imshow("Extraction, Deletion, 2D Mask", output_cv) 295 | key = cv2.waitKey(10) & 0xFF 296 | if key == ord("q"): 297 | break 298 | if key == ord("q"): 299 | break 300 | 301 | 302 | if __name__ == "__main__": 303 | tyro.cli(main) 304 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | # Basic OpenCV viewer with sliders for rotation and translation. 2 | # Can be easily customizable to different use cases. 3 | import torch 4 | from gsplat import rasterization 5 | import cv2 6 | import tyro 7 | import numpy as np 8 | import json 9 | from typing import Literal 10 | import pycolmap_scene_manager as pycolmap 11 | from scipy.spatial.transform import Rotation as scipyR 12 | 13 | device = torch.device("cuda:0") 14 | 15 | def get_rpy_matrix(roll, pitch, yaw): 16 | roll_matrix = np.array( 17 | [ 18 | [1, 0, 0, 0], 19 | [0, np.cos(roll), -np.sin(roll), 0], 20 | [0, np.sin(roll), np.cos(roll), 0], 21 | [0, 0, 0, 1.0], 22 | ]) 23 | 24 | pitch_matrix = np.array( 25 | [ 26 | [np.cos(pitch), 0, np.sin(pitch), 0], 27 | [0, 1, 0, 0], 28 | [-np.sin(pitch), 0, np.cos(pitch), 0], 29 | [0, 0, 0, 1.0], 30 | ]) 31 | yaw_matrix = np.array( 32 | [ 33 | [np.cos(yaw), -np.sin(yaw), 0, 0], 34 | [np.sin(yaw), np.cos(yaw), 0, 0], 35 | [0, 0, 1, 0], 36 | [0, 0, 0, 1.0], 37 | ] 38 | 39 | ) 40 | 41 | return yaw_matrix @ pitch_matrix @ roll_matrix 42 | 43 | 44 | 45 | def _detach_tensors_from_dict(d, inplace=True): 46 | if not inplace: 47 | d = d.copy() 48 | for key in d: 49 | if isinstance(d[key], torch.Tensor): 50 | d[key] = d[key].detach() 51 | return d 52 | 53 | 54 | def load_checkpoint(checkpoint: str, data_dir: str, rasterizer: Literal["original", "gsplat"]="original", data_factor: int = 1): 55 | 56 | colmap_project = pycolmap.SceneManager(f"{data_dir}/sparse/0") 57 | colmap_project.load_cameras() 58 | colmap_project.load_images() 59 | colmap_project.load_points3D() 60 | model = torch.load(checkpoint) # Make sure it is generated by 3DGS original repo 61 | if rasterizer == "original": 62 | model_params, _ = model 63 | splats = { 64 | "active_sh_degree": model_params[0], 65 | "means": model_params[1], 66 | "features_dc": model_params[2], 67 | "features_rest": model_params[3], 68 | "scaling": model_params[4], 69 | "rotation": model_params[5], 70 | "opacity": model_params[6].squeeze(1), 71 | } 72 | elif rasterizer == "gsplat": 73 | 74 | model_params = model["splats"] 75 | splats = { 76 | "active_sh_degree": 3, 77 | "means": model_params["means"], 78 | "features_dc": model_params["sh0"], 79 | "features_rest": model_params["shN"], 80 | "scaling": model_params["scales"], 81 | "rotation": model_params["quats"], 82 | "opacity": model_params["opacities"], 83 | } 84 | else: 85 | raise ValueError("Invalid rasterizer") 86 | 87 | _detach_tensors_from_dict(splats) 88 | 89 | # Assuming only one camera 90 | for camera in colmap_project.cameras.values(): 91 | camera_matrix = torch.tensor( 92 | [ 93 | [camera.fx, 0, camera.cx], 94 | [0, camera.fy, camera.cy], 95 | [0, 0, 1], 96 | ] 97 | ) 98 | break 99 | 100 | camera_matrix[:2,:3] /= data_factor 101 | 102 | splats["camera_matrix"] = camera_matrix 103 | splats["colmap_project"] = colmap_project 104 | splats["colmap_dir"] = data_dir 105 | 106 | return splats 107 | 108 | def create_checkerboard(width, height, size=64): 109 | checkerboard = np.zeros((height, width, 3), dtype=np.uint8) 110 | for y in range(0, height, size): 111 | for x in range(0, width, size): 112 | if (x // size + y // size) % 2 == 0: 113 | checkerboard[y:y + size, x:x + size] = 255 114 | else: 115 | checkerboard[y:y + size, x:x + size] = 128 116 | return checkerboard 117 | 118 | 119 | def main(data_dir: str = "./data/chair", # colmap path 120 | checkpoint: str = "./data/chair/checkpoint.pth", # checkpoint path, can generate from original 3DGS repo 121 | rasterizer: Literal["original", "gsplat"] = "original", # Original or GSplat for checkpoints 122 | mask_path: str = "./results/chair/mask3d.pth", 123 | apply_mask: bool = True, 124 | invert: bool = False, 125 | use_checkerboard_background: bool = True, 126 | data_factor: int = 1): 127 | """Program to view the extracted 3D segment. 128 | 129 | Args: 130 | data_dir: Path to the colmap project. 131 | checkpoint: checkpoint path, can generate from original 3DGS repo or using gsplat. 132 | rasterizer: The rasterizer which is used to generate the checkpoint. 133 | mask_path: Path to the mask file. 134 | apply_mask: Apply the mask to the splats. 135 | invert: Invert the mask. 136 | use_checkerboard_background: Use checkerboard background. 137 | data_factor: Factor to scale the resolution down. 138 | """ 139 | 140 | torch.set_default_device("cuda") 141 | torch.set_grad_enabled(False) 142 | 143 | splats = load_checkpoint(checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor) 144 | 145 | show_anaglyph = False 146 | 147 | 148 | means = splats["means"].float() 149 | opacities = splats["opacity"] 150 | quats = splats["rotation"] 151 | scales = splats["scaling"].float() 152 | 153 | opacities = torch.sigmoid(opacities) 154 | scales = torch.exp(scales) 155 | colors = torch.cat([splats["features_dc"], splats["features_rest"]], 1) 156 | if apply_mask: 157 | mask = torch.load(mask_path) 158 | 159 | 160 | if invert: 161 | mask = ~mask 162 | 163 | means = means[mask] 164 | opacities = opacities[mask] 165 | quats = quats[mask] 166 | scales = scales[mask] 167 | colors = colors[mask] 168 | 169 | cv2.namedWindow("Viewer", cv2.WINDOW_NORMAL) 170 | cv2.createTrackbar("Roll", "Viewer", 0, 180, lambda x: None) 171 | cv2.createTrackbar("Pitch", "Viewer", 0, 180, lambda x: None) 172 | cv2.createTrackbar("Yaw", "Viewer", 0, 180, lambda x: None) 173 | cv2.createTrackbar("X", "Viewer", 0, 1000, lambda x: None) 174 | cv2.createTrackbar("Y", "Viewer", 0, 1000, lambda x: None) 175 | cv2.createTrackbar("Z", "Viewer", 0, 1000, lambda x: None) 176 | cv2.createTrackbar("Scaling", "Viewer", 100, 100, lambda x: None) 177 | 178 | cv2.setTrackbarMin("Roll", "Viewer", -180) 179 | cv2.setTrackbarMax("Roll", "Viewer", 180) 180 | cv2.setTrackbarMin("Pitch", "Viewer", -180) 181 | cv2.setTrackbarMax("Pitch", "Viewer", 180) 182 | cv2.setTrackbarMin("Yaw", "Viewer", -180) 183 | cv2.setTrackbarMax("Yaw", "Viewer", 180) 184 | cv2.setTrackbarMin("X", "Viewer", -1000) 185 | cv2.setTrackbarMax("X", "Viewer", 1000) 186 | cv2.setTrackbarMin("Y", "Viewer", -1000) 187 | cv2.setTrackbarMax("Y", "Viewer", 1000) 188 | cv2.setTrackbarMin("Z", "Viewer", -1000) 189 | cv2.setTrackbarMax("Z", "Viewer", 1000) 190 | 191 | 192 | K = splats["camera_matrix"].float() 193 | 194 | 195 | width = int(K[0, 2] * 2) 196 | height = int(K[1, 2] * 2) 197 | 198 | def update_trackbars_from_viewmat(world_to_camera): 199 | # if torch tensor is passed, convert to numpy 200 | if isinstance(world_to_camera, torch.Tensor): 201 | world_to_camera = world_to_camera.cpu().numpy() 202 | r = scipyR.from_matrix(world_to_camera[:3,:3]) 203 | roll, pitch, yaw = r.as_euler('xyz') 204 | cv2.setTrackbarPos("Roll", "Viewer", np.rad2deg(roll).astype(int)) 205 | cv2.setTrackbarPos("Pitch", "Viewer", np.rad2deg(pitch).astype(int)) 206 | cv2.setTrackbarPos("Yaw", "Viewer", np.rad2deg(yaw).astype(int)) 207 | cv2.setTrackbarPos("X", "Viewer", int(world_to_camera[0, 3]*100)) 208 | cv2.setTrackbarPos("Y", "Viewer", int(world_to_camera[1, 3]*100)) 209 | cv2.setTrackbarPos("Z", "Viewer", int(world_to_camera[2, 3]*100)) 210 | 211 | while True: 212 | roll = cv2.getTrackbarPos("Roll", "Viewer") 213 | pitch = cv2.getTrackbarPos("Pitch", "Viewer") 214 | yaw = cv2.getTrackbarPos("Yaw", "Viewer") 215 | 216 | roll_rad = np.deg2rad(roll) 217 | pitch_rad = np.deg2rad(pitch) 218 | yaw_rad = np.deg2rad(yaw) 219 | 220 | viewmat = ( 221 | torch.tensor(get_rpy_matrix(roll_rad, pitch_rad, yaw_rad)) 222 | .float() 223 | .to(device) 224 | ) 225 | viewmat[0, 3] = cv2.getTrackbarPos("X", "Viewer") / 100.0 226 | viewmat[1, 3] = cv2.getTrackbarPos("Y", "Viewer") / 100.0 227 | viewmat[2, 3] = cv2.getTrackbarPos("Z", "Viewer") / 100.0 228 | output, alphas, meta = rasterization( 229 | means, 230 | quats, 231 | scales * cv2.getTrackbarPos("Scaling", "Viewer") / 100.0, 232 | opacities, 233 | colors, 234 | viewmat[None], 235 | K[None], 236 | width=width, 237 | height=height, 238 | sh_degree=3, 239 | ) 240 | 241 | output_cv = torch_to_cv(output[0]) 242 | if use_checkerboard_background: 243 | alphas = alphas[0].cpu().numpy() 244 | output_cv = output_cv.astype(float) * alphas + create_checkerboard(width, height).astype(float) * (1 - alphas) 245 | output_cv = np.clip(output_cv, 0, 255).astype(np.uint8) 246 | if show_anaglyph: 247 | left = output_cv.copy() 248 | left[..., :2] = 0 249 | viewmat[:, 3] -= 0.1 250 | output, _, _ = rasterization( 251 | means, 252 | quats, 253 | scales * cv2.getTrackbarPos("Scaling", "Viewer") / 100.0, 254 | opacities, 255 | colors, 256 | viewmat[None], 257 | K[None], 258 | width=width, 259 | height=height, 260 | sh_degree=3, 261 | ) 262 | right = torch_to_cv(output[0]) 263 | if use_checkerboard_background: 264 | right = right.astype(float) * alphas + create_checkerboard(width, height).astype(float) * (1 - alphas) 265 | right = np.clip(right, 0, 255).astype(np.uint8) 266 | right[..., -1] = 0 267 | output_cv = left + right 268 | 269 | cv2.imshow("Viewer", output_cv) 270 | key = cv2.waitKey(1) 271 | if key == ord("q"): 272 | break 273 | elif key == ord("3"): 274 | show_anaglyph = not show_anaglyph 275 | if key in [ord("w"), ord("a"), ord("s"), ord("d")]: 276 | if key == ord("w"): 277 | viewmat[2, 3] -= 0.1 278 | if key == ord("s"): 279 | viewmat[2, 3] += 0.1 280 | if key == ord("a"): 281 | viewmat[0, 3] += 0.1 282 | if key == ord("d"): 283 | viewmat[0, 3] -= 0.1 284 | update_trackbars_from_viewmat(viewmat) 285 | 286 | 287 | def torch_to_cv(tensor, permute=False): 288 | if permute: 289 | tensor = torch.clamp(tensor.permute(1, 2, 0), 0, 1).cpu().numpy() 290 | else: 291 | tensor = torch.clamp(tensor, 0, 1).cpu().numpy() 292 | return (tensor * 255).astype(np.uint8)[..., ::-1] 293 | 294 | 295 | if __name__ == "__main__": 296 | tyro.cli(main) 297 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | import tyro 3 | import os 4 | import torch 5 | import cv2 6 | import imageio # To generate gifs 7 | import pycolmap_scene_manager as pycolmap 8 | from gsplat import rasterization 9 | from ultralytics import YOLOWorld 10 | from sam2.build_sam import build_sam2_video_predictor 11 | import numpy as np 12 | 13 | def torch_to_cv(tensor): 14 | img_cv = tensor.detach().cpu().numpy()[..., ::-1] 15 | img_cv = np.clip(img_cv*255, 0, 255).astype(np.uint8) 16 | return img_cv 17 | 18 | 19 | def _detach_tensors_from_dict(d, inplace=True): 20 | if not inplace: 21 | d = d.copy() 22 | for key in d: 23 | if isinstance(d[key], torch.Tensor): 24 | d[key] = d[key].detach() 25 | return d 26 | 27 | 28 | def load_checkpoint(checkpoint: str, data_dir: str, rasterizer: Literal["original", "gsplat"]="original", data_factor: int = 1): 29 | 30 | colmap_project = pycolmap.SceneManager(f"{data_dir}/sparse/0") 31 | colmap_project.load_cameras() 32 | colmap_project.load_images() 33 | colmap_project.load_points3D() 34 | model = torch.load(checkpoint) # Make sure it is generated by 3DGS original repo 35 | if rasterizer == "original": 36 | model_params, _ = model 37 | splats = { 38 | "active_sh_degree": model_params[0], 39 | "means": model_params[1], 40 | "features_dc": model_params[2], 41 | "features_rest": model_params[3], 42 | "scaling": model_params[4], 43 | "rotation": model_params[5], 44 | "opacity": model_params[6].squeeze(1), 45 | } 46 | elif rasterizer == "gsplat": 47 | print(model["splats"].keys()) 48 | model_params = model["splats"] 49 | splats = { 50 | "active_sh_degree": 3, 51 | "means": model_params["means"], 52 | "features_dc": model_params["sh0"], 53 | "features_rest": model_params["shN"], 54 | "scaling": model_params["scales"], 55 | "rotation": model_params["quats"], 56 | "opacity": model_params["opacities"], 57 | } 58 | else: 59 | raise ValueError("Invalid rasterizer") 60 | 61 | _detach_tensors_from_dict(splats) 62 | 63 | # Assuming only one camera 64 | for camera in colmap_project.cameras.values(): 65 | camera_matrix = torch.tensor( 66 | [ 67 | [camera.fx, 0, camera.cx], 68 | [0, camera.fy, camera.cy], 69 | [0, 0, 1], 70 | ] 71 | ) 72 | break 73 | 74 | camera_matrix[:2,:3] /= data_factor 75 | 76 | splats["camera_matrix"] = camera_matrix 77 | splats["colmap_project"] = colmap_project 78 | splats["colmap_dir"] = data_dir 79 | 80 | return splats 81 | 82 | def get_viewmat_from_colmap_image(image): 83 | viewmat = torch.eye(4).float()#.to(device) 84 | viewmat[:3, :3] = torch.tensor(image.R()).float()#.to(device) 85 | viewmat[:3, 3] = torch.tensor(image.t).float()#.to(device) 86 | return viewmat 87 | 88 | def create_checkerboard(width, height, size=64): 89 | checkerboard = np.zeros((height, width, 3), dtype=np.uint8) 90 | for y in range(0, height, size): 91 | for x in range(0, width, size): 92 | if (x // size + y // size) % 2 == 0: 93 | checkerboard[y:y + size, x:x + size] = 255 94 | else: 95 | checkerboard[y:y + size, x:x + size] = 128 96 | return checkerboard 97 | 98 | 99 | 100 | 101 | def render_to_dir(output_dir: str, splats, feedback: bool = False): 102 | if feedback: 103 | cv2.destroyAllWindows() 104 | cv2.namedWindow("Initial Rendering", cv2.WINDOW_NORMAL) 105 | os.makedirs(output_dir, exist_ok=True) 106 | colmap_project = splats["colmap_project"] 107 | frame_idx = 0 108 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 109 | image_name = image.name#.split(".")[0] + ".jpg" 110 | image_path = f"{output_dir}/{image_name}" 111 | means = splats["means"] 112 | colors_dc = splats["features_dc"] 113 | colors_rest = splats["features_rest"] 114 | colors = torch.cat([colors_dc, colors_rest], dim=1) 115 | opacities = torch.sigmoid(splats["opacity"]) 116 | scales = torch.exp(splats["scaling"]) 117 | quats = splats["rotation"] 118 | viewmat = get_viewmat_from_colmap_image(image) 119 | K = splats["camera_matrix"] 120 | output, _, info = rasterization( 121 | means, 122 | quats, 123 | scales, 124 | opacities, 125 | colors, 126 | viewmats=viewmat[None], 127 | Ks=K[None], 128 | sh_degree=3, 129 | width=K[0, 2] * 2, 130 | height=K[1, 2] * 2, 131 | ) 132 | output_cv = torch_to_cv(output[0]) 133 | imageio.imsave(image_path, output_cv[:, :, ::-1]) 134 | if feedback: 135 | cv2.imshow("Initial Rendering", output_cv) 136 | cv2.waitKey(1) 137 | frame_idx += 1 138 | 139 | def prune_by_gradients(splats): 140 | colmap_project = splats["colmap_project"] 141 | frame_idx = 0 142 | means = splats["means"] 143 | colors_dc = splats["features_dc"] 144 | colors_rest = splats["features_rest"] 145 | colors = torch.cat([colors_dc, colors_rest], dim=1) 146 | opacities = torch.sigmoid(splats["opacity"]) 147 | scales = torch.exp(splats["scaling"]) 148 | quats = splats["rotation"] 149 | K = splats["camera_matrix"] 150 | colors.requires_grad = True 151 | gaussian_grads = torch.zeros(colors.shape[0], device=colors.device) 152 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 153 | viewmat = get_viewmat_from_colmap_image(image) 154 | output, _, _ = rasterization( 155 | means, 156 | quats, 157 | scales, 158 | opacities, 159 | colors[:,0,:], 160 | viewmats=viewmat[None], 161 | Ks=K[None], 162 | # sh_degree=3, 163 | width=K[0, 2] * 2, 164 | height=K[1, 2] * 2, 165 | ) 166 | frame_idx += 1 167 | # pseudo_loss = ((output.detach() + 1 - output)**2).mean() 168 | pseudo_loss = output.mean() 169 | pseudo_loss.backward() 170 | # print(colors.grad.shape) 171 | gaussian_grads += (colors.grad[:,0]).norm(dim=[1]) 172 | colors.grad.zero_() 173 | 174 | mask = gaussian_grads > 0 175 | print("Total splats", len(gaussian_grads)) 176 | print("Pruned", (~mask).sum(), "splats") 177 | print("Remaining", mask.sum(), "splats") 178 | splats = splats.copy() 179 | splats["means"] = splats["means"][mask] 180 | splats["features_dc"] = splats["features_dc"][mask] 181 | splats["features_rest"] = splats["features_rest"][mask] 182 | splats["scaling"] = splats["scaling"][mask] 183 | splats["rotation"] = splats["rotation"][mask] 184 | splats["opacity"] = splats["opacity"][mask] 185 | return splats, mask 186 | 187 | def test_proper_pruning(splats, splats_after_pruning): 188 | colmap_project = splats["colmap_project"] 189 | frame_idx = 0 190 | means = splats["means"] 191 | colors_dc = splats["features_dc"] 192 | colors_rest = splats["features_rest"] 193 | colors = torch.cat([colors_dc, colors_rest], dim=1) 194 | opacities = torch.sigmoid(splats["opacity"]) 195 | scales = torch.exp(splats["scaling"]) 196 | quats = splats["rotation"] 197 | 198 | means_pruned = splats_after_pruning["means"] 199 | colors_dc_pruned = splats_after_pruning["features_dc"] 200 | colors_rest_pruned = splats_after_pruning["features_rest"] 201 | colors_pruned = torch.cat([colors_dc_pruned, colors_rest_pruned], dim=1) 202 | opacities_pruned = torch.sigmoid(splats_after_pruning["opacity"]) 203 | scales_pruned = torch.exp(splats_after_pruning["scaling"]) 204 | quats_pruned = splats_after_pruning["rotation"] 205 | 206 | 207 | 208 | K = splats["camera_matrix"] 209 | total_error = 0 210 | max_pixel_error = 0 211 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 212 | viewmat = get_viewmat_from_colmap_image(image) 213 | output, _, _ = rasterization( 214 | means, 215 | quats, 216 | scales, 217 | opacities, 218 | colors, 219 | viewmats=viewmat[None], 220 | Ks=K[None], 221 | sh_degree=3, 222 | width=K[0, 2] * 2, 223 | height=K[1, 2] * 2, 224 | ) 225 | 226 | output_pruned, _, _ = rasterization( 227 | means_pruned, 228 | quats_pruned, 229 | scales_pruned, 230 | opacities_pruned, 231 | colors_pruned, 232 | viewmats=viewmat[None], 233 | Ks=K[None], 234 | sh_degree=3, 235 | width=K[0, 2] * 2, 236 | height=K[1, 2] * 2, 237 | ) 238 | 239 | total_error += torch.abs((output - output_pruned)).sum() 240 | max_pixel_error = max(max_pixel_error, torch.abs((output - output_pruned)).max()) 241 | 242 | percentage_pruned = (len(splats["means"]) - len(splats_after_pruning["means"])) / len(splats["means"]) * 100 243 | 244 | assert max_pixel_error < 1 / (255*2), "Max pixel error should be less than 1/(255*2), safety margin" 245 | print("Report {}% pruned, max pixel error = {}, total pixel error = {}".format(percentage_pruned, max_pixel_error, total_error)) 246 | 247 | 248 | def get_mask3d(splats, prompt, data_dir: str, results_dir: str, show_visual_feedback: bool = False, mask_interval: int = 1, voting_method: Literal["gradient", "binary", "projection"] = "gradient", mask_dir=None): 249 | checkpoint = "./checkpoints/sam2_hiera_large.pt" 250 | if not os.path.exists(checkpoint): 251 | raise RuntimeError("Please download the checkpoint sam2_hiera_large.pt to checkpoints folder") 252 | 253 | if show_visual_feedback: 254 | cv2.destroyAllWindows() 255 | cv2.namedWindow("2D Mask", cv2.WINDOW_NORMAL) 256 | model_cfg = "sam2_hiera_l.yaml" 257 | mask_predictor = build_sam2_video_predictor(model_cfg, checkpoint) 258 | yolo_world = YOLOWorld("yolov8s-worldv2.pt") 259 | yolo_world.set_classes([prompt]) 260 | colmap_project = splats["colmap_project"] 261 | first_image_name = sorted(colmap_project.images.values(), key=lambda x: x.name)[0].name 262 | first_image_path = f"{results_dir}/images/{first_image_name}" 263 | frame_idx = 0 264 | with torch.autocast("cuda", dtype=torch.bfloat16): 265 | state = mask_predictor.init_state(f"{results_dir}/images/") 266 | 267 | result = yolo_world(first_image_path)[0] 268 | 269 | box = result.boxes[0].xyxy[0].tolist() 270 | 271 | 272 | # add new prompts and instantly get the output on the same frame 273 | _, object_ids, masks = mask_predictor.add_new_points_or_box( 274 | state, box=box, frame_idx=0, obj_id=0 275 | ) 276 | 277 | means = splats["means"] 278 | colors_dc = splats["features_dc"] 279 | colors_rest = splats["features_rest"] 280 | 281 | colors = colors_dc[:,0,:] * 0 # Just to show that gradient (opacity * transmittance) is independent of color. Any value will work. 282 | 283 | 284 | opacities = torch.sigmoid(splats["opacity"]) 285 | scales = torch.exp(splats["scaling"]) 286 | quats = splats["rotation"] 287 | K = splats["camera_matrix"] 288 | colors.requires_grad = True 289 | 290 | gaussian_grads = torch.zeros(colors.shape[0], device=colors.device) 291 | mask_dir = f"{results_dir}/masks_with_images" 292 | mask_bin_dir = f"{results_dir}/masks_bin" 293 | os.makedirs(mask_dir, exist_ok=True) 294 | os.makedirs(mask_bin_dir, exist_ok=True) 295 | 296 | # propagate the prompts to get masklets throughout the video 297 | frame_idx = 0 298 | for image, (frame_idx, object_ids, masks) in zip( 299 | sorted(colmap_project.images.values(), key=lambda x: x.name), 300 | mask_predictor.propagate_in_video(state), 301 | ): 302 | 303 | image_name = image.name#.split(".")[0] + ".jpg" 304 | # image_name = f"frame_" 305 | image_path = f"{results_dir}/images/{image_name}" 306 | mask_path = f"{mask_dir}/{image_name}" 307 | mask_bin_path = f"{mask_bin_dir}/{image.name}" 308 | 309 | frame = cv2.imread(image_path) 310 | mask = masks[0, 0].cpu().numpy() >= 0 311 | mask = mask.astype(float) 312 | mask = cv2.blur(mask, (7, 7)) 313 | mask = mask > 0 314 | mask = mask.astype(bool) 315 | 316 | mask_red = np.zeros_like(frame) 317 | 318 | mask_red[:, :, -1][mask] = 255 319 | mask_bin = mask.astype(np.uint8) * 255 320 | cv2.imwrite(mask_bin_path, mask_bin) 321 | output = cv2.addWeighted(frame, 1, mask_red, 0.5, 0) 322 | cv2.imwrite(mask_path, output) 323 | 324 | 325 | frame_idx += 1 326 | if (frame_idx % mask_interval != 1) and (mask_interval != 1): 327 | continue 328 | if show_visual_feedback: 329 | cv2.imshow("2D Mask", output) 330 | cv2.waitKey(1) 331 | 332 | viewmat = get_viewmat_from_colmap_image(image) 333 | 334 | 335 | width = frame.shape[1] 336 | height = frame.shape[0] 337 | 338 | output_for_grad, _, meta = rasterization( 339 | means, 340 | quats, 341 | scales, 342 | opacities, 343 | colors, 344 | viewmat[None], 345 | K[None], 346 | width=width, 347 | height=height, 348 | # sh_degree=3, 349 | ) 350 | 351 | target = output_for_grad[0] * torch.from_numpy(mask)[..., None].cuda().float() 352 | loss = 1 * target.mean() 353 | 354 | loss.backward(retain_graph=True) 355 | if voting_method == "gradient": 356 | mins = torch.min(colors.grad, dim=-1).values 357 | maxes = torch.max(colors.grad, dim=-1).values 358 | assert torch.allclose(mins , maxes), "Something is wrong with gradient calculation" 359 | gaussian_grads += (colors.grad).norm(dim=[1]) 360 | elif voting_method == "binary": 361 | gaussian_grads += 1 * (colors.grad.norm(dim=[1]) > 0) 362 | elif voting_method == "projection": 363 | 364 | means2d = np.round(meta["means2d"].detach().cpu().numpy()).astype(int) 365 | means2d_mask = (means2d[:, 0] >= 0) & (means2d[:, 0] < width) & (means2d[:, 1] >= 0) & (means2d[:, 1] < height) 366 | means2d = means2d[means2d_mask] 367 | gaussian_ids = meta["gaussian_ids"].detach().cpu().numpy() 368 | gaussian_ids = gaussian_ids[means2d_mask] 369 | 370 | means2d_mask = mask[means2d[:, 1], means2d[:, 0]] # Check if the splat is in the mask 371 | gaussian_grads[torch.from_numpy(gaussian_ids[~means2d_mask]).long()] -= 1 372 | gaussian_grads[torch.from_numpy(gaussian_ids[means2d_mask]).long()] += 1 373 | else: 374 | raise ValueError("Invalid voting method") 375 | 376 | colors.grad.zero_() 377 | 378 | 379 | mask_inverted = ~mask 380 | target = output_for_grad[0] * torch.from_numpy(mask_inverted).cuda()[ 381 | ..., None 382 | ] 383 | loss = 1 * target.mean() 384 | loss.backward(retain_graph=False) 385 | 386 | if voting_method == "gradient": 387 | gaussian_grads -= (colors.grad).norm(dim=[1]) 388 | elif voting_method == "binary": 389 | gaussian_grads -= 1 * ((colors.grad).norm(dim=[1]) > 0) 390 | elif voting_method == "projection": 391 | pass 392 | else: 393 | raise ValueError("Invalid voting method") 394 | colors.grad.zero_() 395 | 396 | mask_3d = gaussian_grads > 0 397 | mask_3d_inverted = gaussian_grads < 0 # We don't need Gaussians without any influence ie gaussian_grads == 0 398 | return mask_3d, mask_3d_inverted 399 | 400 | def apply_mask3d(splats, mask3d,mask3d_inverted, results_dir: str): 401 | if mask3d_inverted == None: 402 | mask3d_inverted = ~mask3d 403 | extracted = splats.copy() 404 | deleted = splats.copy() 405 | masked = splats.copy() 406 | extracted["means"] = extracted["means"][mask3d] 407 | extracted["features_dc"] = extracted["features_dc"][mask3d] 408 | extracted["features_rest"] = extracted["features_rest"][mask3d] 409 | extracted["scaling"] = extracted["scaling"][mask3d] 410 | extracted["rotation"] = extracted["rotation"][mask3d] 411 | extracted["opacity"] = extracted["opacity"][mask3d] 412 | 413 | deleted["means"] = deleted["means"][mask3d_inverted] 414 | deleted["features_dc"] = deleted["features_dc"][mask3d_inverted] 415 | deleted["features_rest"] = deleted["features_rest"][mask3d_inverted] 416 | deleted["scaling"] = deleted["scaling"][mask3d_inverted] 417 | deleted["rotation"] = deleted["rotation"][mask3d_inverted] 418 | deleted["opacity"] = deleted["opacity"][mask3d_inverted] 419 | 420 | masked["features_dc"][mask3d] = 1#(1 - 0.5) / 0.2820947917738781 421 | masked["features_dc"][~mask3d] = 0#(0 - 0.5) / 0.2820947917738781 422 | masked["features_rest"][~mask3d] = 0 423 | 424 | return extracted, deleted, masked 425 | 426 | 427 | 428 | def render_to_gif(output_path: str, splats, feedback: bool = False, use_checkerboard_background: bool = False, no_sh: bool=False): 429 | if feedback: 430 | cv2.destroyAllWindows() 431 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 432 | frames = [] 433 | means = splats["means"] 434 | colors_dc = splats["features_dc"] 435 | colors_rest = splats["features_rest"] 436 | colors = torch.cat([colors_dc, colors_rest], dim=1) 437 | if no_sh == True: 438 | colors = colors_dc[:,0,:] 439 | opacities = torch.sigmoid(splats["opacity"]) 440 | scales = torch.exp(splats["scaling"]) 441 | quats = splats["rotation"] 442 | K = splats["camera_matrix"] 443 | aux_dir = output_path + ".images" 444 | os.makedirs(aux_dir, exist_ok=True) 445 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 446 | viewmat = get_viewmat_from_colmap_image(image) 447 | output, alphas, meta = rasterization( 448 | means, 449 | quats, 450 | scales, 451 | opacities, 452 | colors, 453 | viewmat[None], 454 | K[None], 455 | width=K[0, 2]*2, 456 | height=K[1, 2]*2, 457 | sh_degree=3 if not no_sh else None, 458 | ) 459 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 460 | if use_checkerboard_background: 461 | checkerboard = create_checkerboard(frame.shape[1], frame.shape[0]) 462 | alphas = alphas[0].detach().cpu().numpy() 463 | frame = frame * alphas + checkerboard * (1 - alphas) 464 | frame = np.clip(frame, 0, 255).astype(np.uint8) 465 | frames.append(frame) 466 | if feedback: 467 | cv2.imshow("Rendering", frame[...,::-1]) 468 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[...,::-1]) 469 | cv2.waitKey(1) 470 | imageio.mimsave(output_path, frames, fps=10) 471 | if feedback: 472 | cv2.destroyAllWindows() 473 | 474 | def render_mask_pred(output_dir: str, splats, feedback: bool = False): 475 | if feedback: 476 | cv2.destroyAllWindows() 477 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 478 | frames = [] 479 | means = splats["means"] 480 | colors_dc = splats["features_dc"] 481 | colors_rest = splats["features_rest"] 482 | 483 | colors = colors_dc[:,0,:] 484 | opacities = torch.sigmoid(splats["opacity"]) 485 | scales = torch.exp(splats["scaling"]) 486 | quats = splats["rotation"] 487 | K = splats["camera_matrix"] 488 | aux_dir = output_dir 489 | os.makedirs(aux_dir, exist_ok=True) 490 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 491 | viewmat = get_viewmat_from_colmap_image(image) 492 | output, alphas, meta = rasterization( 493 | means, 494 | quats, 495 | scales, 496 | opacities, 497 | colors, 498 | viewmat[None], 499 | K[None], 500 | width=K[0, 2]*2, 501 | height=K[1, 2]*2, 502 | sh_degree=None, 503 | ) 504 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 505 | frame = frame > 128 506 | frame = frame.astype(np.uint8) * 255 507 | # cv2.imshow("Checkerboard", checkerboard) 508 | # frames.append(frame) 509 | if feedback: 510 | cv2.imshow("Rendering", frame[...,::-1]) 511 | cv2.imwrite(f"{aux_dir}/{image.name}", frame) 512 | cv2.waitKey(1) 513 | # imageio.mimsave(output_path, frames, fps=10) 514 | if feedback: 515 | cv2.destroyAllWindows() 516 | 517 | def export_mask(mask3d, prune_mask, results_dir: str): 518 | if prune_mask is not None: 519 | mask3d_export = torch.zeros_like(prune_mask).bool() 520 | mask3d_export[prune_mask] = mask3d 521 | torch.save(mask3d_export, f"{results_dir}/mask3d.pth") 522 | else: 523 | torch.save(mask3d, f"{results_dir}/mask3d.pth") 524 | 525 | def main( 526 | data_dir: str = "./data/chair", # colmap path 527 | checkpoint: str = "./data/chair/checkpoint.pth", # checkpoint path, can generate from original 3DGS repo 528 | prompt: str = "chair", # prompt 529 | results_dir: str = "./results/chair", # output path 530 | show_visual_feedback: bool = True, # Will show opencv window, 531 | rasterizer: Literal["original", "gsplat"] = "original", # Original or GSplat for checkpoints 532 | data_factor: int = 1, 533 | mask_interval: int = 1, 534 | voting_method: Literal["gradient", "binary", "projection"] = "gradient", 535 | 536 | ): 537 | """ 538 | Demo program. 539 | 540 | 541 | Args: 542 | data_dir: Path to the colmap project directory 543 | checkpoint: Path to the checkpoint file 544 | prompt: Prompt for the mask prediction 545 | results_dir: Path to the output directory 546 | show_visual_feedback: Show visual feedback 547 | rasterizer: Rasterizer used to create the checkpoint 548 | data_factor: Factor to scale down the resolution of the images 549 | mask_interval: The interval between images for taking masked gradients 550 | voting_method: Voting method to generate 3D mask 551 | """ 552 | 553 | if not torch.cuda.is_available(): 554 | raise RuntimeError("CUDA is required for this demo") 555 | 556 | torch.set_default_device('cuda') 557 | 558 | os.makedirs(results_dir, exist_ok=True) 559 | splats = load_checkpoint(checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor) 560 | splats_optimized, prune_mask = prune_by_gradients(splats) 561 | test_proper_pruning(splats, splats_optimized) 562 | 563 | del splats 564 | splats = splats_optimized 565 | 566 | render_to_dir(f"{results_dir}/images", splats, show_visual_feedback) 567 | mask3d, mask3d_inverted = get_mask3d(splats, prompt, data_dir, results_dir, show_visual_feedback, mask_interval=mask_interval, voting_method=voting_method) 568 | 569 | export_mask(mask3d, prune_mask, results_dir) 570 | 571 | extracted, deleted, masked = apply_mask3d(splats, mask3d, mask3d_inverted, results_dir) 572 | # render_mask_pred(f"{results_dir}/mask_bin_pred", masked, show_visual_feedback) 573 | render_to_gif(f"{results_dir}/extracted.gif", extracted, show_visual_feedback, use_checkerboard_background=True) 574 | render_to_gif(f"{results_dir}/deleted.gif", deleted, show_visual_feedback) 575 | render_to_gif(f"{results_dir}/masked.gif", masked, show_visual_feedback, no_sh=True) 576 | 577 | 578 | if __name__ == "__main__": 579 | tyro.cli(main) 580 | -------------------------------------------------------------------------------- /affordance_transfer_pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | The pipeline is as follows. 3 | 4 | 1. Load labels to database 5 | 2. Render images to the results folder 6 | 3. Calculate affordance map and save 7 | 4. Render and vote 8 | 5. Evaluate the results 9 | """ 10 | 11 | from collections import defaultdict 12 | import pickle as pkl 13 | import time 14 | import numpy as np 15 | import tyro 16 | import imageio 17 | import pycolmap_scene_manager as pycolmap 18 | import torch 19 | import os 20 | import faiss 21 | import base64 22 | import cv2 23 | import json 24 | import PIL.Image as Image 25 | from torchvision import transforms as T 26 | from gsplat import rasterization 27 | 28 | if not torch.cuda.is_available(): 29 | raise RuntimeError("CUDA is required for this script") 30 | torch.set_default_device("cuda") 31 | device = "cuda" 32 | 33 | SIZE = 224 * 4 34 | FEATURE_MAP_SIZE = 64 35 | 36 | LABEL_TO_IDX = { 37 | "background": 0, 38 | "grasp": 1, 39 | "cut": 2, 40 | "scoop": 3, 41 | "contain": 4, 42 | "pound": 5, 43 | "support": 6, 44 | "wrap grasp": 7, 45 | } 46 | 47 | IDX_TO_LABEL = ["background", "grasp", "cut", "scoop", "contain", "pound", "support", "wrap grasp"] 48 | 49 | transform = T.Compose([T.Resize((SIZE, SIZE)), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 50 | PATCH_SIZE = 14 51 | 52 | DIM = 1024 53 | 54 | LABEL_TO_IDX = { 55 | "background": 0, 56 | "grasp": 1, 57 | "cut": 2, 58 | "scoop": 3, 59 | "contain": 4, 60 | "pound": 5, 61 | "support": 6, 62 | "wrap grasp": 7, 63 | } 64 | 65 | IDX_TO_LABEL = ["background", "grasp", "cut", "scoop", "contain", "pound", "support", "wrap grasp"] 66 | 67 | 68 | 69 | feature_extractor = ( 70 | torch.hub.load("facebookresearch/dinov2:main", "dinov2_vitl14_reg") 71 | .to(device) 72 | .eval() 73 | ) 74 | 75 | def torch_to_cv(tensor): 76 | img_cv = tensor.detach().cpu().numpy()[..., ::-1] 77 | img_cv = np.clip(img_cv * 255, 0, 255).astype(np.uint8) 78 | return img_cv 79 | 80 | 81 | def _detach_tensors_from_dict(d, inplace=True): 82 | if not inplace: 83 | d = d.copy() 84 | for key in d: 85 | if isinstance(d[key], torch.Tensor): 86 | d[key] = d[key].detach() 87 | return d 88 | 89 | 90 | def load_checkpoint( 91 | checkpoint: str, data_dir: str, rasterizer: str = "original", data_factor: int = 1 92 | ): 93 | 94 | colmap_project = pycolmap.SceneManager(f"{data_dir}/sparse/0") 95 | colmap_project.load_cameras() 96 | colmap_project.load_images() 97 | colmap_project.load_points3D() 98 | model = torch.load(checkpoint) # Make sure it is generated by 3DGS original repo 99 | if rasterizer == "original": 100 | model_params, _ = model 101 | splats = { 102 | "active_sh_degree": model_params[0], 103 | "means": model_params[1], 104 | "features_dc": model_params[2], 105 | "features_rest": model_params[3], 106 | "scaling": model_params[4], 107 | "rotation": model_params[5], 108 | "opacity": model_params[6].squeeze(1), 109 | } 110 | elif rasterizer == "gsplat": 111 | 112 | model_params = model["splats"] 113 | splats = { 114 | "active_sh_degree": 3, 115 | "means": model_params["means"], 116 | "features_dc": model_params["sh0"], 117 | "features_rest": model_params["shN"], 118 | "scaling": model_params["scales"], 119 | "rotation": model_params["quats"], 120 | "opacity": model_params["opacities"], 121 | } 122 | else: 123 | raise ValueError("Invalid rasterizer") 124 | 125 | _detach_tensors_from_dict(splats) 126 | 127 | # Assuming only one camera 128 | for camera in colmap_project.cameras.values(): 129 | camera_matrix = torch.tensor( 130 | [ 131 | [camera.fx, 0, camera.cx], 132 | [0, camera.fy, camera.cy], 133 | [0, 0, 1], 134 | ] 135 | ) 136 | break 137 | 138 | camera_matrix[:2, :3] /= data_factor 139 | 140 | splats["camera_matrix"] = camera_matrix 141 | splats["colmap_project"] = colmap_project 142 | splats["colmap_dir"] = data_dir 143 | 144 | return splats 145 | 146 | 147 | def get_viewmat_from_colmap_image(image): 148 | viewmat = torch.eye(4).float() # .to(device) 149 | viewmat[:3, :3] = torch.tensor(image.R()).float() # .to(device) 150 | viewmat[:3, 3] = torch.tensor(image.t).float() # .to(device) 151 | return viewmat 152 | 153 | 154 | def create_checkerboard(width, height, size=64): 155 | checkerboard = np.zeros((height, width, 3), dtype=np.uint8) 156 | for y in range(0, height, size): 157 | for x in range(0, width, size): 158 | if (x // size + y // size) % 2 == 0: 159 | checkerboard[y : y + size, x : x + size] = 255 160 | else: 161 | checkerboard[y : y + size, x : x + size] = 128 162 | return checkerboard 163 | 164 | 165 | def render_to_dir(output_dir: str, splats, feedback: bool = False): 166 | if feedback: 167 | cv2.destroyAllWindows() 168 | cv2.namedWindow("Initial Rendering", cv2.WINDOW_NORMAL) 169 | os.makedirs(output_dir, exist_ok=True) 170 | colmap_project = splats["colmap_project"] 171 | frame_idx = 0 172 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 173 | image_name = image.name # .split(".")[0] + ".jpg" 174 | image_path = f"{output_dir}/{image_name}" 175 | means = splats["means"] 176 | colors_dc = splats["features_dc"] 177 | colors_rest = splats["features_rest"] 178 | colors = torch.cat([colors_dc, colors_rest], dim=1) 179 | opacities = torch.sigmoid(splats["opacity"]) 180 | scales = torch.exp(splats["scaling"]) 181 | quats = splats["rotation"] 182 | viewmat = get_viewmat_from_colmap_image(image) 183 | K = splats["camera_matrix"] 184 | output, _, info = rasterization( 185 | means, 186 | quats, 187 | scales, 188 | opacities, 189 | colors, 190 | viewmats=viewmat[None], 191 | Ks=K[None], 192 | sh_degree=3, 193 | width=K[0, 2] * 2, 194 | height=K[1, 2] * 2, 195 | ) 196 | output_cv = torch_to_cv(output[0]) 197 | imageio.imsave(image_path, output_cv[:, :, ::-1]) 198 | if feedback: 199 | cv2.imshow("Initial Rendering", output_cv) 200 | cv2.waitKey(1) 201 | frame_idx += 1 202 | 203 | 204 | 205 | 206 | 207 | 208 | def load_labels(labels_dir, results_dir): 209 | example_images = os.listdir(labels_dir) 210 | example_images = [img for img in example_images if img.endswith(".webp")] 211 | features = [] 212 | labels = [] 213 | for example in example_images: 214 | example_path = os.path.join(labels_dir, example) 215 | annotations = json.load(open(example_path.replace(".webp", ".json"))) 216 | 217 | img_example = cv2.imread(example_path)[..., ::-1] 218 | img_example = Image.fromarray(img_example) 219 | img_example_th = transform(img_example).to(device) 220 | feats = feature_extractor.forward_features(img_example_th[None])[ 221 | "x_norm_patchtokens" 222 | ][0] 223 | feats = feats.reshape((FEATURE_MAP_SIZE, FEATURE_MAP_SIZE, -1)) 224 | feats = feats / torch.norm(feats, dim=-1, keepdim=True) 225 | feats_flatten = feats.detach().cpu().numpy().reshape((-1, DIM)) 226 | 227 | background_mask_inverted = np.zeros((FEATURE_MAP_SIZE, FEATURE_MAP_SIZE)) 228 | 229 | for shape in annotations["shapes"]: 230 | shape_label = shape["label"] 231 | assert shape_label in LABEL_TO_IDX 232 | label_idx = LABEL_TO_IDX[shape_label] 233 | mask64 = shape["mask"] 234 | mask_bytes = base64.b64decode(mask64) 235 | mask = ( 236 | cv2.imdecode(np.frombuffer(mask_bytes, np.uint8), cv2.IMREAD_UNCHANGED) 237 | * 255 238 | ) 239 | boundary_points = np.array(shape["points"]).astype(np.int32) 240 | blank_image = np.zeros_like(img_example) 241 | blank_image[ 242 | boundary_points[0][1] : boundary_points[1][1] + 1, 243 | boundary_points[0][0] : boundary_points[1][0] + 1, 244 | ] = mask[..., None] 245 | blank_image = blank_image[..., 0] 246 | mask_np = cv2.resize( 247 | blank_image, 248 | (FEATURE_MAP_SIZE, FEATURE_MAP_SIZE), 249 | interpolation=cv2.INTER_NEAREST, 250 | ) 251 | mask_flatten = mask_np.flatten() 252 | feats_current_mask = feats_flatten[mask_flatten > 0] 253 | labels_current_mask = np.ones((feats_current_mask.shape[0], 1)) * label_idx 254 | 255 | features = ( 256 | np.concatenate([features, feats_current_mask], axis=0) 257 | if len(features) > 0 258 | else feats_current_mask 259 | ) 260 | labels = ( 261 | np.concatenate([labels, labels_current_mask], axis=0) 262 | if len(labels) > 0 263 | else labels_current_mask 264 | ) 265 | background_mask_inverted = np.logical_or(background_mask_inverted, mask_np) 266 | cv2.imshow("mask", mask_np) 267 | cv2.waitKey(1) 268 | cv2.destroyAllWindows() 269 | 270 | background_mask = ~background_mask_inverted 271 | features_background = feats_flatten[background_mask.flatten()] 272 | labels_background = np.zeros((features_background.shape[0], 1)) 273 | features = ( 274 | np.concatenate([features, features_background], axis=0) 275 | if len(features) > 0 276 | else features_background 277 | ) 278 | labels = ( 279 | np.concatenate([labels, labels_background], axis=0) 280 | if len(labels) > 0 281 | else labels_background 282 | ) 283 | cv2.imshow("background_mask", background_mask.astype(np.uint8) * 255) 284 | cv2.waitKey(1) 285 | cv2.destroyAllWindows() 286 | data = {"features": features, "labels": labels} 287 | os.makedirs(results_dir, exist_ok=True) 288 | pkl.dump(data, open(os.path.join(results_dir, "features_and_labels.pkl"), "wb")) 289 | 290 | 291 | def render_images( 292 | data_dir: str, 293 | checkpoint: str, 294 | results_dir: str, 295 | ): 296 | output_dir = os.path.join(results_dir, "images") 297 | splats = load_checkpoint(checkpoint, data_dir, rasterizer="gsplat") 298 | render_to_dir(output_dir, splats, feedback=True) 299 | 300 | def render_and_vote(data_dir, checkpoint, results_dir): 301 | splats = load_checkpoint(checkpoint, data_dir, rasterizer="gsplat") 302 | means = splats["means"] 303 | colors_dc = splats["features_dc"] 304 | colors_rest = splats["features_rest"] 305 | colors = torch.cat([colors_dc, colors_rest], dim=1) 306 | opacities = torch.sigmoid(splats["opacity"]) 307 | scales = torch.exp(splats["scaling"]) 308 | quats = splats["rotation"] 309 | K = splats["camera_matrix"] 310 | width = int(K[0, 2] * 2) 311 | height = int(K[1, 2] * 2) 312 | colmap_project = splats["colmap_project"] 313 | 314 | affordance_dir = os.path.join(results_dir, "affordance_maps") 315 | affordance_map_images_dir = os.path.join(results_dir, "affordance_map_images") 316 | affordance_maps_3d_dir = os.path.join(results_dir, "affordance_map_images_3dgs") 317 | os.makedirs(affordance_maps_3d_dir, exist_ok=True) 318 | 319 | 320 | 321 | 322 | votes = torch.zeros((8, colors.shape[0])).to(device) # Hardcoding for now 323 | 324 | colors.requires_grad = True 325 | 326 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 327 | viewmat = get_viewmat_from_colmap_image(image) 328 | 329 | output, _, meta = rasterization( 330 | means, 331 | quats, 332 | scales, 333 | opacities, 334 | colors, 335 | viewmat[None], 336 | K[None], 337 | width=width, 338 | height=height, 339 | sh_degree=3, 340 | # render_mode="RGB+ED", 341 | backgrounds=torch.tensor([[0.0, 0.0, 0.0]]).to(device), 342 | ) 343 | 344 | label_map_name = image.name.split(".")[0] + ".npy" 345 | label_map = np.load(os.path.join(affordance_dir, label_map_name)) 346 | 347 | 348 | mask = np.zeros((height, width, 3), dtype=np.uint8) 349 | label_map = cv2.resize( 350 | label_map, (width, height), interpolation=cv2.INTER_NEAREST 351 | ) 352 | 353 | 354 | # Voting 355 | for i in range(8): 356 | mask[label_map == i, :] = i 357 | 358 | 359 | mask = torch.from_numpy(mask).to(device).float() 360 | 361 | 362 | for i in range(8): 363 | loss = ((mask == i) * output[0]).mean() 364 | loss.backward(retain_graph=True) 365 | votes[i] += colors.grad[..., 0, :3].norm(dim=[1]) 366 | colors.grad.zero_() 367 | 368 | votes_path = os.path.join(results_dir, "votes.npy") 369 | votes_np = votes.cpu().numpy() 370 | np.save(votes_path, votes_np) 371 | 372 | colors_new = colors.clone() 373 | color_palette = np.array([[125, 125, 125], [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255], [128, 0, 0]]) 374 | 375 | for i in range(8): 376 | mask = torch.argmax(votes, dim=0) == i 377 | colors_new[mask, 0, :3] = colors_new[mask, 0, :3] * 0.5 + 0.5*((torch.from_numpy(color_palette[i])/255.).to(device)-0.5)/ (1/np.sqrt(4*np.pi)) 378 | colors_new[:, 1:, :3] = 0.1 * colors_new[:, 1:, :3] 379 | 380 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 381 | viewmat = get_viewmat_from_colmap_image(image) 382 | 383 | output, _, meta = rasterization( 384 | means, 385 | quats, 386 | scales, 387 | opacities, 388 | colors_new, 389 | viewmat[None], 390 | K[None], 391 | width=width, 392 | height=height, 393 | sh_degree=3, 394 | backgrounds=torch.tensor([[0.0, 0.0, 0.0]]).to(device), 395 | ) 396 | 397 | output_cv = torch_to_cv(output[0, ..., :3].detach()) 398 | cv2.imshow("Mapped affordance regions", output_cv) 399 | cv2.imwrite(f"{affordance_maps_3d_dir}/{image.name}", output_cv) 400 | cv2.waitKey(1) 401 | 402 | def render_and_vote_fast(data_dir, checkpoint, results_dir): 403 | splats = load_checkpoint(checkpoint, data_dir, rasterizer="gsplat") 404 | means = splats["means"] 405 | colors_dc = splats["features_dc"] 406 | colors_rest = splats["features_rest"] 407 | colors = torch.cat([colors_dc, colors_rest], dim=1) 408 | opacities = torch.sigmoid(splats["opacity"]) 409 | scales = torch.exp(splats["scaling"]) 410 | quats = splats["rotation"] 411 | K = splats["camera_matrix"] 412 | width = int(K[0, 2] * 2) 413 | height = int(K[1, 2] * 2) 414 | colmap_project = splats["colmap_project"] 415 | 416 | affordance_dir = os.path.join(results_dir, "affordance_maps") 417 | affordance_map_images_dir = os.path.join(results_dir, "affordance_map_images") 418 | affordance_maps_3d_dir = os.path.join(results_dir, "affordance_map_images_3dgs") 419 | os.makedirs(affordance_maps_3d_dir, exist_ok=True) 420 | 421 | 422 | 423 | 424 | votes = torch.zeros((colors.shape[0], 8)).to(device) # Hardcoding for now 425 | 426 | dummy_colors = torch.zeros((len(colors),8)).to(device) 427 | dummy_colors.requires_grad = True 428 | 429 | # colors.requires_grad = True 430 | 431 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 432 | viewmat = get_viewmat_from_colmap_image(image) 433 | 434 | output, _, meta = rasterization( 435 | means, 436 | quats, 437 | scales, 438 | opacities, 439 | dummy_colors, 440 | viewmat[None], 441 | K[None], 442 | width=width, 443 | height=height, 444 | # sh_degree=3, 445 | # render_mode="RGB+ED", 446 | # backgrounds=torch.tensor([[0.0, 0.0, 0.0]]).to(device), 447 | ) 448 | 449 | label_map_name = image.name.split(".")[0] + ".npy" 450 | label_map = np.load(os.path.join(affordance_dir, label_map_name)) 451 | 452 | 453 | mask = np.zeros((height, width, 8), dtype=np.float32) 454 | label_map = cv2.resize( 455 | label_map, (width, height), interpolation=cv2.INTER_NEAREST 456 | ) 457 | 458 | 459 | # Voting 460 | for i in range(8): 461 | mask[label_map == i, i] = 1 462 | 463 | 464 | mask = torch.from_numpy(mask).to(device).float() 465 | 466 | 467 | # for i in range(8): 468 | # loss = ((mask == i) * output[0]).mean() 469 | # loss.backward(retain_graph=True) 470 | # votes[i] += colors.grad[..., 0, :3].norm(dim=[1]) 471 | # colors.grad.zero_() 472 | loss = (mask * output[0]).mean() 473 | loss.backward() 474 | votes += dummy_colors.grad.clone() 475 | dummy_colors.grad.zero_() 476 | 477 | votes = votes.T 478 | votes_path = os.path.join(results_dir, "votes.npy") 479 | votes_np = votes.cpu().numpy() 480 | np.save(votes_path, votes_np) 481 | 482 | colors_new = colors.clone() 483 | color_palette = np.array([[125, 125, 125], [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255], [128, 0, 0]]) 484 | 485 | for i in range(8): 486 | mask = torch.argmax(votes, dim=0) == i 487 | colors_new[mask, 0, :3] = colors_new[mask, 0, :3] * 0.5 + 0.5*((torch.from_numpy(color_palette[i])/255.).to(device)-0.5)/ (1/np.sqrt(4*np.pi)) 488 | colors_new[:, 1:, :3] = 0.1 * colors_new[:, 1:, :3] 489 | 490 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 491 | viewmat = get_viewmat_from_colmap_image(image) 492 | 493 | output, _, meta = rasterization( 494 | means, 495 | quats, 496 | scales, 497 | opacities, 498 | colors_new, 499 | viewmat[None], 500 | K[None], 501 | width=width, 502 | height=height, 503 | sh_degree=3, 504 | backgrounds=torch.tensor([[0.0, 0.0, 0.0]]).to(device), 505 | ) 506 | 507 | output_cv = torch_to_cv(output[0, ..., :3].detach()) 508 | cv2.imshow("Mapped affordance regions", output_cv) 509 | cv2.imwrite(f"{affordance_maps_3d_dir}/{image.name}", output_cv) 510 | cv2.waitKey(1) 511 | 512 | def calculate_affordance_map(data_dir, labels_dir, results_dir, k=5): 513 | def most_frequent(array): 514 | return np.bincount(array).argmax() 515 | images_dir = os.path.join(results_dir, "images") 516 | affordance_map_dir = os.path.join(results_dir, "affordance_maps") 517 | affordance_map_images_dir = os.path.join(results_dir, "affordance_map_images") 518 | os.makedirs(affordance_map_images_dir, exist_ok=True) 519 | os.makedirs(affordance_map_dir, exist_ok=True) 520 | features_and_labels = pkl.load(open(os.path.join(results_dir, "features_and_labels.pkl"), "rb")) 521 | 522 | features = features_and_labels["features"] 523 | labels = features_and_labels["labels"] 524 | feature_index = faiss.IndexFlatIP(DIM) 525 | feature_index.add(features) 526 | for image_name in sorted(os.listdir(images_dir)): 527 | image_path = os.path.join(images_dir, image_name) 528 | image = cv2.imread(image_path)[..., ::-1] 529 | image = Image.fromarray(image) 530 | image_th = transform(image).to(device) 531 | feats = feature_extractor.forward_features(image_th[None])["x_norm_patchtokens"][0] 532 | feats = feats.reshape((FEATURE_MAP_SIZE, FEATURE_MAP_SIZE, -1)) 533 | feats = feats / torch.norm(feats, dim=-1, keepdim=True) 534 | feats_flatten = feats.detach().cpu().numpy().reshape((-1, DIM)) 535 | D, I = feature_index.search(feats_flatten, k) 536 | label_set = labels[I].astype(np.uint8) 537 | label_set = np.apply_along_axis(most_frequent, axis=1, arr=label_set) 538 | affordance_map = label_set.reshape((FEATURE_MAP_SIZE, FEATURE_MAP_SIZE)).astype(np.uint8) 539 | # affordance_map = labels[I[:, 0]].reshape((FEATURE_MAP_SIZE, FEATURE_MAP_SIZE)).astype(np.uint8) 540 | 541 | # Showing feedback 542 | color_palette = np.array([[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255], [255, 0, 255], [255, 255, 0], [0, 0, 128]]) 543 | 544 | labels_mask = color_palette[affordance_map.flatten()] 545 | labels_mask = labels_mask.reshape((FEATURE_MAP_SIZE, FEATURE_MAP_SIZE, 3)) 546 | np.save(os.path.join(affordance_map_dir, image_name.replace(".jpg", ".npy")), affordance_map) 547 | labels_mask_big = cv2.resize(labels_mask, image.size, interpolation=cv2.INTER_NEAREST) 548 | img = cv2.imread(image_path) 549 | 550 | img_out = img * 0.5 + 0.5 * labels_mask_big 551 | img_out = img_out.astype(np.uint8) 552 | 553 | for i in range(8): 554 | color = color_palette[i].tolist() 555 | cv2.putText(img_out, IDX_TO_LABEL[i], (10, 10 + 20*i), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA) 556 | 557 | cv2.imshow("2D-2D Affordance Transfer", img_out) 558 | cv2.imwrite(os.path.join(affordance_map_images_dir, image_name), img_out) 559 | cv2.waitKey(1) 560 | 561 | cv2.destroyAllWindows() 562 | 563 | 564 | 565 | 566 | from scipy.io import loadmat 567 | 568 | def evaluate_results(data_dir, checkpoint, results_dir): 569 | gt_dir = os.path.join(data_dir, "gt") 570 | affordance_maps_dir = os.path.join(results_dir, "affordance_maps") 571 | 572 | # Not rendering now 573 | splats = load_checkpoint(checkpoint, data_dir, rasterizer="gsplat") 574 | colmap_project = splats["colmap_project"] 575 | label_files = os.listdir(gt_dir) 576 | label_files = [label_file for label_file in label_files if label_file.endswith("label.mat")] 577 | label_files.sort() 578 | mIoU = defaultdict(list) 579 | recall = defaultdict(list) 580 | assert len(label_files) == len(colmap_project.images) 581 | 582 | # For affordance map 583 | 584 | print("\n\nEvaluating 2D-2D affordance transfer.") 585 | for image, label_file in zip(sorted(colmap_project.images.values(), key=lambda x: x.name), label_files): 586 | # gt_name = image.name.replace("_rgb.jpg", "_label.mat") 587 | # image_path = os.path.join(data_dir, "images", image.name) 588 | # gt_image_path = os.path.join(gt_dir, label_file.replace("_label.mat", "_rgb.jpg")) 589 | gt_path = os.path.join(gt_dir, label_file) 590 | gt = loadmat(gt_path) 591 | if gt["gt_type"] == "automatic": 592 | continue 593 | gt_label = gt["gt_label"] 594 | 595 | affordance_map_name = image.name.replace(".jpg", ".npy") 596 | affordance_map_path = os.path.join(affordance_maps_dir, affordance_map_name) 597 | affordance_map = np.load(affordance_map_path) 598 | affordance_map = cv2.resize(affordance_map, (gt_label.shape[1], gt_label.shape[0]), interpolation=cv2.INTER_NEAREST) 599 | 600 | for i in range(1, 8): 601 | gt_mask = gt_label == i 602 | affordance_mask = affordance_map == i 603 | intersection = np.logical_and(gt_mask, affordance_mask).sum() 604 | union = np.logical_or(gt_mask, affordance_mask).sum() 605 | if union == 0: 606 | continue 607 | if intersection == 0: 608 | iou = 0 609 | else: 610 | iou = intersection / union 611 | mIoU[i].append(iou) 612 | 613 | if gt_mask.sum() == 0: 614 | continue 615 | if intersection == 0: 616 | rec = 0 617 | else: 618 | rec = intersection / gt_mask.sum() # Actually recall 619 | recall[i].append(rec) 620 | # print(f"Class {i}: {iou}") 621 | # Calculate mIoU 622 | res = 0 623 | for i in range(1, 8): 624 | if len(mIoU[i]) == 0: 625 | continue 626 | res += np.mean(mIoU[i]) 627 | # print(f"mIoU {i}: {np.mean(mIoU[i])}") 628 | res /= (np.unique(gt_label).size - 1) 629 | print(f"mIoU: {res}") 630 | 631 | res_recall = 0 632 | for i in range(1, 8): 633 | if len(recall[i]) == 0: 634 | continue 635 | res_recall += np.mean(recall[i]) 636 | # print(f"Recall {i}: {np.mean(recall[i])}") 637 | res_recall /= (np.unique(gt_label).size - 1) 638 | print(f"Recall: {res_recall}") 639 | 640 | 641 | print("\nEvaluating 2D-3D affordance transfer.") 642 | mIoU = defaultdict(list) 643 | recall = defaultdict(list) 644 | means = splats["means"] 645 | colors_dc = splats["features_dc"] 646 | colors_rest = splats["features_rest"] 647 | colors = torch.cat([colors_dc, colors_rest], dim=1) 648 | opacities = torch.sigmoid(splats["opacity"]) 649 | scales = torch.exp(splats["scaling"]) 650 | quats = splats["rotation"] 651 | K = splats["camera_matrix"] 652 | width = int(K[0, 2] * 2) 653 | height = int(K[1, 2] * 2) 654 | 655 | votes_path = os.path.join(results_dir, "votes.npy") 656 | votes = np.load(votes_path) 657 | votes = torch.from_numpy(votes).to(device) 658 | for image, label_file in zip(sorted(colmap_project.images.values(), key=lambda x: x.name), label_files): 659 | gt_name = image.name.replace("_rgb.jpg", "_label.mat") 660 | image_path = os.path.join(data_dir, "images", image.name) 661 | gt_image_path = os.path.join(gt_dir, label_file.replace("_label.mat", "_rgb.jpg")) 662 | 663 | gt_path = os.path.join(gt_dir, label_file) 664 | gt = loadmat(gt_path) 665 | if gt["gt_type"] == "automatic": 666 | continue 667 | gt_label = gt["gt_label"] 668 | 669 | affordance_map_name = image.name.replace(".jpg", ".npy") 670 | affordance_map_path = os.path.join(affordance_maps_dir, affordance_map_name) 671 | affordance_map = np.load(affordance_map_path) 672 | affordance_map = cv2.resize(affordance_map, (gt_label.shape[1], gt_label.shape[0]), interpolation=cv2.INTER_NEAREST) 673 | 674 | viewmat = get_viewmat_from_colmap_image(image) 675 | 676 | 677 | for i in range(1, 8): 678 | colors_new = colors.clone()[:,0] 679 | mask = torch.argmax(votes, dim=0) == i 680 | colors_new[mask] = 1.0 681 | colors_new[~mask] = 0.0 682 | 683 | output, _, meta = rasterization( 684 | means, 685 | quats, 686 | scales, 687 | opacities, 688 | colors_new, 689 | viewmat[None], 690 | K[None], 691 | width=width, 692 | height=height, 693 | backgrounds=torch.tensor([[0.0, 0.0, 0.0]]).to(device), 694 | ) 695 | output_cv = torch_to_cv(output[0, ..., :3].detach()) 696 | cv2.imshow("Mapped affordance regions", output_cv) 697 | cv2.waitKey(1) 698 | gt_mask = gt_label == i 699 | affordance_mask = output_cv > 64 700 | affordance_mask = affordance_mask[...,0] 701 | intersection = np.logical_and(gt_mask, affordance_mask).sum() 702 | union = np.logical_or(gt_mask, affordance_mask).sum() 703 | if union == 0: 704 | continue 705 | if intersection == 0: 706 | iou = 0 707 | else: 708 | iou = intersection / union 709 | mIoU[i].append(iou) 710 | if gt_mask.sum() == 0: 711 | continue 712 | if intersection == 0: 713 | rec = 0 714 | else: 715 | rec = intersection / gt_mask.sum() 716 | recall[i].append(rec) 717 | res = 0 718 | for i in range(1, 8): 719 | if len(mIoU[i]) == 0: 720 | continue 721 | res += np.mean(mIoU[i]) 722 | # print(f"mIoU {i}: {np.mean(mIoU[i])}") 723 | res /= (np.unique(gt_label).size - 1) 724 | print(f"mIoU: {res}") 725 | 726 | res_recall = 0 727 | for i in range(1, 8): 728 | if len(recall[i]) == 0: 729 | continue 730 | res_recall += np.mean(recall[i]) 731 | # print(f"Recall {i}: {np.mean(recall[i])}") 732 | res_recall /= (np.unique(gt_label).size - 1) 733 | print(f"Recall: {res_recall}") 734 | 735 | 736 | def main( 737 | data_dir: str = "./data/processed_scene_01", 738 | checkpoint: str = "./data/processed_scene_01/ckpts/ckpt_29999_rank0.pt", 739 | labels_dir: str = "./data/affordance_labels", 740 | results_dir: str = "./results/scene_01", 741 | ): 742 | # Load labels to database 743 | load_labels(labels_dir, results_dir) 744 | 745 | # Render images to the results folder 746 | render_images(data_dir, checkpoint, results_dir) 747 | 748 | # Calculate affordance map and save 749 | calculate_affordance_map(data_dir, labels_dir, results_dir) 750 | 751 | # Render and vote 752 | render_and_vote_fast(data_dir, checkpoint, results_dir) 753 | # print(f"Time taken for render and vote: {t2-t1} seconds") 754 | 755 | # Evaluate the results. Comment out if you don't want to evaluate 756 | evaluate_results(data_dir, checkpoint, results_dir) 757 | 758 | 759 | if __name__ == "__main__": 760 | tyro.cli(main) 761 | --------------------------------------------------------------------------------