├── f3dgs ├── datasets │ ├── __init__.py │ ├── download_dataset.py │ ├── normalize.py │ ├── traj.py │ └── colmap.py └── utils_simple_trainer.py ├── .gitignore ├── requirements.txt ├── affordance_transfer ├── README.md └── affordance.sh ├── train_compression_encoder_decoder.py ├── README.md ├── visualize_pca.py ├── visualize_pca_f3dgs.py ├── backproject_compressed.py ├── segment.py ├── segment_f3dgs.py ├── segment_compressed.py ├── backproject.py ├── utils.py ├── click_and_segment.py ├── viewer.py └── viewer_with_llm.py /f3dgs/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | results 3 | checkpoints 4 | .ckpt 5 | .pt 6 | .pth 7 | temp_* 8 | 9 | symlink_*.py 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gsplat==1.4.0 2 | numpy==1.24.4 3 | git+https://github.com/JojiJoseph/pycolmap-scene-manager.git # To not conflict with python binding of colmap 4 | tyro==0.9.2 5 | git+https://github.com/ultralytics/CLIP.git 6 | opencv-python==4.10.0.84 7 | git+https://github.com/krrish94/lseg-minimal.git 8 | imageio==2.35.1 9 | -------------------------------------------------------------------------------- /affordance_transfer/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | The code needs to be refactored and is not documented. 4 | 5 | To run, 6 | 7 | 1. Move these scripts one folder up 8 | 2. Download the affordance annotation and dataset from https://github.com/JojiJoseph/3dgs-gradient-segmentation 9 | 3. Back-project dino features 10 | 4. Run `sh affordance.sh` 11 | 12 | Please do raise an issue if something is not working. 13 | -------------------------------------------------------------------------------- /affordance_transfer/affordance.sh: -------------------------------------------------------------------------------- 1 | python demo_affordance_transfer.py --data-dir data/processed_scene_01/ --checkpoint data/processed_scene_01/ckpts/ckpt_29999_rank0.pt --results-dir results/ps01 --rasterizer gsplat 2 | python demo_affordance_transfer.py --data-dir data/processed_scene_02/ --checkpoint data/processed_scene_02/ckpts/ckpt_29999_rank0.pt --results-dir results/ps02 --rasterizer gsplat 3 | python demo_affordance_transfer.py --data-dir data/processed_scene_03/ --checkpoint data/processed_scene_03/ckpts/ckpt_29999_rank0.pt --results-dir results/ps03 --rasterizer gsplat 4 | -------------------------------------------------------------------------------- /f3dgs/datasets/download_dataset.py: -------------------------------------------------------------------------------- 1 | """Script to download benchmark dataset(s)""" 2 | 3 | import os 4 | import subprocess 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Literal 8 | 9 | import tyro 10 | 11 | # dataset names 12 | dataset_names = Literal["mipnerf360"] 13 | 14 | # dataset urls 15 | urls = {"mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip"} 16 | 17 | # rename maps 18 | dataset_rename_map = {"mipnerf360": "360_v2"} 19 | 20 | 21 | @dataclass 22 | class DownloadData: 23 | dataset: dataset_names = "mipnerf360" 24 | save_dir: Path = Path(os.getcwd() + "/data") 25 | 26 | def main(self): 27 | self.save_dir.mkdir(parents=True, exist_ok=True) 28 | self.dataset_download(self.dataset) 29 | 30 | def dataset_download(self, dataset: dataset_names): 31 | (self.save_dir / dataset_rename_map[dataset]).mkdir(parents=True, exist_ok=True) 32 | 33 | file_name = Path(urls[dataset]).name 34 | 35 | # download 36 | download_command = [ 37 | "wget", 38 | "-P", 39 | str(self.save_dir / dataset_rename_map[dataset]), 40 | urls[dataset], 41 | ] 42 | try: 43 | subprocess.run(download_command, check=True) 44 | print("File downloaded successfully.") 45 | except subprocess.CalledProcessError as e: 46 | print(f"Error downloading file: {e}") 47 | 48 | # if .zip 49 | if Path(urls[dataset]).suffix == ".zip": 50 | extract_command = [ 51 | "unzip", 52 | self.save_dir / dataset_rename_map[dataset] / file_name, 53 | "-d", 54 | self.save_dir / dataset_rename_map[dataset], 55 | ] 56 | # if .tar 57 | else: 58 | extract_command = [ 59 | "tar", 60 | "-xvzf", 61 | self.save_dir / dataset_rename_map[dataset] / file_name, 62 | "-C", 63 | self.save_dir / dataset_rename_map[dataset], 64 | ] 65 | 66 | # extract 67 | try: 68 | subprocess.run(extract_command, check=True) 69 | os.remove(self.save_dir / dataset_rename_map[dataset] / file_name) 70 | print("Extraction complete.") 71 | except subprocess.CalledProcessError as e: 72 | print(f"Extraction failed: {e}") 73 | 74 | 75 | if __name__ == "__main__": 76 | tyro.cli(DownloadData).main() 77 | -------------------------------------------------------------------------------- /train_compression_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | import pandas as pd 4 | import torch.nn as nn 5 | import time 6 | import torch.nn.functional as F 7 | 8 | from lseg import LSegNet 9 | import os 10 | 11 | csv_file = './objectInfo150.csv' 12 | if os.path.exists(csv_file): 13 | df = pd.read_csv(csv_file) 14 | else: 15 | raise FileNotFoundError("objectInfo150.csv does not exist. Please download from https://github.com/CSAILVision/sceneparsing/blob/master/objectInfo150.csv") 16 | 17 | labels = df['Name'].values 18 | new_labels = [] 19 | for label in labels: 20 | new_labels.extend(list(label.split(";"))) 21 | labels = new_labels 22 | 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | 25 | net = LSegNet( 26 | backbone="clip_vitl16_384", 27 | features=256, 28 | crop_size=480, 29 | arch_option=0, 30 | block_depth=0, 31 | activation="lrelu", 32 | ) 33 | # Load pre-trained weights 34 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt", map_location=device)) 35 | net.eval() 36 | net.to(device) 37 | 38 | clip_text_encoder = net.clip_pretrained.encode_text 39 | 40 | prompts = labels 41 | 42 | prompt = clip.tokenize(prompts) 43 | prompt = prompt.cuda() 44 | 45 | text_feat = clip_text_encoder(prompt) # N, 512, N - number of prompts 46 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) 47 | text_feat_norm = text_feat_norm.float().to(device) 48 | print(text_feat_norm.shape) 49 | 50 | 51 | class EncoderDecoder(nn.Module): 52 | def __init__(self): 53 | super(EncoderDecoder, self).__init__() 54 | 55 | # Same thing can be realized by linear layer or 1x1 conv layer 56 | # Defines the layer as a matrix to avoid reshaping 57 | # input dim shape (any_number_of_dimensions..., 512) 58 | self.encoder = nn.Parameter(torch.randn(512, 16)) 59 | self.decoder = nn.Parameter(torch.randn(16, 512)) 60 | 61 | 62 | def forward(self, x): 63 | x = x @ self.encoder 64 | y = x @ self.decoder 65 | return x, y 66 | 67 | 68 | model = EncoderDecoder().to(device) 69 | opt = torch.optim.Adam(model.parameters(), lr=1e-4) 70 | text_feat_norm = text_feat_norm.detach() 71 | 72 | 73 | 74 | 75 | def latent_cosine_preservation_loss(z, x): 76 | z_norm = F.normalize(z, dim=1) 77 | x_norm = F.normalize(x, dim=1) 78 | 79 | cosine_z = z_norm @ z_norm.T 80 | cosine_x = x_norm @ x_norm.T 81 | 82 | return F.mse_loss(cosine_z, cosine_x) 83 | 84 | 85 | t1 = time.time() 86 | for i in range(100000): 87 | x, y = model(text_feat_norm) 88 | y = torch.nn.functional.normalize(y, dim=1) 89 | loss1 = torch.nn.functional.mse_loss(text_feat_norm, y) 90 | loss2 = latent_cosine_preservation_loss(x, text_feat_norm) 91 | loss = loss1 + loss2 92 | loss.backward() 93 | opt.step() 94 | opt.zero_grad() 95 | if i % 1000 == 0: 96 | print(loss.item()) 97 | 98 | 99 | t2 = time.time() 100 | print("Time taken for training encoder decoder model: ", t2 - t1) 101 | 102 | torch.save(model.state_dict(), "encoder_decoder.ckpt") 103 | 104 | 105 | # How to use 106 | # model = EncoderDecoder() 107 | # compressed = uncompressed @ model.encoder 108 | # uncompressed = compressed @ model.decoder 109 | # uncompressed tensor should have the last dimension as 512 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gradient-Weighted Feature Back-Projection: A Fast Alternative to Feature Distillation in 3D Gaussian Splatting 2 | 3 | This repository contains the code for the **SIGGRAPH Asia 2025** paper **Gradient-Weighted Feature Back-Projection: A Fast Alternative to Feature Distillation in 3D Gaussian Splatting**. 4 | 5 | 6 | 7 | Project page: https://jojijoseph.github.io/3dgs-backprojection 8 | 9 | [Paper](https://dl.acm.org/doi/10.1145/3757377.3763926) 10 | 11 | 12 | ## Setup 13 | 14 | Please install the dependencies listed in `requirements.txt` via `pip install -r requirements.txt`. Download `lseg_minimal_e200.ckpt` from https://mitprod-my.sharepoint.com/:u:/g/personal/jkrishna_mit_edu/EVlP4Ggf3OlMgDACHNVYuIYBZ4JNi5nJCQA1kXM-_nrB3w?e=XnPT39 and place it in the `./checkpoints` folder. 15 | 16 | Other than that, it's a self-contained repo. Please feel free to raise an issue if you face any problems while running the code. 17 | 18 | ## Demo 19 | 20 | 21 | 22 | https://github.com/user-attachments/assets/1aecd2d1-8e16-499e-98ce-a1667be5114d 23 | 24 | Left: Original rendering, Mid: Extraction, Right: Deletion 25 | 26 | Sample data (garden) can be found [here](https://drive.google.com/file/d/1cEPby9zWgG40dJ4eRiHu15Jdg7FgvTdG/view?usp=sharing). Please create a folder named `data` on root folder and extract the contents of zip file to that folder. 27 | 28 | **Backprojection** 29 | 30 | To backproject the features run 31 | 32 | ```bash 33 | python backproject.py --help 34 | ``` 35 | 36 | **Segmentation** 37 | 38 | Once backprojection is completed, run the following to see the segmentation results. 39 | 40 | ```bash 41 | python segment.py --help 42 | ``` 43 | 44 | 45 | Trained Mip-NeRF 360 Gaussian splat models (using [gsplat](https://github.com/nerfstudio-project/gsplat) with data factor = 4) can be found [here](https://drive.google.com/file/d/1ZCTgAE6vZOeUBdR3qPXdSPY01QQBHxeO/view?usp=sharing). Extract them to `data` folder. 46 | 47 | 48 | **Application - Click and Segment** 49 | 50 | 51 | 52 | https://github.com/user-attachments/assets/3f1c797f-db29-416f-8917-9be7885231b5 53 | 54 | 55 | 56 | ```bash 57 | python click_and_segment.py 58 | ``` 59 | 60 | Click left button to select positive visual prompts and middle button to select negative visual prompts. `ctrl+lbutton` and `ctrl+mbutton` to remove selected prompts. 61 | 62 | **Application - Editing with LLM** 63 | 64 | ```bash 65 | python viewer_with_llm.py --checkpoint data/garden/ckpts/ckpt_29999_rank0.pt --data-dir data/garden --lseg-checkpoint results/garden/features_lseg.pt 66 | ``` 67 | 68 | Press ` to start prompting. At present it supports only a single query at a time. Queries can be of changing view, segment and change color. 69 | 70 | 71 | 72 | https://github.com/user-attachments/assets/126583ab-1f6f-4cc3-ab60-21453a7f3f5a 73 | 74 | 75 | 76 | ## Acknowledgements 77 | 78 | A big thanks to the following tools/libraries, which were instrumental in this project: 79 | 80 | - [gsplat](https://github.com/nerfstudio-project/gsplat): 3DGS rasterizer. 81 | - [LSeg](https://github.com/isl-org/lang-seg) and [LSeg Minimal](https://github.com/krrish94/lseg-minimal) : To generate features to be backprojected. 82 | 83 | 84 | ## Citation 85 | If you find this paper or the code helpful for your work, please consider citing our work, 86 | ``` 87 | @misc{joseph2024gradientweightedfeaturebackprojectionfast, 88 | title={Gradient-Weighted Feature Back-Projection: A Fast Alternative to Feature Distillation in 3D Gaussian Splatting}, 89 | author={Joji Joseph and Bharadwaj Amrutur and Shalabh Bhatnagar}, 90 | year={2024}, 91 | eprint={2411.15193}, 92 | archivePrefix={arXiv}, 93 | primaryClass={cs.CV}, 94 | url={https://arxiv.org/abs/2411.15193}, 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /f3dgs/datasets/normalize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"): 5 | """ 6 | reference: nerf-factory 7 | Get a similarity transform to normalize dataset 8 | from c2w (OpenCV convention) cameras 9 | :param c2w: (N, 4) 10 | :return T (4,4) , scale (float) 11 | """ 12 | t = c2w[:, :3, 3] 13 | R = c2w[:, :3, :3] 14 | 15 | # (1) Rotate the world so that z+ is the up axis 16 | # we estimate the up axis by averaging the camera up axes 17 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) 18 | world_up = np.mean(ups, axis=0) 19 | world_up /= np.linalg.norm(world_up) 20 | 21 | up_camspace = np.array([0.0, -1.0, 0.0]) 22 | c = (up_camspace * world_up).sum() 23 | cross = np.cross(world_up, up_camspace) 24 | skew = np.array( 25 | [ 26 | [0.0, -cross[2], cross[1]], 27 | [cross[2], 0.0, -cross[0]], 28 | [-cross[1], cross[0], 0.0], 29 | ] 30 | ) 31 | if c > -1: 32 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) 33 | else: 34 | # In the unlikely case the original data has y+ up axis, 35 | # rotate 180-deg about x axis 36 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 37 | 38 | # R_align = np.eye(3) # DEBUG 39 | R = R_align @ R 40 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) 41 | t = (R_align @ t[..., None])[..., 0] 42 | 43 | # (2) Recenter the scene. 44 | if center_method == "focus": 45 | # find the closest point to the origin for each camera's center ray 46 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds 47 | translate = -np.median(nearest, axis=0) 48 | elif center_method == "poses": 49 | # use center of the camera positions 50 | translate = -np.median(t, axis=0) 51 | else: 52 | raise ValueError(f"Unknown center_method {center_method}") 53 | 54 | transform = np.eye(4) 55 | transform[:3, 3] = translate 56 | transform[:3, :3] = R_align 57 | 58 | # (3) Rescale the scene using camera distances 59 | scale_fn = np.max if strict_scaling else np.median 60 | scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1)) 61 | transform[:3, :] *= scale 62 | 63 | return transform 64 | 65 | 66 | def align_principle_axes(point_cloud): 67 | # Compute centroid 68 | centroid = np.median(point_cloud, axis=0) 69 | 70 | # Translate point cloud to centroid 71 | translated_point_cloud = point_cloud - centroid 72 | 73 | # Compute covariance matrix 74 | covariance_matrix = np.cov(translated_point_cloud, rowvar=False) 75 | 76 | # Compute eigenvectors and eigenvalues 77 | eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) 78 | 79 | # Sort eigenvectors by eigenvalues (descending order) so that the z-axis 80 | # is the principal axis with the smallest eigenvalue. 81 | sort_indices = eigenvalues.argsort()[::-1] 82 | eigenvectors = eigenvectors[:, sort_indices] 83 | 84 | # Check orientation of eigenvectors. If the determinant of the eigenvectors is 85 | # negative, then we need to flip the sign of one of the eigenvectors. 86 | if np.linalg.det(eigenvectors) < 0: 87 | eigenvectors[:, 0] *= -1 88 | 89 | # Create rotation matrix 90 | rotation_matrix = eigenvectors.T 91 | 92 | # Create SE(3) matrix (4x4 transformation matrix) 93 | transform = np.eye(4) 94 | transform[:3, :3] = rotation_matrix 95 | transform[:3, 3] = -rotation_matrix @ centroid 96 | 97 | return transform 98 | 99 | 100 | def transform_points(matrix, points): 101 | """Transform points using an SE(3) matrix. 102 | 103 | Args: 104 | matrix: 4x4 SE(3) matrix 105 | points: Nx3 array of points 106 | 107 | Returns: 108 | Nx3 array of transformed points 109 | """ 110 | assert matrix.shape == (4, 4) 111 | assert len(points.shape) == 2 and points.shape[1] == 3 112 | return points @ matrix[:3, :3].T + matrix[:3, 3] 113 | 114 | 115 | def transform_cameras(matrix, camtoworlds): 116 | """Transform cameras using an SE(3) matrix. 117 | 118 | Args: 119 | matrix: 4x4 SE(3) matrix 120 | camtoworlds: Nx4x4 array of camera-to-world matrices 121 | 122 | Returns: 123 | Nx4x4 array of transformed camera-to-world matrices 124 | """ 125 | assert matrix.shape == (4, 4) 126 | assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4) 127 | camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix) 128 | scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1) 129 | camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None] 130 | return camtoworlds 131 | 132 | 133 | def normalize(camtoworlds, points=None): 134 | T1 = similarity_from_cameras(camtoworlds) 135 | camtoworlds = transform_cameras(T1, camtoworlds) 136 | if points is not None: 137 | points = transform_points(T1, points) 138 | T2 = align_principle_axes(points) 139 | camtoworlds = transform_cameras(T2, camtoworlds) 140 | points = transform_points(T2, points) 141 | return camtoworlds, points, T2 @ T1 142 | else: 143 | return camtoworlds, T1 144 | -------------------------------------------------------------------------------- /visualize_pca.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Literal 3 | import tyro 4 | import os 5 | import torch 6 | import cv2 7 | import imageio # To generate gifs 8 | import pycolmap_scene_manager as pycolmap 9 | from gsplat import rasterization 10 | import numpy as np 11 | import clip 12 | import matplotlib 13 | from sklearn.decomposition import PCA 14 | 15 | matplotlib.use("TkAgg") 16 | 17 | from lseg import LSegNet 18 | from utils import ( 19 | prune_by_gradients, 20 | test_proper_pruning, 21 | get_viewmat_from_colmap_image, 22 | load_checkpoint, 23 | torch_to_cv, 24 | ) 25 | 26 | 27 | def render_pca( 28 | splats, 29 | features, 30 | output_path, 31 | pca_on_gaussians=True, 32 | scale=1.0, 33 | feedback=True, 34 | ): 35 | if feedback: 36 | cv2.destroyAllWindows() 37 | cv2.namedWindow("PCA", cv2.WINDOW_NORMAL) 38 | frames = [] 39 | means = splats["means"] 40 | colors_dc = splats["features_dc"] 41 | colors_rest = splats["features_rest"] 42 | colors = torch.cat([colors_dc, colors_rest], dim=1) 43 | opacities = torch.sigmoid(splats["opacity"]) 44 | scales = torch.exp(splats["scaling"]) 45 | quats = splats["rotation"] 46 | K = splats["camera_matrix"] 47 | aux_dir = output_path + ".images" 48 | os.makedirs(aux_dir, exist_ok=True) 49 | 50 | pca = PCA(n_components=3) 51 | features_pca = pca.fit_transform(features.detach().cpu().numpy()) 52 | feats_min = np.min(features_pca, axis=(0, 1)) 53 | feats_max = np.max(features_pca, axis=(0, 1)) 54 | features_pca = (features_pca - feats_min) / (feats_max - feats_min) 55 | features_pca = torch.tensor(features_pca).float().cuda() 56 | if pca_on_gaussians: 57 | for image in sorted( 58 | splats["colmap_project"].images.values(), key=lambda x: x.name 59 | ): 60 | viewmat = get_viewmat_from_colmap_image(image) 61 | features_rendered, alphas, meta = rasterization( 62 | means, 63 | quats, 64 | scales * scale, 65 | opacities, 66 | features_pca, 67 | viewmats=viewmat[None], 68 | Ks=K[None], 69 | width=K[0, 2] * 2, 70 | height=K[1, 2] * 2, 71 | # sh_degree=3, 72 | ) 73 | features_rendered = features_rendered[0] 74 | frame = torch_to_cv(features_rendered) 75 | frame = np.clip(frame, 0, 255).astype(np.uint8) 76 | frames.append(frame) 77 | if feedback: 78 | cv2.imshow("PCA", frame[..., ::-1]) 79 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 80 | cv2.waitKey(1) 81 | else: 82 | for image in sorted( 83 | splats["colmap_project"].images.values(), key=lambda x: x.name 84 | ): 85 | viewmat = get_viewmat_from_colmap_image(image) 86 | features_rendered, alphas, meta = rasterization( 87 | means, 88 | quats, 89 | scales * scale, 90 | opacities, 91 | features, 92 | viewmats=viewmat[None], 93 | Ks=K[None], 94 | width=K[0, 2] * 2, 95 | height=K[1, 2] * 2, 96 | # sh_degree=3, 97 | ) 98 | features_rendered = features_rendered[0] 99 | h, w, c = features_rendered.shape 100 | features_rendered = ( 101 | features_rendered.reshape(h * w, c).detach().cpu().numpy() 102 | ) 103 | features_rendered = pca.transform(features_rendered) 104 | features_rendered = features_rendered.reshape(h, w, 3) 105 | features_rendered = (features_rendered - feats_min) / ( 106 | feats_max - feats_min 107 | ) 108 | frame = (features_rendered * 255).astype(np.uint8) 109 | frames.append(frame[..., ::-1]) 110 | if feedback: 111 | cv2.imshow("PCA", frame) 112 | cv2.imwrite(f"{aux_dir}/{image.name}", frame) 113 | cv2.waitKey(1) 114 | imageio.mimsave(output_path, frames, fps=10, loop=0) 115 | if feedback: 116 | cv2.destroyAllWindows() 117 | 118 | 119 | def main( 120 | data_dir: str = "./data/garden", # colmap path 121 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 122 | results_dir: str = "./results/garden", # output path 123 | rasterizer: Literal[ 124 | "inria", "gsplat" 125 | ] = "gsplat", # Original or gsplat for checkpoints 126 | data_factor: int = 4, 127 | show_visual_feedback: bool = True, 128 | feature: Literal["lseg", "dino"] = "lseg", 129 | ): 130 | 131 | if not torch.cuda.is_available(): 132 | raise RuntimeError("CUDA is required for this demo") 133 | 134 | torch.set_default_device("cuda") 135 | 136 | os.makedirs(results_dir, exist_ok=True) 137 | splats = load_checkpoint( 138 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 139 | ) 140 | splats_optimized = prune_by_gradients(splats) 141 | test_proper_pruning(splats, splats_optimized) 142 | splats = splats_optimized 143 | if feature == "lseg": 144 | features = torch.load(f"{results_dir}/features_lseg.pt") 145 | elif feature == "dino": 146 | features = torch.load(f"{results_dir}/features_dino.pt") 147 | 148 | render_pca( 149 | splats, 150 | features, 151 | f"{results_dir}/pca_gaussians_{feature}.gif", 152 | pca_on_gaussians=True, 153 | scale=0.20, 154 | feedback=show_visual_feedback, 155 | ) 156 | 157 | render_pca( 158 | splats, 159 | features, 160 | f"{results_dir}/pca_renderings_{feature}.gif", 161 | pca_on_gaussians=False, 162 | feedback=show_visual_feedback, 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | tyro.cli(main) 168 | -------------------------------------------------------------------------------- /visualize_pca_f3dgs.py: -------------------------------------------------------------------------------- 1 | # TODO: Clean up 2 | from copy import deepcopy 3 | from typing import Literal 4 | import tyro 5 | import os 6 | import torch 7 | import cv2 8 | import imageio # To generate gifs 9 | import pycolmap_scene_manager as pycolmap 10 | from gsplat import rasterization 11 | import numpy as np 12 | import clip 13 | import matplotlib 14 | from sklearn.decomposition import PCA 15 | 16 | matplotlib.use("TkAgg") 17 | 18 | from lseg import LSegNet 19 | from utils import ( 20 | prune_by_gradients, 21 | test_proper_pruning, 22 | get_viewmat_from_colmap_image, 23 | load_checkpoint, 24 | load_checkpoint_f3dgs, 25 | torch_to_cv, 26 | ) 27 | 28 | 29 | def render_pca_f3dgs( 30 | splats, 31 | # features, 32 | output_path, 33 | pca_on_gaussians=True, 34 | scale=1.0, 35 | feedback=True, 36 | ): 37 | if feedback: 38 | cv2.destroyAllWindows() 39 | cv2.namedWindow("PCA", cv2.WINDOW_NORMAL) 40 | frames = [] 41 | means = splats["means"] 42 | colors_dc = splats["features_dc"] 43 | colors_rest = splats["features_rest"] 44 | colors = torch.cat([colors_dc, colors_rest], dim=1) 45 | opacities = torch.sigmoid(splats["opacity"]) 46 | scales = torch.exp(splats["scaling"]) 47 | features = splats["features"] 48 | conv = splats["conv"] 49 | quats = splats["rotation"] 50 | K = splats["camera_matrix"] 51 | aux_dir = output_path + ".images" 52 | os.makedirs(aux_dir, exist_ok=True) 53 | 54 | features_expanded = features @ conv 55 | 56 | features_expanded = torch.nn.functional.normalize(features_expanded, dim=1) 57 | 58 | pca = PCA(n_components=3) 59 | features_pca = pca.fit_transform((features_expanded).detach().cpu().numpy()) 60 | feats_min = np.min(features_pca, axis=(0, 1)) 61 | feats_max = np.max(features_pca, axis=(0, 1)) 62 | features_pca = (features_pca - feats_min) / (feats_max - feats_min) 63 | features_pca = torch.tensor(features_pca).float().cuda() 64 | if pca_on_gaussians: 65 | for image in sorted( 66 | splats["colmap_project"].images.values(), key=lambda x: x.name 67 | ): 68 | viewmat = get_viewmat_from_colmap_image(image) 69 | print(means.shape, features_pca.shape) 70 | features_rendered, alphas, meta = rasterization( 71 | means, 72 | quats, 73 | scales * scale, 74 | opacities, 75 | features_pca, 76 | viewmats=viewmat[None], 77 | Ks=K[None], 78 | width=K[0, 2] * 2, 79 | height=K[1, 2] * 2, 80 | # sh_degree=3, 81 | ) 82 | features_rendered = features_rendered[0] 83 | frame = torch_to_cv(features_rendered) 84 | frame = np.clip(frame, 0, 255).astype(np.uint8) 85 | frames.append(frame) 86 | if feedback: 87 | cv2.imshow("PCA", frame[..., ::-1]) 88 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 89 | cv2.waitKey(1) 90 | else: 91 | for image in sorted( 92 | splats["colmap_project"].images.values(), key=lambda x: x.name 93 | ): 94 | viewmat = get_viewmat_from_colmap_image(image) 95 | features_rendered, alphas, meta = rasterization( 96 | means, 97 | quats, 98 | scales * scale, 99 | opacities, 100 | features, 101 | viewmats=viewmat[None], 102 | Ks=K[None], 103 | width=K[0, 2] * 2, 104 | height=K[1, 2] * 2, 105 | # sh_degree=3, 106 | ) 107 | features_rendered = features_rendered[0] @ conv 108 | h, w, c = features_rendered.shape 109 | features_rendered = ( 110 | features_rendered.reshape(h * w, c).detach().cpu().numpy() 111 | ) 112 | features_rendered = pca.transform(features_rendered) 113 | features_rendered = features_rendered.reshape(h, w, 3) 114 | features_rendered = (features_rendered - feats_min) / ( 115 | feats_max - feats_min 116 | ) 117 | frame = (features_rendered * 255).astype(np.uint8) 118 | frames.append(frame[..., ::-1]) 119 | if feedback: 120 | cv2.imshow("PCA", frame) 121 | cv2.imwrite(f"{aux_dir}/{image.name}", frame) 122 | cv2.waitKey(1) 123 | imageio.mimsave(output_path, frames, fps=10, loop=0) 124 | if feedback: 125 | cv2.destroyAllWindows() 126 | 127 | 128 | def main( 129 | data_dir: str = "./data/garden", # colmap path 130 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 131 | results_dir: str = "./results/garden", # output path 132 | rasterizer: Literal[ 133 | "inria", "gsplat" 134 | ] = "gsplat", # Original or gsplat for checkpoints 135 | data_factor: int = 4, 136 | show_visual_feedback: bool = True, 137 | feature: Literal["lseg", "dino"] = "lseg", 138 | ): 139 | 140 | if not torch.cuda.is_available(): 141 | raise RuntimeError("CUDA is required for this demo") 142 | 143 | torch.set_default_device("cuda") 144 | 145 | os.makedirs(results_dir, exist_ok=True) 146 | splats = load_checkpoint_f3dgs( 147 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 148 | ) 149 | splats_optimized = prune_by_gradients(splats) 150 | test_proper_pruning(splats, splats_optimized) 151 | splats = splats_optimized 152 | 153 | render_pca_f3dgs( 154 | splats, 155 | f"{results_dir}/pca_f3dgs_gaussians_{feature}.gif", 156 | pca_on_gaussians=True, 157 | scale=0.2, 158 | feedback=show_visual_feedback, 159 | ) 160 | 161 | render_pca_f3dgs( 162 | splats, 163 | f"{results_dir}/pca_f3dgs_frames_{feature}.gif", 164 | pca_on_gaussians=False, 165 | feedback=show_visual_feedback, 166 | ) 167 | 168 | 169 | if __name__ == "__main__": 170 | tyro.cli(main) 171 | -------------------------------------------------------------------------------- /backproject_compressed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | from typing import Literal 5 | import torch 6 | import tyro 7 | from gsplat import rasterization 8 | import matplotlib 9 | 10 | matplotlib.use("TkAgg") # To avoid conflict with cv2 11 | from tqdm import tqdm 12 | from lseg import LSegNet 13 | import torch.nn as nn 14 | 15 | from utils import ( 16 | get_viewmat_from_colmap_image, 17 | load_checkpoint, 18 | prune_by_gradients, 19 | test_proper_pruning, 20 | ) 21 | 22 | 23 | class EncoderDecoder(nn.Module): 24 | def __init__(self): 25 | super(EncoderDecoder, self).__init__() 26 | self.encoder = nn.Parameter(torch.randn(512, 16)) 27 | self.decoder = nn.Parameter(torch.randn(16, 512)) 28 | 29 | def forward(self, x): 30 | x = x @ self.encoder 31 | y = x @ self.decoder 32 | return x, y 33 | 34 | 35 | encoder_decoder = EncoderDecoder().to("cuda") 36 | encoder_decoder.load_state_dict(torch.load("./encoder_decoder.ckpt")) 37 | 38 | 39 | def create_feature_field_lseg(splats, batch_size=1, use_cpu=False): 40 | device = "cpu" if use_cpu else "cuda" 41 | 42 | net = LSegNet( 43 | backbone="clip_vitl16_384", 44 | features=256, 45 | crop_size=480, 46 | arch_option=0, 47 | block_depth=0, 48 | activation="lrelu", 49 | ) 50 | # Load pre-trained weights 51 | net.load_state_dict( 52 | torch.load("./checkpoints/lseg_minimal_e200.ckpt", map_location=device) 53 | ) 54 | net.eval() 55 | net.to(device) 56 | 57 | means = splats["means"] 58 | colors_dc = splats["features_dc"] 59 | colors_rest = splats["features_rest"] 60 | colors_all = torch.cat([colors_dc, colors_rest], dim=1) 61 | 62 | colors = colors_dc[:, 0, :] # * 0 63 | colors_0 = colors_dc[:, 0, :] * 0 64 | colors.to(device) 65 | colors_0.to(device) 66 | 67 | colmap_project = splats["colmap_project"] 68 | 69 | opacities = torch.sigmoid(splats["opacity"]) 70 | scales = torch.exp(splats["scaling"]) 71 | quats = splats["rotation"] 72 | K = splats["camera_matrix"] 73 | colors.requires_grad = True 74 | colors_0.requires_grad = True 75 | 76 | gaussian_features = torch.zeros(colors.shape[0], 16, device=colors.device) 77 | gaussian_denoms = torch.ones(colors.shape[0], device=colors.device) * 1e-12 78 | 79 | t1 = time.time() 80 | 81 | colors_feats = torch.zeros( 82 | colors.shape[0], 16, device=colors.device, requires_grad=True 83 | ) 84 | colors_feats_0 = torch.zeros( 85 | colors.shape[0], 3, device=colors.device, requires_grad=True 86 | ) 87 | 88 | images = sorted(colmap_project.images.values(), key=lambda x: x.name) 89 | 90 | for batch_start in tqdm( 91 | range(0, len(images), batch_size), 92 | desc="Feature backprojection (batches)", 93 | ): 94 | batch = images[batch_start : batch_start + batch_size] 95 | for image in batch: 96 | viewmat = get_viewmat_from_colmap_image(image) 97 | 98 | width = int(K[0, 2] * 2) 99 | height = int(K[1, 2] * 2) 100 | 101 | with torch.no_grad(): 102 | output, _, meta = rasterization( 103 | means, 104 | quats, 105 | scales, 106 | opacities, 107 | colors_all, 108 | viewmat[None], 109 | K[None], 110 | width=width, 111 | height=height, 112 | sh_degree=3, 113 | ) 114 | 115 | output = torch.nn.functional.interpolate( 116 | output.permute(0, 3, 1, 2).to(device), 117 | size=(480, 480), 118 | mode="bilinear", 119 | ) 120 | output.to(device) 121 | feats = net.forward(output) 122 | feats = torch.nn.functional.normalize(feats, dim=1) 123 | feats = torch.nn.functional.interpolate( 124 | feats, size=(height, width), mode="bilinear" 125 | )[0] 126 | feats = feats.permute(1, 2, 0) 127 | feats = feats @ encoder_decoder.encoder # 512 -> 16 128 | 129 | output_for_grad, _, meta = rasterization( 130 | means, 131 | quats, 132 | scales, 133 | opacities, 134 | colors_feats, 135 | viewmat[None], 136 | K[None], 137 | width=width, 138 | height=height, 139 | ) 140 | 141 | target = (output_for_grad[0].to(device) * feats).sum() 142 | target.to(device) 143 | target.backward() 144 | colors_feats_copy = colors_feats.grad.clone() 145 | colors_feats.grad.zero_() 146 | 147 | output_for_grad, _, meta = rasterization( 148 | means, 149 | quats, 150 | scales, 151 | opacities, 152 | colors_feats_0, 153 | viewmat[None], 154 | K[None], 155 | width=width, 156 | height=height, 157 | ) 158 | 159 | target_0 = (output_for_grad[0]).sum() 160 | target_0.to(device) 161 | target_0.backward() 162 | 163 | gaussian_features += colors_feats_copy 164 | gaussian_denoms += colors_feats_0.grad[:, 0] 165 | colors_feats_0.grad.zero_() 166 | 167 | # Clean up unused variables and free GPU memory 168 | del ( 169 | viewmat, 170 | meta, 171 | _, 172 | output, 173 | feats, 174 | output_for_grad, 175 | colors_feats_copy, 176 | target, 177 | target_0, 178 | ) 179 | torch.cuda.empty_cache() 180 | gaussian_features = gaussian_features / gaussian_denoms[..., None] 181 | gaussian_features = gaussian_features / gaussian_features.norm(dim=-1, keepdim=True) 182 | # Replace nan values with 0 183 | gaussian_features[torch.isnan(gaussian_features)] = 0 184 | t2 = time.time() 185 | print("Time taken for feature backprojection", t2 - t1) 186 | return gaussian_features 187 | 188 | 189 | def main( 190 | data_dir: str = "./data/garden", # colmap path 191 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 192 | results_dir: str = "./results/garden", # output path 193 | rasterizer: Literal[ 194 | "inria", "gsplat" 195 | ] = "gsplat", # Original or GSplat for checkpoints 196 | data_factor: int = 4, 197 | feature_field_batch_count: int = 1, # Number of batches to process for feature field 198 | run_feature_field_on_cpu: bool = False, # Run feature field on CPU 199 | ): 200 | 201 | if not torch.cuda.is_available(): 202 | raise RuntimeError("CUDA is required for this demo") 203 | 204 | torch.set_default_device("cuda") 205 | 206 | os.makedirs(results_dir, exist_ok=True) 207 | splats = load_checkpoint( 208 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 209 | ) 210 | splats_optimized = prune_by_gradients(splats) 211 | test_proper_pruning(splats, splats_optimized) 212 | splats = splats_optimized 213 | features = create_feature_field_lseg( 214 | splats, feature_field_batch_count, run_feature_field_on_cpu 215 | ) 216 | print(features.shape) 217 | torch.save(features, f"{results_dir}/features_lseg_compressed.pt") 218 | 219 | 220 | if __name__ == "__main__": 221 | tyro.cli(main) 222 | -------------------------------------------------------------------------------- /f3dgs/utils_simple_trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from sklearn.neighbors import NearestNeighbors 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | import matplotlib.pyplot as plt 9 | from matplotlib import colormaps 10 | 11 | 12 | class CameraOptModule(torch.nn.Module): 13 | """Camera pose optimization module.""" 14 | 15 | def __init__(self, n: int): 16 | super().__init__() 17 | # Delta positions (3D) + Delta rotations (6D) 18 | self.embeds = torch.nn.Embedding(n, 9) 19 | # Identity rotation in 6D representation 20 | self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])) 21 | 22 | def zero_init(self): 23 | torch.nn.init.zeros_(self.embeds.weight) 24 | 25 | def random_init(self, std: float): 26 | torch.nn.init.normal_(self.embeds.weight, std=std) 27 | 28 | def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor: 29 | """Adjust camera pose based on deltas. 30 | 31 | Args: 32 | camtoworlds: (..., 4, 4) 33 | embed_ids: (...,) 34 | 35 | Returns: 36 | updated camtoworlds: (..., 4, 4) 37 | """ 38 | assert camtoworlds.shape[:-2] == embed_ids.shape 39 | batch_dims = camtoworlds.shape[:-2] 40 | pose_deltas = self.embeds(embed_ids) # (..., 9) 41 | dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:] 42 | rot = rotation_6d_to_matrix( 43 | drot + self.identity.expand(*batch_dims, -1) 44 | ) # (..., 3, 3) 45 | transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_dims, 1, 1)) 46 | transform[..., :3, :3] = rot 47 | transform[..., :3, 3] = dx 48 | return torch.matmul(camtoworlds, transform) 49 | 50 | 51 | class AppearanceOptModule(torch.nn.Module): 52 | """Appearance optimization module.""" 53 | 54 | def __init__( 55 | self, 56 | n: int, 57 | feature_dim: int, 58 | embed_dim: int = 16, 59 | sh_degree: int = 3, 60 | mlp_width: int = 64, 61 | mlp_depth: int = 2, 62 | ): 63 | super().__init__() 64 | self.embed_dim = embed_dim 65 | self.sh_degree = sh_degree 66 | self.embeds = torch.nn.Embedding(n, embed_dim) 67 | layers = [] 68 | layers.append( 69 | torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width) 70 | ) 71 | layers.append(torch.nn.ReLU(inplace=True)) 72 | for _ in range(mlp_depth - 1): 73 | layers.append(torch.nn.Linear(mlp_width, mlp_width)) 74 | layers.append(torch.nn.ReLU(inplace=True)) 75 | layers.append(torch.nn.Linear(mlp_width, 3)) 76 | self.color_head = torch.nn.Sequential(*layers) 77 | 78 | def forward( 79 | self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int 80 | ) -> Tensor: 81 | """Adjust appearance based on embeddings. 82 | 83 | Args: 84 | features: (N, feature_dim) 85 | embed_ids: (C,) 86 | dirs: (C, N, 3) 87 | 88 | Returns: 89 | colors: (C, N, 3) 90 | """ 91 | from gsplat.cuda._torch_impl import _eval_sh_bases_fast 92 | 93 | C, N = dirs.shape[:2] 94 | # Camera embeddings 95 | if embed_ids is None: 96 | embeds = torch.zeros(C, self.embed_dim, device=features.device) 97 | else: 98 | embeds = self.embeds(embed_ids) # [C, D2] 99 | embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2] 100 | # GS features 101 | features = features[None, :, :].expand(C, -1, -1) # [C, N, D1] 102 | # View directions 103 | dirs = F.normalize(dirs, dim=-1) # [C, N, 3] 104 | num_bases_to_use = (sh_degree + 1) ** 2 105 | num_bases = (self.sh_degree + 1) ** 2 106 | sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K] 107 | sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) 108 | # Get colors 109 | if self.embed_dim > 0: 110 | h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K] 111 | else: 112 | h = torch.cat([features, sh_bases], dim=-1) 113 | colors = self.color_head(h) 114 | return colors 115 | 116 | 117 | def rotation_6d_to_matrix(d6: Tensor) -> Tensor: 118 | """ 119 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 120 | using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. 121 | Args: 122 | d6: 6D rotation representation, of size (*, 6) 123 | 124 | Returns: 125 | batch of rotation matrices of size (*, 3, 3) 126 | 127 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 128 | On the Continuity of Rotation Representations in Neural Networks. 129 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 130 | Retrieved from http://arxiv.org/abs/1812.07035 131 | """ 132 | 133 | a1, a2 = d6[..., :3], d6[..., 3:] 134 | b1 = F.normalize(a1, dim=-1) 135 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 136 | b2 = F.normalize(b2, dim=-1) 137 | b3 = torch.cross(b1, b2, dim=-1) 138 | return torch.stack((b1, b2, b3), dim=-2) 139 | 140 | 141 | def knn(x: Tensor, K: int = 4) -> Tensor: 142 | x_np = x.cpu().numpy() 143 | model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) 144 | distances, _ = model.kneighbors(x_np) 145 | return torch.from_numpy(distances).to(x) 146 | 147 | 148 | def rgb_to_sh(rgb: Tensor) -> Tensor: 149 | C0 = 0.28209479177387814 150 | return (rgb - 0.5) / C0 151 | 152 | 153 | def set_random_seed(seed: int): 154 | random.seed(seed) 155 | np.random.seed(seed) 156 | torch.manual_seed(seed) 157 | 158 | 159 | # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163 160 | def colormap(img, cmap="jet"): 161 | W, H = img.shape[:2] 162 | dpi = 300 163 | fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) 164 | im = ax.imshow(img, cmap=cmap) 165 | ax.set_axis_off() 166 | fig.colorbar(im, ax=ax) 167 | fig.tight_layout() 168 | fig.canvas.draw() 169 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 170 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 171 | img = torch.from_numpy(data).float().permute(2, 0, 1) 172 | plt.close() 173 | return img 174 | 175 | 176 | def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: 177 | """Convert single channel to a color img. 178 | 179 | Args: 180 | img (torch.Tensor): (..., 1) float32 single channel image. 181 | colormap (str): Colormap for img. 182 | 183 | Returns: 184 | (..., 3) colored img with colors in [0, 1]. 185 | """ 186 | img = torch.nan_to_num(img, 0) 187 | if colormap == "gray": 188 | return img.repeat(1, 1, 3) 189 | img_long = (img * 255).long() 190 | img_long_min = torch.min(img_long) 191 | img_long_max = torch.max(img_long) 192 | assert img_long_min >= 0, f"the min value is {img_long_min}" 193 | assert img_long_max <= 255, f"the max value is {img_long_max}" 194 | return torch.tensor( 195 | colormaps[colormap].colors, # type: ignore 196 | device=img.device, 197 | )[img_long[..., 0]] 198 | 199 | 200 | def apply_depth_colormap( 201 | depth: torch.Tensor, 202 | acc: torch.Tensor = None, 203 | near_plane: float = None, 204 | far_plane: float = None, 205 | ) -> torch.Tensor: 206 | """Converts a depth image to color for easier analysis. 207 | 208 | Args: 209 | depth (torch.Tensor): (..., 1) float32 depth. 210 | acc (torch.Tensor | None): (..., 1) optional accumulation mask. 211 | near_plane: Closest depth to consider. If None, use min image value. 212 | far_plane: Furthest depth to consider. If None, use max image value. 213 | 214 | Returns: 215 | (..., 3) colored depth image with colors in [0, 1]. 216 | """ 217 | near_plane = near_plane or float(torch.min(depth)) 218 | far_plane = far_plane or float(torch.max(depth)) 219 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) 220 | depth = torch.clip(depth, 0.0, 1.0) 221 | img = apply_float_colormap(depth, colormap="turbo") 222 | if acc is not None: 223 | img = img * acc + (1.0 - acc) 224 | return img -------------------------------------------------------------------------------- /f3dgs/datasets/traj.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code borrowed from 3 | 4 | https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py 5 | """ 6 | 7 | import numpy as np 8 | import scipy 9 | 10 | 11 | def normalize(x: np.ndarray) -> np.ndarray: 12 | """Normalization helper function.""" 13 | return x / np.linalg.norm(x) 14 | 15 | 16 | def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray: 17 | """Construct lookat view matrix.""" 18 | vec2 = normalize(lookdir) 19 | vec0 = normalize(np.cross(up, vec2)) 20 | vec1 = normalize(np.cross(vec2, vec0)) 21 | m = np.stack([vec0, vec1, vec2, position], axis=1) 22 | return m 23 | 24 | 25 | def focus_point_fn(poses: np.ndarray) -> np.ndarray: 26 | """Calculate nearest point to all focal axes in poses.""" 27 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 28 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 29 | mt_m = np.transpose(m, [0, 2, 1]) @ m 30 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 31 | return focus_pt 32 | 33 | 34 | def generate_ellipse_path_z( 35 | poses: np.ndarray, 36 | n_frames: int = 120, 37 | # const_speed: bool = True, 38 | variation: float = 0.0, 39 | phase: float = 0.0, 40 | height: float = 0.0, 41 | ) -> np.ndarray: 42 | """Generate an elliptical render path based on the given poses.""" 43 | # Calculate the focal point for the path (cameras point toward this). 44 | center = focus_point_fn(poses) 45 | # Path height sits at z=height (in middle of zero-mean capture pattern). 46 | offset = np.array([center[0], center[1], height]) 47 | 48 | # Calculate scaling for ellipse axes based on input camera positions. 49 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 50 | # Use ellipse that is symmetric about the focal point in xy. 51 | low = -sc + offset 52 | high = sc + offset 53 | # Optional height variation need not be symmetric 54 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 55 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 56 | 57 | def get_positions(theta): 58 | # Interpolate between bounds with trig functions to get ellipse in x-y. 59 | # Optionally also interpolate in z to change camera height along path. 60 | return np.stack( 61 | [ 62 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), 63 | low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5), 64 | variation 65 | * ( 66 | z_low[2] 67 | + (z_high - z_low)[2] 68 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) 69 | ) 70 | + height, 71 | ], 72 | -1, 73 | ) 74 | 75 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) 76 | positions = get_positions(theta) 77 | 78 | # if const_speed: 79 | # # Resample theta angles so that the velocity is closer to constant. 80 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 81 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 82 | # positions = get_positions(theta) 83 | 84 | # Throw away duplicated last position. 85 | positions = positions[:-1] 86 | 87 | # Set path's up vector to axis closest to average of input pose up vectors. 88 | avg_up = poses[:, :3, 1].mean(0) 89 | avg_up = avg_up / np.linalg.norm(avg_up) 90 | ind_up = np.argmax(np.abs(avg_up)) 91 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 92 | 93 | return np.stack([viewmatrix(p - center, up, p) for p in positions]) 94 | 95 | 96 | def generate_ellipse_path_y( 97 | poses: np.ndarray, 98 | n_frames: int = 120, 99 | # const_speed: bool = True, 100 | variation: float = 0.0, 101 | phase: float = 0.0, 102 | height: float = 0.0, 103 | ) -> np.ndarray: 104 | """Generate an elliptical render path based on the given poses.""" 105 | # Calculate the focal point for the path (cameras point toward this). 106 | center = focus_point_fn(poses) 107 | # Path height sits at y=height (in middle of zero-mean capture pattern). 108 | offset = np.array([center[0], height, center[2]]) 109 | 110 | # Calculate scaling for ellipse axes based on input camera positions. 111 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 112 | # Use ellipse that is symmetric about the focal point in xy. 113 | low = -sc + offset 114 | high = sc + offset 115 | # Optional height variation need not be symmetric 116 | y_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 117 | y_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 118 | 119 | def get_positions(theta): 120 | # Interpolate between bounds with trig functions to get ellipse in x-z. 121 | # Optionally also interpolate in y to change camera height along path. 122 | return np.stack( 123 | [ 124 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), 125 | variation 126 | * ( 127 | y_low[1] 128 | + (y_high - y_low)[1] 129 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) 130 | ) 131 | + height, 132 | low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5), 133 | ], 134 | -1, 135 | ) 136 | 137 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) 138 | positions = get_positions(theta) 139 | 140 | # if const_speed: 141 | # # Resample theta angles so that the velocity is closer to constant. 142 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 143 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 144 | # positions = get_positions(theta) 145 | 146 | # Throw away duplicated last position. 147 | positions = positions[:-1] 148 | 149 | # Set path's up vector to axis closest to average of input pose up vectors. 150 | avg_up = poses[:, :3, 1].mean(0) 151 | avg_up = avg_up / np.linalg.norm(avg_up) 152 | ind_up = np.argmax(np.abs(avg_up)) 153 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 154 | 155 | return np.stack([viewmatrix(p - center, up, p) for p in positions]) 156 | 157 | 158 | def generate_interpolated_path( 159 | poses: np.ndarray, 160 | n_interp: int, 161 | spline_degree: int = 5, 162 | smoothness: float = 0.03, 163 | rot_weight: float = 0.1, 164 | ): 165 | """Creates a smooth spline path between input keyframe camera poses. 166 | 167 | Spline is calculated with poses in format (position, lookat-point, up-point). 168 | 169 | Args: 170 | poses: (n, 3, 4) array of input pose keyframes. 171 | n_interp: returned path will have n_interp * (n - 1) total poses. 172 | spline_degree: polynomial degree of B-spline. 173 | smoothness: parameter for spline smoothing, 0 forces exact interpolation. 174 | rot_weight: relative weighting of rotation/translation in spline solve. 175 | 176 | Returns: 177 | Array of new camera poses with shape (n_interp * (n - 1), 3, 4). 178 | """ 179 | 180 | def poses_to_points(poses, dist): 181 | """Converts from pose matrices to (position, lookat, up) format.""" 182 | pos = poses[:, :3, -1] 183 | lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] 184 | up = poses[:, :3, -1] + dist * poses[:, :3, 1] 185 | return np.stack([pos, lookat, up], 1) 186 | 187 | def points_to_poses(points): 188 | """Converts from (position, lookat, up) format to pose matrices.""" 189 | return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) 190 | 191 | def interp(points, n, k, s): 192 | """Runs multidimensional B-spline interpolation on the input points.""" 193 | sh = points.shape 194 | pts = np.reshape(points, (sh[0], -1)) 195 | k = min(k, sh[0] - 1) 196 | tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) 197 | u = np.linspace(0, 1, n, endpoint=False) 198 | new_points = np.array(scipy.interpolate.splev(u, tck)) 199 | new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) 200 | return new_points 201 | 202 | points = poses_to_points(poses, dist=rot_weight) 203 | new_points = interp( 204 | points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness 205 | ) 206 | return points_to_poses(new_points) 207 | -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Literal 3 | import tyro 4 | import os 5 | import torch 6 | import cv2 7 | import imageio # To generate gifs 8 | import pycolmap_scene_manager as pycolmap 9 | from gsplat import rasterization 10 | import numpy as np 11 | import clip 12 | import matplotlib 13 | 14 | matplotlib.use("TkAgg") 15 | 16 | from lseg import LSegNet 17 | from utils import ( 18 | create_checkerboard, 19 | prune_by_gradients, 20 | test_proper_pruning, 21 | get_viewmat_from_colmap_image, 22 | load_checkpoint, 23 | ) 24 | 25 | 26 | def get_mask3d_lseg(splats, features, prompt, neg_prompt, threshold=None): 27 | 28 | net = LSegNet( 29 | backbone="clip_vitl16_384", 30 | features=256, 31 | crop_size=480, 32 | arch_option=0, 33 | block_depth=0, 34 | activation="lrelu", 35 | ) 36 | # Load pre-trained weights 37 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 38 | net.eval() 39 | net.cuda() 40 | 41 | # Preprocess the text prompt 42 | clip_text_encoder = net.clip_pretrained.encode_text 43 | 44 | pos_prompt_length = len(prompt.split(";")) 45 | 46 | prompts = prompt.split(";") + neg_prompt.split(";") 47 | 48 | prompt = clip.tokenize(prompts) 49 | prompt = prompt.cuda() 50 | 51 | text_feat = clip_text_encoder(prompt) # N, 512, N - number of prompts 52 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) 53 | 54 | features = torch.nn.functional.normalize(features, dim=1) 55 | score = features @ text_feat_norm.float().T 56 | mask_3d = score[:, :pos_prompt_length].max(dim=1)[0] > score[:, pos_prompt_length:].max(dim=1)[0] 57 | if threshold is not None: 58 | mask_3d = mask_3d & (score[:, 0] > threshold) 59 | mask_3d_inv = ~mask_3d 60 | 61 | return mask_3d, mask_3d_inv 62 | 63 | 64 | def apply_mask3d(splats, mask3d, mask3d_inverted): 65 | if mask3d_inverted == None: 66 | mask3d_inverted = ~mask3d 67 | extracted = deepcopy(splats) 68 | deleted = deepcopy(splats) 69 | masked = deepcopy(splats) 70 | extracted["means"] = extracted["means"][mask3d] 71 | extracted["features_dc"] = extracted["features_dc"][mask3d] 72 | extracted["features_rest"] = extracted["features_rest"][mask3d] 73 | extracted["scaling"] = extracted["scaling"][mask3d] 74 | extracted["rotation"] = extracted["rotation"][mask3d] 75 | extracted["opacity"] = extracted["opacity"][mask3d] 76 | 77 | deleted["means"] = deleted["means"][mask3d_inverted] 78 | deleted["features_dc"] = deleted["features_dc"][mask3d_inverted] 79 | deleted["features_rest"] = deleted["features_rest"][mask3d_inverted] 80 | deleted["scaling"] = deleted["scaling"][mask3d_inverted] 81 | deleted["rotation"] = deleted["rotation"][mask3d_inverted] 82 | deleted["opacity"] = deleted["opacity"][mask3d_inverted] 83 | 84 | masked["features_dc"][mask3d] = 1 # (1 - 0.5) / 0.2820947917738781 85 | masked["features_dc"][~mask3d] = 0 # (0 - 0.5) / 0.2820947917738781 86 | masked["features_rest"][~mask3d] = 0 87 | 88 | return extracted, deleted, masked 89 | 90 | 91 | def render_to_gif( 92 | output_path: str, 93 | splats, 94 | feedback: bool = False, 95 | use_checkerboard_background: bool = False, 96 | no_sh: bool = False, 97 | ): 98 | if feedback: 99 | cv2.destroyAllWindows() 100 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 101 | frames = [] 102 | means = splats["means"] 103 | colors_dc = splats["features_dc"] 104 | colors_rest = splats["features_rest"] 105 | colors = torch.cat([colors_dc, colors_rest], dim=1) 106 | if no_sh == True: 107 | colors = colors_dc[:, 0, :] 108 | opacities = torch.sigmoid(splats["opacity"]) 109 | scales = torch.exp(splats["scaling"]) 110 | quats = splats["rotation"] 111 | K = splats["camera_matrix"] 112 | aux_dir = output_path + ".images" 113 | os.makedirs(aux_dir, exist_ok=True) 114 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 115 | viewmat = get_viewmat_from_colmap_image(image) 116 | output, alphas, meta = rasterization( 117 | means, 118 | quats, 119 | scales, 120 | opacities, 121 | colors, 122 | viewmat[None], 123 | K[None], 124 | width=K[0, 2] * 2, 125 | height=K[1, 2] * 2, 126 | sh_degree=3 if not no_sh else None, 127 | ) 128 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 129 | if use_checkerboard_background: 130 | checkerboard = create_checkerboard(frame.shape[1], frame.shape[0]) 131 | alphas = alphas[0].detach().cpu().numpy() 132 | frame = frame * alphas + checkerboard * (1 - alphas) 133 | frame = np.clip(frame, 0, 255).astype(np.uint8) 134 | frames.append(frame) 135 | if feedback: 136 | cv2.imshow("Rendering", frame[..., ::-1]) 137 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 138 | cv2.waitKey(1) 139 | if output_path is not None: 140 | imageio.mimsave(output_path, frames, fps=10, loop=0) 141 | if feedback: 142 | cv2.destroyAllWindows() 143 | 144 | 145 | def render_mask_2d_to_gif( 146 | splats, 147 | features, 148 | prompt, 149 | neg_prompt, 150 | output_path: str, 151 | feedback: bool = False, 152 | ): 153 | if feedback: 154 | cv2.destroyAllWindows() 155 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 156 | frames = [] 157 | means = splats["means"] 158 | colors_dc = splats["features_dc"] 159 | colors_rest = splats["features_rest"] 160 | colors = torch.cat([colors_dc, colors_rest], dim=1) 161 | opacities = torch.sigmoid(splats["opacity"]) 162 | scales = torch.exp(splats["scaling"]) 163 | quats = splats["rotation"] 164 | K = splats["camera_matrix"] 165 | aux_dir = output_path + ".images" 166 | os.makedirs(aux_dir, exist_ok=True) 167 | 168 | net = LSegNet( 169 | backbone="clip_vitl16_384", 170 | features=256, 171 | crop_size=480, 172 | arch_option=0, 173 | block_depth=0, 174 | activation="lrelu", 175 | ) 176 | # Load pre-trained weights 177 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 178 | net.eval() 179 | net.cuda() 180 | 181 | # Preprocess the text prompt 182 | clip_text_encoder = net.clip_pretrained.encode_text 183 | pos_prompt_length = len(prompt.split(";")) 184 | 185 | prompts = prompt.split(";") + neg_prompt.split(";") 186 | 187 | prompt = clip.tokenize(prompts) 188 | prompt = prompt.cuda() 189 | 190 | text_feat = clip_text_encoder(prompt) # N, 512, N - number of prompts 191 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) 192 | 193 | # features = torch.nn.functional.normalize(features, dim=1) 194 | 195 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 196 | viewmat = get_viewmat_from_colmap_image(image) 197 | output, alphas, meta = rasterization( 198 | means, 199 | quats, 200 | scales, 201 | opacities, 202 | colors, 203 | viewmats=viewmat[None], 204 | Ks=K[None], 205 | width=K[0, 2] * 2, 206 | height=K[1, 2] * 2, 207 | sh_degree=3, 208 | ) 209 | feats_rendered, _, _ = rasterization( 210 | means, 211 | quats, 212 | scales, 213 | opacities, 214 | features, 215 | viewmats=viewmat[None], 216 | Ks=K[None], 217 | width=K[0, 2] * 2, 218 | height=K[1, 2] * 2, 219 | # sh_degree=3, 220 | ) 221 | feats_rendered = feats_rendered[0] 222 | feats_rendered = torch.nn.functional.normalize(feats_rendered, dim=-1) 223 | score = feats_rendered @ text_feat_norm.float().T 224 | mask2d = score[..., :pos_prompt_length].max(dim=2)[0] > score[..., pos_prompt_length:].max(dim=2)[0] 225 | # print(mask2d.shape) 226 | mask2d = mask2d[..., None].detach().cpu().numpy() 227 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 228 | frame = frame * ( 229 | 0.75 + 0.25 * mask2d * np.array([255, 0, 0]) + (1 - mask2d) * 0.25 230 | ) 231 | frame = np.clip(frame, 0, 255).astype(np.uint8) 232 | frames.append(frame) 233 | if feedback: 234 | cv2.imshow("Rendering", frame[..., ::-1]) 235 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 236 | cv2.waitKey(1) 237 | if output_path is not None: 238 | imageio.mimsave(output_path, frames, fps=10, loop=0) 239 | if feedback: 240 | cv2.destroyAllWindows() 241 | 242 | 243 | def save_to_ckpt( 244 | output_path: str, 245 | splats, 246 | ): 247 | # Save Torch Checkpoint 248 | checkpoint_data = { 249 | "splats": { 250 | "means": splats["means"], 251 | "quats": splats["rotation"], 252 | "scales": splats["scaling"], 253 | "opacities": splats["opacity"], 254 | "sh0": splats["features_dc"], 255 | "shN": splats["features_rest"], 256 | } 257 | } 258 | torch.save(checkpoint_data, output_path) 259 | 260 | 261 | def main( 262 | data_dir: str = "./data/garden", # colmap path 263 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 264 | results_dir: str = "./results/garden", # output path 265 | rasterizer: Literal[ 266 | "inria", "gsplat" 267 | ] = "gsplat", # Original or gsplat for checkpoints 268 | prompt: str = "Table", 269 | neg_prompt: str = "Vase;Other", 270 | data_factor: int = 4, 271 | show_visual_feedback: bool = True, 272 | export_checkpoint: bool = False, 273 | ): 274 | 275 | if not torch.cuda.is_available(): 276 | raise RuntimeError("CUDA is required for this demo") 277 | 278 | torch.set_default_device("cuda") 279 | 280 | os.makedirs(results_dir, exist_ok=True) 281 | splats = load_checkpoint( 282 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 283 | ) 284 | splats_optimized = prune_by_gradients(splats) 285 | test_proper_pruning(splats, splats_optimized) 286 | splats = splats_optimized 287 | features = torch.load(f"{results_dir}/features_lseg.pt") 288 | mask3d, mask3d_inv = get_mask3d_lseg(splats, features, prompt, neg_prompt) 289 | extracted, deleted, masked = apply_mask3d(splats, mask3d, mask3d_inv) 290 | 291 | render_mask_2d_to_gif( 292 | splats, 293 | features, 294 | prompt, 295 | neg_prompt, 296 | f"{results_dir}/mask2d.gif", 297 | show_visual_feedback, 298 | ) 299 | 300 | render_to_gif( 301 | f"{results_dir}/extracted.gif", 302 | extracted, 303 | show_visual_feedback, 304 | use_checkerboard_background=True, 305 | ) 306 | render_to_gif(f"{results_dir}/deleted.gif", deleted, show_visual_feedback) 307 | 308 | if export_checkpoint: 309 | save_to_ckpt(f"{results_dir}/extracted.pt", extracted) 310 | save_to_ckpt(f"{results_dir}/deleted.pt", deleted) 311 | 312 | 313 | if __name__ == "__main__": 314 | tyro.cli(main) 315 | -------------------------------------------------------------------------------- /segment_f3dgs.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Literal 3 | import tyro 4 | import os 5 | import torch 6 | import cv2 7 | import imageio # To generate gifs 8 | import pycolmap_scene_manager as pycolmap 9 | from gsplat import rasterization 10 | import numpy as np 11 | import clip 12 | import matplotlib 13 | 14 | matplotlib.use("TkAgg") 15 | 16 | from lseg import LSegNet 17 | from utils import ( 18 | create_checkerboard, 19 | prune_by_gradients, 20 | test_proper_pruning, 21 | get_viewmat_from_colmap_image, 22 | load_checkpoint_f3dgs, 23 | ) 24 | 25 | 26 | def get_mask3d_lseg(splats, features, prompt, neg_prompt, threshold=None): 27 | 28 | net = LSegNet( 29 | backbone="clip_vitl16_384", 30 | features=256, 31 | crop_size=480, 32 | arch_option=0, 33 | block_depth=0, 34 | activation="lrelu", 35 | ) 36 | # Load pre-trained weights 37 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 38 | net.eval() 39 | net.cuda() 40 | 41 | # Preprocess the text prompt 42 | clip_text_encoder = net.clip_pretrained.encode_text 43 | 44 | pos_prompt_length = len(prompt.split(";")) 45 | 46 | prompts = prompt.split(";") + neg_prompt.split(";") 47 | 48 | prompt = clip.tokenize(prompts) 49 | prompt = prompt.cuda() 50 | 51 | text_feat = clip_text_encoder(prompt) # N, 512, N - number of prompts 52 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) 53 | 54 | features = torch.nn.functional.normalize(features, dim=1) 55 | score = features @ text_feat_norm.float().T 56 | mask_3d = score[:, :pos_prompt_length].max(dim=1)[0] > score[:, pos_prompt_length:].max(dim=1)[0] 57 | if threshold is not None: 58 | mask_3d = mask_3d & (score[:, 0] > threshold) 59 | mask_3d_inv = ~mask_3d 60 | 61 | return mask_3d, mask_3d_inv 62 | 63 | 64 | def apply_mask3d(splats, mask3d, mask3d_inverted): 65 | if mask3d_inverted == None: 66 | mask3d_inverted = ~mask3d 67 | extracted = deepcopy(splats) 68 | deleted = deepcopy(splats) 69 | masked = deepcopy(splats) 70 | extracted["means"] = extracted["means"][mask3d] 71 | extracted["features_dc"] = extracted["features_dc"][mask3d] 72 | extracted["features_rest"] = extracted["features_rest"][mask3d] 73 | extracted["scaling"] = extracted["scaling"][mask3d] 74 | extracted["rotation"] = extracted["rotation"][mask3d] 75 | extracted["opacity"] = extracted["opacity"][mask3d] 76 | 77 | deleted["means"] = deleted["means"][mask3d_inverted] 78 | deleted["features_dc"] = deleted["features_dc"][mask3d_inverted] 79 | deleted["features_rest"] = deleted["features_rest"][mask3d_inverted] 80 | deleted["scaling"] = deleted["scaling"][mask3d_inverted] 81 | deleted["rotation"] = deleted["rotation"][mask3d_inverted] 82 | deleted["opacity"] = deleted["opacity"][mask3d_inverted] 83 | 84 | masked["features_dc"][mask3d] = 1 # (1 - 0.5) / 0.2820947917738781 85 | masked["features_dc"][~mask3d] = 0 # (0 - 0.5) / 0.2820947917738781 86 | masked["features_rest"][~mask3d] = 0 87 | 88 | return extracted, deleted, masked 89 | 90 | 91 | def render_to_gif_f3dgs( 92 | output_path: str, 93 | splats, 94 | feedback: bool = False, 95 | use_checkerboard_background: bool = False, 96 | no_sh: bool = False, 97 | ): 98 | if feedback: 99 | cv2.destroyAllWindows() 100 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 101 | frames = [] 102 | means = splats["means"] 103 | colors_dc = splats["features_dc"] 104 | colors_rest = splats["features_rest"] 105 | colors = torch.cat([colors_dc, colors_rest], dim=1) 106 | if no_sh == True: 107 | colors = colors_dc[:, 0, :] 108 | opacities = torch.sigmoid(splats["opacity"]) 109 | scales = torch.exp(splats["scaling"]) 110 | quats = splats["rotation"] 111 | K = splats["camera_matrix"] 112 | aux_dir = output_path + ".images" 113 | os.makedirs(aux_dir, exist_ok=True) 114 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 115 | viewmat = get_viewmat_from_colmap_image(image) 116 | output, alphas, meta = rasterization( 117 | means, 118 | quats, 119 | scales, 120 | opacities, 121 | colors, 122 | viewmat[None], 123 | K[None], 124 | width=K[0, 2] * 2, 125 | height=K[1, 2] * 2, 126 | sh_degree=3 if not no_sh else None, 127 | ) 128 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 129 | if use_checkerboard_background: 130 | checkerboard = create_checkerboard(frame.shape[1], frame.shape[0]) 131 | alphas = alphas[0].detach().cpu().numpy() 132 | frame = frame * alphas + checkerboard * (1 - alphas) 133 | frame = np.clip(frame, 0, 255).astype(np.uint8) 134 | frames.append(frame) 135 | if feedback: 136 | cv2.imshow("Rendering", frame[..., ::-1]) 137 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 138 | cv2.waitKey(1) 139 | imageio.mimsave(output_path, frames, fps=10, loop=0) 140 | if feedback: 141 | cv2.destroyAllWindows() 142 | 143 | 144 | def render_mask_2d_to_gif_f3dgs( 145 | splats, 146 | prompt, 147 | neg_prompt, 148 | output_path: str, 149 | feedback: bool = False, 150 | ): 151 | if feedback: 152 | cv2.destroyAllWindows() 153 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 154 | frames = [] 155 | means = splats["means"] 156 | colors_dc = splats["features_dc"] 157 | colors_rest = splats["features_rest"] 158 | colors = torch.cat([colors_dc, colors_rest], dim=1) 159 | opacities = torch.sigmoid(splats["opacity"]) 160 | scales = torch.exp(splats["scaling"]) 161 | quats = splats["rotation"] 162 | K = splats["camera_matrix"] 163 | aux_dir = output_path + ".images" 164 | os.makedirs(aux_dir, exist_ok=True) 165 | 166 | features = splats["features"] 167 | conv = splats["conv"] 168 | 169 | net = LSegNet( 170 | backbone="clip_vitl16_384", 171 | features=256, 172 | crop_size=480, 173 | arch_option=0, 174 | block_depth=0, 175 | activation="lrelu", 176 | ) 177 | # Load pre-trained weights 178 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 179 | net.eval() 180 | net.cuda() 181 | 182 | # Preprocess the text prompt 183 | clip_text_encoder = net.clip_pretrained.encode_text 184 | 185 | pos_prompt_length = len(prompt.split(";")) 186 | 187 | prompts = prompt.split(";") + neg_prompt.split(";") 188 | 189 | prompt = clip.tokenize(prompts) 190 | prompt = prompt.cuda() 191 | 192 | text_feat = clip_text_encoder(prompt) # N, 512, N - number of prompts 193 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) 194 | 195 | # features = torch.nn.functional.normalize(features, dim=1) 196 | 197 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 198 | viewmat = get_viewmat_from_colmap_image(image) 199 | output, alphas, meta = rasterization( 200 | means, 201 | quats, 202 | scales, 203 | opacities, 204 | colors, 205 | viewmats=viewmat[None], 206 | Ks=K[None], 207 | width=K[0, 2] * 2, 208 | height=K[1, 2] * 2, 209 | sh_degree=3, 210 | ) 211 | feats_rendered, _, _ = rasterization( 212 | means, 213 | quats, 214 | scales, 215 | opacities, 216 | features, 217 | viewmats=viewmat[None], 218 | Ks=K[None], 219 | width=K[0, 2] * 2, 220 | height=K[1, 2] * 2, 221 | # sh_degree=3, 222 | ) 223 | feats_rendered = feats_rendered[0] @ conv 224 | feats_rendered = torch.nn.functional.normalize(feats_rendered, dim=-1) 225 | score = feats_rendered @ text_feat_norm.float().T 226 | mask2d = score[..., :pos_prompt_length].max(dim=2)[0] > score[..., pos_prompt_length:].max(dim=2)[0] 227 | # print(mask2d.shape) 228 | mask2d = mask2d[..., None].detach().cpu().numpy() 229 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 230 | frame = frame * ( 231 | 0.75 + 0.25 * mask2d * np.array([255, 0, 0]) + (1 - mask2d) * 0.25 232 | ) 233 | frame = np.clip(frame, 0, 255).astype(np.uint8) 234 | frames.append(frame) 235 | if feedback: 236 | cv2.imshow("Rendering", frame[..., ::-1]) 237 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 238 | cv2.waitKey(1) 239 | imageio.mimsave(output_path, frames, fps=10, loop=0) 240 | if feedback: 241 | cv2.destroyAllWindows() 242 | 243 | 244 | def save_to_ckpt( 245 | output_path: str, 246 | splats, 247 | ): 248 | # Save Torch Checkpoint 249 | checkpoint_data = { 250 | "splats": { 251 | "means": splats["means"], 252 | "quats": splats["rotation"], 253 | "scales": splats["scaling"], 254 | "opacities": splats["opacity"], 255 | "sh0": splats["features_dc"], 256 | "shN": splats["features_rest"], 257 | } 258 | } 259 | torch.save(checkpoint_data, output_path) 260 | 261 | 262 | def main( 263 | data_dir: str = "./data/garden", # colmap path 264 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 265 | results_dir: str = "./results/garden", # output path 266 | rasterizer: Literal[ 267 | "inria", "gsplat" 268 | ] = "gsplat", # Original or gsplat for checkpoints 269 | prompt: str = "Table", 270 | neg_prompt: str = "Vase;Other", 271 | data_factor: int = 4, 272 | show_visual_feedback: bool = True, 273 | export_checkpoint: bool = False, 274 | ): 275 | 276 | if not torch.cuda.is_available(): 277 | raise RuntimeError("CUDA is required for this demo") 278 | 279 | torch.set_default_device("cuda") 280 | 281 | os.makedirs(results_dir, exist_ok=True) 282 | splats = load_checkpoint_f3dgs( 283 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 284 | ) 285 | splats_optimized = prune_by_gradients(splats) 286 | test_proper_pruning(splats, splats_optimized) 287 | splats = splats_optimized 288 | # features = torch.load(f"{results_dir}/features_lseg.pt") 289 | features = splats["features"] @ splats["conv"] 290 | mask3d, mask3d_inv = get_mask3d_lseg(splats, features, prompt, neg_prompt) 291 | extracted, deleted, masked = apply_mask3d(splats, mask3d, mask3d_inv) 292 | 293 | render_mask_2d_to_gif_f3dgs( 294 | splats, 295 | prompt, 296 | neg_prompt, 297 | f"{results_dir}/mask2d_f3dgs.gif", 298 | show_visual_feedback, 299 | ) 300 | 301 | render_to_gif_f3dgs( 302 | f"{results_dir}/extracted_f3dgs.gif", 303 | extracted, 304 | show_visual_feedback, 305 | use_checkerboard_background=True, 306 | ) 307 | render_to_gif_f3dgs(f"{results_dir}/deleted_f3dgs.gif", deleted, show_visual_feedback) 308 | 309 | if export_checkpoint: 310 | save_to_ckpt(f"{results_dir}/extracted_f3dgs.pt", extracted) 311 | save_to_ckpt(f"{results_dir}/deleted_f3dgs.pt", deleted) 312 | 313 | 314 | if __name__ == "__main__": 315 | tyro.cli(main) 316 | -------------------------------------------------------------------------------- /segment_compressed.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Literal 3 | import tyro 4 | import os 5 | import torch 6 | import cv2 7 | import imageio # To generate gifs 8 | import pycolmap_scene_manager as pycolmap 9 | from gsplat import rasterization 10 | import numpy as np 11 | import clip 12 | import matplotlib 13 | 14 | matplotlib.use("TkAgg") 15 | 16 | from lseg import LSegNet 17 | import torch.nn as nn 18 | 19 | from utils import ( 20 | get_viewmat_from_colmap_image, 21 | create_checkerboard, 22 | load_checkpoint, 23 | prune_by_gradients, 24 | test_proper_pruning, 25 | ) 26 | 27 | 28 | class EncoderDecoder(nn.Module): 29 | def __init__(self): 30 | super(EncoderDecoder, self).__init__() 31 | self.encoder = nn.Parameter(torch.randn(512, 16)) 32 | self.decoder = nn.Parameter(torch.randn(16, 512)) 33 | 34 | def forward(self, x): 35 | x = x @ self.encoder 36 | y = x @ self.decoder 37 | return x, y 38 | 39 | 40 | encoder_decoder = EncoderDecoder().to("cuda") 41 | 42 | encoder_decoder.load_state_dict(torch.load("./encoder_decoder.ckpt")) 43 | 44 | 45 | def get_mask3d_lseg(splats, features, prompt, neg_prompt, threshold=None): 46 | 47 | net = LSegNet( 48 | backbone="clip_vitl16_384", 49 | features=256, 50 | crop_size=480, 51 | arch_option=0, 52 | block_depth=0, 53 | activation="lrelu", 54 | ) 55 | # Load pre-trained weights 56 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 57 | net.eval() 58 | net.cuda() 59 | 60 | # Preprocess the text prompt 61 | clip_text_encoder = net.clip_pretrained.encode_text 62 | 63 | pos_prompt_length = len(prompt.split(";")) 64 | 65 | prompts = prompt.split(";") + neg_prompt.split(";") 66 | 67 | prompt = clip.tokenize(prompts) 68 | prompt = prompt.cuda() 69 | 70 | text_feat = clip_text_encoder(prompt) # N, 512, N - number of prompts 71 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) 72 | text_feat_norm = text_feat_norm.float().to("cuda") 73 | text_feat_norm = text_feat_norm @ encoder_decoder.encoder # 512 -> 16 74 | text_feat_norm = torch.nn.functional.normalize(text_feat_norm, dim=1) 75 | # features = features @ encoder_decoder.decoder 76 | features = torch.nn.functional.normalize(features, dim=1) 77 | print(features.shape, text_feat_norm.shape) 78 | score = features @ text_feat_norm.float().T 79 | mask_3d = score[:, :pos_prompt_length].max(dim=1)[0] > score[:, pos_prompt_length:].max(dim=1)[0] 80 | if threshold is not None: 81 | mask_3d = mask_3d & (score[:, 0] > threshold) 82 | mask_3d_inv = ~mask_3d 83 | 84 | return mask_3d, mask_3d_inv 85 | 86 | 87 | def render_mask_2d_to_gif( 88 | splats, 89 | features, 90 | prompt, 91 | neg_prompt, 92 | output_path: str, 93 | feedback: bool = False, 94 | ): 95 | if feedback: 96 | cv2.destroyAllWindows() 97 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 98 | frames = [] 99 | means = splats["means"] 100 | colors_dc = splats["features_dc"] 101 | colors_rest = splats["features_rest"] 102 | colors = torch.cat([colors_dc, colors_rest], dim=1) 103 | opacities = torch.sigmoid(splats["opacity"]) 104 | scales = torch.exp(splats["scaling"]) 105 | quats = splats["rotation"] 106 | K = splats["camera_matrix"] 107 | aux_dir = output_path + ".images" 108 | os.makedirs(aux_dir, exist_ok=True) 109 | 110 | net = LSegNet( 111 | backbone="clip_vitl16_384", 112 | features=256, 113 | crop_size=480, 114 | arch_option=0, 115 | block_depth=0, 116 | activation="lrelu", 117 | ) 118 | # Load pre-trained weights 119 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 120 | net.eval() 121 | net.cuda() 122 | 123 | # Preprocess the text prompt 124 | clip_text_encoder = net.clip_pretrained.encode_text 125 | 126 | pos_prompt_length = len(prompt.split(";")) 127 | 128 | prompts = prompt.split(";") + neg_prompt.split(";") 129 | 130 | prompt = clip.tokenize(prompts) 131 | prompt = prompt.cuda() 132 | 133 | text_feat = clip_text_encoder(prompt) # N, 512, N - number of prompts 134 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1).float() 135 | text_feat_norm = text_feat_norm @ encoder_decoder.encoder # 512 -> 16 136 | text_feat_norm = torch.nn.functional.normalize(text_feat_norm, dim=1) 137 | 138 | # features = torch.nn.functional.normalize(features, dim=1) 139 | 140 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 141 | viewmat = get_viewmat_from_colmap_image(image) 142 | output, alphas, meta = rasterization( 143 | means, 144 | quats, 145 | scales, 146 | opacities, 147 | colors, 148 | viewmats=viewmat[None], 149 | Ks=K[None], 150 | width=K[0, 2] * 2, 151 | height=K[1, 2] * 2, 152 | sh_degree=3, 153 | ) 154 | feats_rendered, _, _ = rasterization( 155 | means, 156 | quats, 157 | scales, 158 | opacities, 159 | features, 160 | viewmats=viewmat[None], 161 | Ks=K[None], 162 | width=K[0, 2] * 2, 163 | height=K[1, 2] * 2, 164 | # sh_degree=3, 165 | ) 166 | feats_rendered = feats_rendered[0] 167 | feats_rendered = torch.nn.functional.normalize(feats_rendered, dim=-1) 168 | score = feats_rendered @ text_feat_norm.float().T 169 | mask2d = score[..., :pos_prompt_length].max(dim=2)[0] > score[..., pos_prompt_length:].max(dim=2)[0] 170 | # print(mask2d.shape) 171 | mask2d = mask2d[..., None].detach().cpu().numpy() 172 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 173 | frame = frame * ( 174 | 0.75 + 0.25 * mask2d * np.array([255, 0, 0]) + (1 - mask2d) * 0.25 175 | ) 176 | frame = np.clip(frame, 0, 255).astype(np.uint8) 177 | frames.append(frame) 178 | if feedback: 179 | cv2.imshow("Rendering", frame[..., ::-1]) 180 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 181 | cv2.waitKey(1) 182 | imageio.mimsave(output_path, frames, fps=10, loop=0) 183 | if feedback: 184 | cv2.destroyAllWindows() 185 | 186 | 187 | def apply_mask3d(splats, mask3d, mask3d_inverted): 188 | if mask3d_inverted == None: 189 | mask3d_inverted = ~mask3d 190 | extracted = deepcopy(splats) 191 | deleted = deepcopy(splats) 192 | masked = deepcopy(splats) 193 | extracted["means"] = extracted["means"][mask3d] 194 | extracted["features_dc"] = extracted["features_dc"][mask3d] 195 | extracted["features_rest"] = extracted["features_rest"][mask3d] 196 | extracted["scaling"] = extracted["scaling"][mask3d] 197 | extracted["rotation"] = extracted["rotation"][mask3d] 198 | extracted["opacity"] = extracted["opacity"][mask3d] 199 | 200 | deleted["means"] = deleted["means"][mask3d_inverted] 201 | deleted["features_dc"] = deleted["features_dc"][mask3d_inverted] 202 | deleted["features_rest"] = deleted["features_rest"][mask3d_inverted] 203 | deleted["scaling"] = deleted["scaling"][mask3d_inverted] 204 | deleted["rotation"] = deleted["rotation"][mask3d_inverted] 205 | deleted["opacity"] = deleted["opacity"][mask3d_inverted] 206 | 207 | masked["features_dc"][mask3d] = 1 # (1 - 0.5) / 0.2820947917738781 208 | masked["features_dc"][~mask3d] = 0 # (0 - 0.5) / 0.2820947917738781 209 | masked["features_rest"][~mask3d] = 0 210 | 211 | return extracted, deleted, masked 212 | 213 | 214 | def render_to_gif( 215 | output_path: str, 216 | splats, 217 | feedback: bool = False, 218 | use_checkerboard_background: bool = False, 219 | no_sh: bool = False, 220 | ): 221 | if feedback: 222 | cv2.destroyAllWindows() 223 | cv2.namedWindow("Rendering", cv2.WINDOW_NORMAL) 224 | frames = [] 225 | means = splats["means"] 226 | colors_dc = splats["features_dc"] 227 | colors_rest = splats["features_rest"] 228 | colors = torch.cat([colors_dc, colors_rest], dim=1) 229 | if no_sh == True: 230 | colors = colors_dc[:, 0, :] 231 | opacities = torch.sigmoid(splats["opacity"]) 232 | scales = torch.exp(splats["scaling"]) 233 | quats = splats["rotation"] 234 | K = splats["camera_matrix"] 235 | aux_dir = output_path + ".images" 236 | os.makedirs(aux_dir, exist_ok=True) 237 | for image in sorted(splats["colmap_project"].images.values(), key=lambda x: x.name): 238 | viewmat = get_viewmat_from_colmap_image(image) 239 | output, alphas, meta = rasterization( 240 | means, 241 | quats, 242 | scales, 243 | opacities, 244 | colors, 245 | viewmat[None], 246 | K[None], 247 | width=K[0, 2] * 2, 248 | height=K[1, 2] * 2, 249 | sh_degree=3 if not no_sh else None, 250 | ) 251 | frame = np.clip(output[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) 252 | if use_checkerboard_background: 253 | checkerboard = create_checkerboard(frame.shape[1], frame.shape[0]) 254 | alphas = alphas[0].detach().cpu().numpy() 255 | frame = frame * alphas + checkerboard * (1 - alphas) 256 | frame = np.clip(frame, 0, 255).astype(np.uint8) 257 | frames.append(frame) 258 | if feedback: 259 | cv2.imshow("Rendering", frame[..., ::-1]) 260 | cv2.imwrite(f"{aux_dir}/{image.name}", frame[..., ::-1]) 261 | cv2.waitKey(1) 262 | # imageio.mimsave(output_path, frames, fps=10, loop=0) 263 | if feedback: 264 | cv2.destroyAllWindows() 265 | 266 | 267 | def main( 268 | data_dir: str = "./data/garden", # colmap path 269 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 270 | results_dir: str = "./results/garden", # output path 271 | rasterizer: Literal[ 272 | "inria", "gsplat" 273 | ] = "gsplat", # Original or gsplat for checkpoints 274 | prompt: str = "Table", 275 | neg_prompt: str = "Vase;Other", 276 | data_factor: int = 4, 277 | show_visual_feedback: bool = True, 278 | ): 279 | 280 | if not torch.cuda.is_available(): 281 | raise RuntimeError("CUDA is required for this demo") 282 | 283 | torch.set_default_device("cuda") 284 | 285 | os.makedirs(results_dir, exist_ok=True) 286 | splats = load_checkpoint( 287 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 288 | ) 289 | splats_optimized = prune_by_gradients(splats) 290 | test_proper_pruning(splats, splats_optimized) 291 | splats = splats_optimized 292 | features = torch.load(f"{results_dir}/features_lseg_compressed.pt") # 293 | mask3d, mask3d_inv = get_mask3d_lseg(splats, features, prompt, neg_prompt) 294 | extracted, deleted, masked = apply_mask3d(splats, mask3d, mask3d_inv) 295 | 296 | render_mask_2d_to_gif( 297 | splats, 298 | features, 299 | prompt, 300 | neg_prompt, 301 | f"{results_dir}/mask2d.gif", 302 | show_visual_feedback, 303 | ) 304 | 305 | render_to_gif( 306 | f"{results_dir}/extracted.gif", 307 | extracted, 308 | show_visual_feedback, 309 | use_checkerboard_background=True, 310 | ) 311 | render_to_gif(f"{results_dir}/deleted.gif", deleted, show_visual_feedback) 312 | 313 | 314 | if __name__ == "__main__": 315 | tyro.cli(main) 316 | -------------------------------------------------------------------------------- /backproject.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | from typing import Literal 5 | import torch 6 | import tyro 7 | from gsplat import rasterization 8 | import pycolmap_scene_manager as pycolmap 9 | import numpy as np 10 | import matplotlib 11 | 12 | matplotlib.use("TkAgg") # To avoid conflict with cv2 13 | from tqdm import tqdm 14 | from lseg import LSegNet 15 | 16 | 17 | from utils import ( 18 | load_checkpoint, 19 | get_viewmat_from_colmap_image, 20 | prune_by_gradients, 21 | test_proper_pruning, 22 | ) 23 | 24 | 25 | def create_feature_field_lseg(splats, batch_size=1, use_cpu=False): 26 | device = "cpu" if use_cpu else "cuda" 27 | 28 | net = LSegNet( 29 | backbone="clip_vitl16_384", 30 | features=256, 31 | crop_size=480, 32 | arch_option=0, 33 | block_depth=0, 34 | activation="lrelu", 35 | ) 36 | # Load pre-trained weights 37 | net.load_state_dict( 38 | torch.load("./checkpoints/lseg_minimal_e200.ckpt", map_location=device) 39 | ) 40 | net.eval() 41 | net.to(device) 42 | 43 | means = splats["means"] 44 | colors_dc = splats["features_dc"] 45 | colors_rest = splats["features_rest"] 46 | colors_all = torch.cat([colors_dc, colors_rest], dim=1) 47 | 48 | colors = colors_dc[:, 0, :] # * 0 49 | colors_0 = colors_dc[:, 0, :] * 0 50 | colors.to(device) 51 | colors_0.to(device) 52 | 53 | colmap_project = splats["colmap_project"] 54 | 55 | opacities = torch.sigmoid(splats["opacity"]) 56 | scales = torch.exp(splats["scaling"]) 57 | quats = splats["rotation"] 58 | K = splats["camera_matrix"] 59 | colors.requires_grad = True 60 | colors_0.requires_grad = True 61 | 62 | gaussian_features = torch.zeros(colors.shape[0], 512, device=colors.device) 63 | gaussian_denoms = torch.ones(colors.shape[0], device=colors.device) * 1e-12 64 | 65 | t1 = time.time() 66 | 67 | colors_feats = torch.zeros( 68 | colors.shape[0], 512, device=colors.device, requires_grad=True 69 | ) 70 | colors_feats_0 = torch.zeros( 71 | colors.shape[0], 3, device=colors.device, requires_grad=True 72 | ) 73 | 74 | images = sorted(colmap_project.images.values(), key=lambda x: x.name) 75 | # batch_size = math.ceil(len(images) / batch_count) if batch_count > 0 else 1 76 | 77 | for batch_start in tqdm( 78 | range(0, len(images), batch_size), 79 | desc="Feature backprojection (batches)", 80 | ): 81 | batch = images[batch_start : batch_start + batch_size] 82 | for image in batch: 83 | viewmat = get_viewmat_from_colmap_image(image) 84 | 85 | width = int(K[0, 2] * 2) 86 | height = int(K[1, 2] * 2) 87 | 88 | with torch.no_grad(): 89 | output, _, meta = rasterization( 90 | means, 91 | quats, 92 | scales, 93 | opacities, 94 | colors_all, 95 | viewmat[None], 96 | K[None], 97 | width=width, 98 | height=height, 99 | sh_degree=3, 100 | ) 101 | 102 | output = torch.nn.functional.interpolate( 103 | output.permute(0, 3, 1, 2).to(device), 104 | size=(480, 480), 105 | mode="bilinear", 106 | ) 107 | output.to(device) 108 | feats = net.forward(output) 109 | feats = torch.nn.functional.normalize(feats, dim=1) 110 | feats = torch.nn.functional.interpolate( 111 | feats, size=(height, width), mode="bilinear" 112 | )[0] 113 | feats = feats.permute(1, 2, 0) 114 | 115 | output_for_grad, _, meta = rasterization( 116 | means, 117 | quats, 118 | scales, 119 | opacities, 120 | colors_feats, 121 | viewmat[None], 122 | K[None], 123 | width=width, 124 | height=height, 125 | ) 126 | 127 | target = (output_for_grad[0].to(device) * feats).sum() 128 | target.to(device) 129 | target.backward() 130 | colors_feats_copy = colors_feats.grad.clone() 131 | colors_feats.grad.zero_() 132 | 133 | output_for_grad, _, meta = rasterization( 134 | means, 135 | quats, 136 | scales, 137 | opacities, 138 | colors_feats_0, 139 | viewmat[None], 140 | K[None], 141 | width=width, 142 | height=height, 143 | ) 144 | 145 | target_0 = (output_for_grad[0]).sum() 146 | target_0.to(device) 147 | target_0.backward() 148 | 149 | gaussian_features += colors_feats_copy 150 | gaussian_denoms += colors_feats_0.grad[:, 0] 151 | colors_feats_0.grad.zero_() 152 | 153 | # Clean up unused variables and free GPU memory 154 | del ( 155 | viewmat, 156 | meta, 157 | _, 158 | output, 159 | feats, 160 | output_for_grad, 161 | colors_feats_copy, 162 | target, 163 | target_0, 164 | ) 165 | torch.cuda.empty_cache() 166 | gaussian_features = gaussian_features / gaussian_denoms[..., None] 167 | gaussian_features = gaussian_features / gaussian_features.norm(dim=-1, keepdim=True) 168 | # Replace nan values with 0 169 | gaussian_features[torch.isnan(gaussian_features)] = 0 170 | t2 = time.time() 171 | print("Time taken for feature backprojection", t2 - t1) 172 | return gaussian_features 173 | 174 | 175 | def create_feature_field_dino(splats): 176 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 177 | feature_extractor = ( 178 | torch.hub.load("facebookresearch/dinov2:main", "dinov2_vitl14_reg") 179 | .to(device) 180 | .eval() 181 | ) 182 | 183 | dinov2_vits14_reg = feature_extractor 184 | 185 | means = splats["means"] 186 | colors_dc = splats["features_dc"] 187 | colors_rest = splats["features_rest"] 188 | colors_all = torch.cat([colors_dc, colors_rest], dim=1) 189 | 190 | colors = colors_dc[:, 0, :] # * 0 191 | colors_0 = colors_dc[:, 0, :] * 0 192 | colmap_project = splats["colmap_project"] 193 | 194 | opacities = torch.sigmoid(splats["opacity"]) 195 | scales = torch.exp(splats["scaling"]) 196 | quats = splats["rotation"] 197 | K = splats["camera_matrix"] 198 | colors.requires_grad = True 199 | colors_0.requires_grad = True 200 | 201 | DIM = 1024 202 | 203 | gaussian_features = torch.zeros(colors.shape[0], DIM, device=colors.device) 204 | gaussian_denoms = torch.ones(colors.shape[0], device=colors.device) * 1e-12 205 | 206 | t1 = time.time() 207 | 208 | colors_feats = torch.zeros(colors.shape[0], 1024, device=colors.device) 209 | colors_feats.requires_grad = True 210 | colors_feats_0 = torch.zeros(colors.shape[0], 3, device=colors.device) 211 | colors_feats_0.requires_grad = True 212 | 213 | print("Distilling features...") 214 | for image in tqdm(sorted(colmap_project.images.values(), key=lambda x: x.name)): 215 | 216 | image_name = image.name # .split(".")[0] + ".jpg" 217 | 218 | viewmat = get_viewmat_from_colmap_image(image) 219 | 220 | width = int(K[0, 2] * 2) 221 | height = int(K[1, 2] * 2) 222 | with torch.no_grad(): 223 | output, _, meta = rasterization( 224 | means, 225 | quats, 226 | scales, 227 | opacities, 228 | colors_all, 229 | viewmat[None], 230 | K[None], 231 | width=width, 232 | height=height, 233 | sh_degree=3, 234 | ) 235 | 236 | output = torch.nn.functional.interpolate( 237 | output.permute(0, 3, 1, 2).cuda(), 238 | size=(224 * 4, 224 * 4), 239 | mode="bilinear", 240 | align_corners=False, 241 | ) 242 | feats = dinov2_vits14_reg.forward_features(output)["x_norm_patchtokens"] 243 | feats = feats[0].reshape((16 * 4, 16 * 4, DIM)) 244 | feats = torch.nn.functional.interpolate( 245 | feats.unsqueeze(0).permute(0, 3, 1, 2), 246 | size=(height, width), 247 | mode="nearest", 248 | )[0] 249 | feats = feats.permute(1, 2, 0) 250 | 251 | output_for_grad, _, meta = rasterization( 252 | means, 253 | quats, 254 | scales, 255 | opacities, 256 | colors_feats, 257 | viewmat[None], 258 | K[None], 259 | width=width, 260 | height=height, 261 | ) 262 | 263 | target = (output_for_grad[0] * feats).mean() 264 | 265 | target.backward() 266 | 267 | colors_feats_copy = colors_feats.grad.clone() 268 | 269 | colors_feats.grad.zero_() 270 | 271 | output_for_grad, _, meta = rasterization( 272 | means, 273 | quats, 274 | scales, 275 | opacities, 276 | colors_feats_0, 277 | viewmat[None], 278 | K[None], 279 | width=width, 280 | height=height, 281 | ) 282 | 283 | target_0 = (output_for_grad[0]).mean() 284 | 285 | target_0.backward() 286 | 287 | gaussian_features += colors_feats_copy # / (colors_feats_0.grad[:,0:1]+1e-12) 288 | gaussian_denoms += colors_feats_0.grad[:, 0] 289 | colors_feats_0.grad.zero_() 290 | print(gaussian_features.shape, gaussian_denoms.shape) 291 | gaussian_features = gaussian_features / gaussian_denoms[..., None] 292 | gaussian_features = gaussian_features / gaussian_features.norm(dim=-1, keepdim=True) 293 | # Replace nan values with 0 294 | print("NaN features", torch.isnan(gaussian_features).sum()) 295 | gaussian_features[torch.isnan(gaussian_features)] = 0 296 | t2 = time.time() 297 | print("Time taken for feature distillation", t2 - t1) 298 | return gaussian_features 299 | 300 | 301 | def main( 302 | data_dir: str = "./data/garden", # colmap path 303 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 304 | results_dir: str = "./results/garden", # output path 305 | rasterizer: Literal[ 306 | "inria", "gsplat" 307 | ] = "gsplat", # Original or GSplat for checkpoints 308 | data_factor: int = 4, 309 | feature_field_batch_count: int = 1, # Number of batches to process for feature field 310 | run_feature_field_on_cpu: bool = False, # Run feature field on CPU 311 | feature: Literal["lseg", "dino"] = "lseg", # Feature field type 312 | ): 313 | 314 | if not torch.cuda.is_available(): 315 | raise RuntimeError("CUDA is required for this demo") 316 | 317 | torch.set_default_device("cuda") 318 | 319 | os.makedirs(results_dir, exist_ok=True) 320 | splats = load_checkpoint( 321 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 322 | ) 323 | splats_optimized = prune_by_gradients(splats) 324 | test_proper_pruning(splats, splats_optimized) 325 | splats = splats_optimized 326 | if feature == "lseg": 327 | features = create_feature_field_lseg( 328 | splats, feature_field_batch_count, run_feature_field_on_cpu 329 | ) 330 | torch.save(features, f"{results_dir}/features_lseg.pt") 331 | elif feature == "dino": 332 | features = create_feature_field_dino(splats) 333 | print("Features.shape", features.shape) 334 | torch.save(features, f"{results_dir}/features_dino.pt") 335 | else: 336 | raise ValueError("Invalid field type") 337 | 338 | 339 | if __name__ == "__main__": 340 | tyro.cli(main) 341 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pycolmap_scene_manager as pycolmap 3 | from typing import Literal, Optional 4 | import numpy as np 5 | from gsplat import rasterization 6 | import warnings 7 | from plyfile import PlyData 8 | 9 | 10 | 11 | def _detach_tensors_from_dict(d, inplace=True): 12 | if not inplace: 13 | d = d.copy() 14 | for key in d: 15 | if isinstance(d[key], torch.Tensor): 16 | d[key] = d[key].detach() 17 | return d 18 | 19 | 20 | def load_checkpoint( 21 | checkpoint: str, 22 | data_dir: str, 23 | format: Literal["inria", "gsplat", "ply"] = "gsplat", 24 | data_factor: int = 1, 25 | rasterizer: Optional[Literal["inria", "gsplat"]] = None, 26 | ): 27 | 28 | colmap_project = pycolmap.SceneManager(f"{data_dir}/sparse/0") 29 | colmap_project.load_cameras() 30 | colmap_project.load_images() 31 | colmap_project.load_points3D() 32 | if format in ["inria", "gsplat"]: 33 | model = torch.load(checkpoint,weights_only=False) # Make sure it is generated by 3DGS original repo 34 | 35 | if format is None: 36 | if rasterizer is None: 37 | raise ValueError("Must specify format or rasterizer") 38 | else: 39 | format = rasterizer 40 | if rasterizer is not None: 41 | format = rasterizer 42 | warnings.warn( 43 | "`rasterizer` is deprecated. Use `format` instead.", DeprecationWarning 44 | ) 45 | if format == "inria": 46 | model_params, _ = model 47 | splats = { 48 | "active_sh_degree": model_params[0], 49 | "means": model_params[1], 50 | "features_dc": model_params[2], 51 | "features_rest": model_params[3], 52 | "scaling": model_params[4], 53 | "rotation": model_params[5], 54 | "opacity": model_params[6].squeeze(1), 55 | } 56 | elif format == "gsplat": 57 | 58 | model_params = model["splats"] 59 | splats = { 60 | "active_sh_degree": 3, 61 | "means": model_params["means"], 62 | "features_dc": model_params["sh0"], 63 | "features_rest": model_params["shN"], 64 | "scaling": model_params["scales"], 65 | "rotation": model_params["quats"], 66 | "opacity": model_params["opacities"], 67 | } 68 | elif format == "ply": 69 | plydata = PlyData.read(checkpoint) 70 | vertex = plydata['vertex'].data 71 | 72 | def to_tensor(name, dtype=torch.float32): 73 | return torch.tensor(np.stack([v[name] for v in vertex]), dtype=dtype) 74 | 75 | splats = { 76 | "active_sh_degree": 3, 77 | "means": torch.stack([to_tensor("x"), to_tensor("y"), to_tensor("z")], dim=1), 78 | "features_dc": torch.stack([to_tensor("f_dc_0"), to_tensor("f_dc_1"), to_tensor("f_dc_2")], dim=1).reshape((-1,1,3)), 79 | "features_rest": torch.stack( 80 | [to_tensor(f"f_rest_{i}") for i in range(45)], dim=1 81 | ).reshape((-1,15,3)), 82 | "scaling": torch.stack([to_tensor(f"scale_{i}") for i in range(3)], dim=1), 83 | "rotation": torch.stack([to_tensor(f"rot_{i}") for i in range(4)], dim=1), 84 | "opacity": to_tensor("opacity"), 85 | } 86 | 87 | else: 88 | raise ValueError("Invalid Gaussian splatting format") 89 | 90 | _detach_tensors_from_dict(splats) 91 | 92 | # Assuming only one camera 93 | for camera in colmap_project.cameras.values(): 94 | camera_matrix = torch.tensor( 95 | [ 96 | [camera.fx, 0, camera.cx], 97 | [0, camera.fy, camera.cy], 98 | [0, 0, 1], 99 | ] 100 | ) 101 | break 102 | 103 | camera_matrix[:2, :3] /= data_factor 104 | 105 | splats["camera_matrix"] = camera_matrix 106 | splats["colmap_project"] = colmap_project 107 | splats["colmap_dir"] = data_dir 108 | 109 | return splats 110 | 111 | def load_checkpoint_f3dgs( 112 | checkpoint: str, 113 | data_dir: str, 114 | format: Literal["inria", "gsplat"] = "gsplat", 115 | data_factor: int = 1, 116 | rasterizer: Optional[Literal["inria", "gsplat"]] = None, 117 | ): 118 | # Currently supports only gsplat format 119 | colmap_project = pycolmap.SceneManager(f"{data_dir}/sparse/0") 120 | colmap_project.load_cameras() 121 | colmap_project.load_images() 122 | colmap_project.load_points3D() 123 | model = torch.load(checkpoint) # Make sure it is generated by 3DGS original repo 124 | 125 | if format is None: 126 | if rasterizer is None: 127 | raise ValueError("Must specify format or rasterizer") 128 | else: 129 | format = rasterizer 130 | if rasterizer is not None: 131 | format = rasterizer 132 | warnings.warn( 133 | "`rasterizer` is deprecated. Use `format` instead.", DeprecationWarning 134 | ) 135 | if format == "inria": 136 | model_params, _ = model 137 | splats = { 138 | "active_sh_degree": model_params[0], 139 | "means": model_params[1], 140 | "features_dc": model_params[2], 141 | "features_rest": model_params[3], 142 | "scaling": model_params[4], 143 | "rotation": model_params[5], 144 | "opacity": model_params[6].squeeze(1), 145 | } 146 | elif format == "gsplat": 147 | 148 | model_params = model["splats"] 149 | splats = { 150 | "active_sh_degree": 3, 151 | "means": model_params["means"], 152 | "features_dc": model_params["sh0"], 153 | "features_rest": model_params["shN"], 154 | "scaling": model_params["scales"], 155 | "rotation": model_params["quats"], 156 | "opacity": model_params["opacities"], 157 | "conv": model_params["conv"], 158 | "features": model_params["features"], 159 | } 160 | else: 161 | raise ValueError("Invalid Gaussian splatting format") 162 | 163 | _detach_tensors_from_dict(splats) 164 | 165 | # Assuming only one camera 166 | for camera in colmap_project.cameras.values(): 167 | camera_matrix = torch.tensor( 168 | [ 169 | [camera.fx, 0, camera.cx], 170 | [0, camera.fy, camera.cy], 171 | [0, 0, 1], 172 | ] 173 | ) 174 | break 175 | 176 | camera_matrix[:2, :3] /= data_factor 177 | 178 | splats["camera_matrix"] = camera_matrix 179 | splats["colmap_project"] = colmap_project 180 | splats["colmap_dir"] = data_dir 181 | 182 | return splats 183 | 184 | 185 | def get_rpy_matrix(roll, pitch, yaw): 186 | roll_matrix = np.array( 187 | [ 188 | [1, 0, 0, 0], 189 | [0, np.cos(roll), -np.sin(roll), 0], 190 | [0, np.sin(roll), np.cos(roll), 0], 191 | [0, 0, 0, 1.0], 192 | ] 193 | ) 194 | 195 | pitch_matrix = np.array( 196 | [ 197 | [np.cos(pitch), 0, np.sin(pitch), 0], 198 | [0, 1, 0, 0], 199 | [-np.sin(pitch), 0, np.cos(pitch), 0], 200 | [0, 0, 0, 1.0], 201 | ] 202 | ) 203 | yaw_matrix = np.array( 204 | [ 205 | [np.cos(yaw), -np.sin(yaw), 0, 0], 206 | [np.sin(yaw), np.cos(yaw), 0, 0], 207 | [0, 0, 1, 0], 208 | [0, 0, 0, 1.0], 209 | ] 210 | ) 211 | 212 | return yaw_matrix @ pitch_matrix @ roll_matrix 213 | 214 | 215 | def get_viewmat_from_colmap_image(image): 216 | viewmat = torch.eye(4).float() # .to(device) 217 | viewmat[:3, :3] = torch.tensor(image.R()).float() # .to(device) 218 | viewmat[:3, 3] = torch.tensor(image.t).float() # .to(device) 219 | return viewmat 220 | 221 | 222 | def prune_by_gradients(splats): 223 | colmap_project = splats["colmap_project"] 224 | frame_idx = 0 225 | means = splats["means"] 226 | colors_dc = splats["features_dc"] 227 | colors_rest = splats["features_rest"] 228 | colors = torch.cat([colors_dc, colors_rest], dim=1) 229 | opacities = torch.sigmoid(splats["opacity"]) 230 | scales = torch.exp(splats["scaling"]) 231 | quats = splats["rotation"] 232 | 233 | K = splats["camera_matrix"] 234 | colors.requires_grad = True 235 | gaussian_grads = torch.zeros(colors.shape[0], device=colors.device) 236 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 237 | viewmat = get_viewmat_from_colmap_image(image) 238 | output, _, _ = rasterization( 239 | means, 240 | quats, 241 | scales, 242 | opacities, 243 | colors[:, 0, :], 244 | viewmats=viewmat[None], 245 | Ks=K[None], 246 | # sh_degree=3, 247 | width=K[0, 2] * 2, 248 | height=K[1, 2] * 2, 249 | ) 250 | frame_idx += 1 251 | pseudo_loss = ((output.detach() + 1 - output) ** 2).mean() 252 | pseudo_loss.backward() 253 | # print(colors.grad.shape) 254 | gaussian_grads += (colors.grad[:, 0]).norm(dim=[1]) 255 | colors.grad.zero_() 256 | 257 | mask = gaussian_grads > 0 258 | print("Total splats", len(gaussian_grads)) 259 | print("Pruned", (~mask).sum(), "splats") 260 | print("Remaining", mask.sum(), "splats") 261 | splats = splats.copy() 262 | splats["means"] = splats["means"][mask] 263 | splats["features_dc"] = splats["features_dc"][mask] 264 | splats["features_rest"] = splats["features_rest"][mask] 265 | splats["scaling"] = splats["scaling"][mask] 266 | splats["rotation"] = splats["rotation"][mask] 267 | splats["opacity"] = splats["opacity"][mask] 268 | if "features" in splats: 269 | splats["features"] = splats["features"][mask] 270 | 271 | return splats 272 | 273 | 274 | def create_checkerboard(width, height, size=64): 275 | checkerboard = np.zeros((height, width, 3), dtype=np.uint8) 276 | for y in range(0, height, size): 277 | for x in range(0, width, size): 278 | if (x // size + y // size) % 2 == 0: 279 | checkerboard[y : y + size, x : x + size] = 255 280 | else: 281 | checkerboard[y : y + size, x : x + size] = 128 282 | return checkerboard 283 | 284 | 285 | def torch_to_cv(tensor, permute=False): 286 | if permute: 287 | tensor = torch.clamp(tensor.permute(1, 2, 0), 0, 1).cpu().numpy() 288 | else: 289 | tensor = torch.clamp(tensor, 0, 1).cpu().numpy() 290 | return (tensor * 255).astype(np.uint8)[..., ::-1] 291 | 292 | def test_proper_pruning(splats, splats_after_pruning): 293 | colmap_project = splats["colmap_project"] 294 | frame_idx = 0 295 | means = splats["means"] 296 | colors_dc = splats["features_dc"] 297 | colors_rest = splats["features_rest"] 298 | colors = torch.cat([colors_dc, colors_rest], dim=1) 299 | opacities = torch.sigmoid(splats["opacity"]) 300 | scales = torch.exp(splats["scaling"]) 301 | quats = splats["rotation"] 302 | 303 | means_pruned = splats_after_pruning["means"] 304 | colors_dc_pruned = splats_after_pruning["features_dc"] 305 | colors_rest_pruned = splats_after_pruning["features_rest"] 306 | colors_pruned = torch.cat([colors_dc_pruned, colors_rest_pruned], dim=1) 307 | opacities_pruned = torch.sigmoid(splats_after_pruning["opacity"]) 308 | scales_pruned = torch.exp(splats_after_pruning["scaling"]) 309 | quats_pruned = splats_after_pruning["rotation"] 310 | 311 | K = splats["camera_matrix"] 312 | total_error = 0 313 | max_pixel_error = 0 314 | for image in sorted(colmap_project.images.values(), key=lambda x: x.name): 315 | viewmat = get_viewmat_from_colmap_image(image) 316 | output, _, _ = rasterization( 317 | means, 318 | quats, 319 | scales, 320 | opacities, 321 | colors, 322 | viewmats=viewmat[None], 323 | Ks=K[None], 324 | sh_degree=3, 325 | width=K[0, 2] * 2, 326 | height=K[1, 2] * 2, 327 | ) 328 | 329 | output_pruned, _, _ = rasterization( 330 | means_pruned, 331 | quats_pruned, 332 | scales_pruned, 333 | opacities_pruned, 334 | colors_pruned, 335 | viewmats=viewmat[None], 336 | Ks=K[None], 337 | sh_degree=3, 338 | width=K[0, 2] * 2, 339 | height=K[1, 2] * 2, 340 | ) 341 | 342 | total_error += torch.abs((output - output_pruned)).sum() 343 | max_pixel_error = max( 344 | max_pixel_error, torch.abs((output - output_pruned)).max() 345 | ) 346 | 347 | percentage_pruned = ( 348 | (len(splats["means"]) - len(splats_after_pruning["means"])) 349 | / len(splats["means"]) 350 | * 100 351 | ) 352 | 353 | assert max_pixel_error < 1 / ( 354 | 255 * 2 355 | ), "Max pixel error should be less than 1/(255*2), safety margin" 356 | print( 357 | "Report {}% pruned, max pixel error = {}, total pixel error = {}".format( 358 | percentage_pruned, max_pixel_error, total_error 359 | ) 360 | ) 361 | -------------------------------------------------------------------------------- /click_and_segment.py: -------------------------------------------------------------------------------- 1 | # Basic OpenCV viewer with sliders for rotation and translation. 2 | # Can be easily customizable to different use cases. 3 | import torch 4 | from gsplat import rasterization 5 | import cv2 6 | import tyro 7 | import numpy as np 8 | import json 9 | from typing import Literal 10 | import pycolmap_scene_manager as pycolmap 11 | import clip 12 | from lseg import LSegNet 13 | from utils import ( 14 | load_checkpoint, 15 | get_rpy_matrix, 16 | get_viewmat_from_colmap_image, 17 | prune_by_gradients, 18 | torch_to_cv, 19 | ) 20 | 21 | device = torch.device("cuda:0") 22 | 23 | 24 | def calculate_3d_to_2d(viewmat, fx, fy, cx, cy, position_homo): 25 | x, y, z, _ = position_homo 26 | x = x.item() 27 | y = y.item() 28 | z = z.item() 29 | x1 = viewmat[0, 0] * x + viewmat[0, 1] * y + viewmat[0, 2] * z + viewmat[0, 3] 30 | y1 = viewmat[1, 0] * x + viewmat[1, 1] * y + viewmat[1, 2] * z + viewmat[1, 3] 31 | z1 = viewmat[2, 0] * x + viewmat[2, 1] * y + viewmat[2, 2] * z + viewmat[2, 3] 32 | x = x1 * fx + cx * z1 33 | y = y1 * fy + cy * z1 34 | return int(x / z1), int(y / z1) 35 | 36 | 37 | class UIManager: 38 | def __init__(self, window_name: str): 39 | """ 40 | Manages OpenCV UI components like sliders and mouse callbacks. 41 | 42 | Args: 43 | window_name (str): Name of the OpenCV window. 44 | """ 45 | self.window_name = window_name 46 | self.params = { 47 | "Roll": 0, 48 | "Pitch": 0, 49 | "Yaw": 0, 50 | "X": 0, 51 | "Y": 0, 52 | "Z": 0, 53 | "Scaling": 100, 54 | } 55 | # self.positive_prompt_locations = [] 56 | # self.negative_prompt_locations = [] 57 | self._trigger_callback = lambda: None 58 | self._setup_ui() 59 | 60 | def _setup_ui(self): 61 | """ 62 | Sets up sliders and mouse callbacks for the OpenCV window. 63 | """ 64 | sliders = [ 65 | ("Roll", -180, 0, 180), 66 | ("Pitch", -180, 0, 180), 67 | ("Yaw", -180, 0, 180), 68 | ("X", -1000, 0, 1000), 69 | ("Y", -1000, 0, 1000), 70 | ("Z", -1000, 0, 1000), 71 | ("Scaling", 0, 100, 200), 72 | ] 73 | for slider_name, min_val, default_val, max_val in sliders: 74 | cv2.createTrackbar( 75 | slider_name, 76 | self.window_name, 77 | default_val, 78 | max_val, 79 | self._on_slider_change, 80 | ) 81 | cv2.setTrackbarMin(slider_name, self.window_name, min_val) 82 | 83 | cv2.setMouseCallback(self.window_name, self._on_mouse_event) 84 | 85 | def _on_slider_change(self, value): 86 | """ 87 | Callback for slider changes. 88 | """ 89 | for param in self.params: 90 | self.params[param] = cv2.getTrackbarPos(param, self.window_name) 91 | 92 | def _on_mouse_event(self, event, x, y, flags, param): 93 | """ 94 | Callback for mouse events. 95 | 96 | Args: 97 | event (int): OpenCV mouse event type. 98 | x (int): X-coordinate of the mouse event. 99 | y (int): Y-coordinate of the mouse event. 100 | flags (int): Event flags. 101 | param (Any): Additional parameters. 102 | """ 103 | ctrl_pressed = flags & cv2.EVENT_FLAG_CTRLKEY 104 | trigger = False 105 | xy = None 106 | if event == cv2.EVENT_LBUTTONDOWN: 107 | xy = x, y 108 | trigger = True 109 | elif event == cv2.EVENT_MBUTTONDOWN: 110 | xy = x, y 111 | trigger = True 112 | if trigger: 113 | self._trigger_callback(xy, event, ctrl_pressed) 114 | 115 | def _remove_prompt(self, locations, x, y): 116 | """ 117 | Removes a prompt close to the specified location. 118 | 119 | Args: 120 | locations (list): List of existing prompt locations. 121 | x (int): X-coordinate. 122 | y (int): Y-coordinate. 123 | """ 124 | del_idx = None 125 | for i, (x_i, y_i) in enumerate(locations): 126 | if abs(x_i - x) < 40 and abs(y_i - y) < 40: 127 | del_idx = i 128 | break 129 | if del_idx is not None: 130 | del locations[del_idx] 131 | 132 | def get_params(self): 133 | """ 134 | Returns the current slider values. 135 | 136 | Returns: 137 | dict: Dictionary of slider names and their values. 138 | """ 139 | return self.params 140 | 141 | def set_trigger_callback(self, callback): 142 | """ 143 | Sets the trigger callback function. 144 | 145 | Args: 146 | callback (function): The callback function. 147 | """ 148 | self._trigger_callback = callback 149 | 150 | 151 | def main( 152 | data_dir: str = "./data/garden", # colmap path 153 | checkpoint: str = "./data/garden/ckpts/ckpt_29999_rank0.pt", # checkpoint path, can generate from original 3DGS repo 154 | rasterizer: Literal[ 155 | "inria", "gsplat" 156 | ] = "gsplat", # Original or GSplat for checkpoints 157 | results_dir: str = "./results/garden", 158 | data_factor: int = 4, 159 | ): 160 | 161 | torch.set_default_device("cuda") 162 | 163 | splats = load_checkpoint( 164 | checkpoint, data_dir, rasterizer=rasterizer, data_factor=data_factor 165 | ) 166 | splats = prune_by_gradients(splats) 167 | torch.set_grad_enabled(False) 168 | 169 | means = splats["means"].float() 170 | opacities = splats["opacity"] 171 | quats = splats["rotation"] 172 | scales = splats["scaling"].float() 173 | 174 | opacities = torch.sigmoid(opacities) 175 | scales = torch.exp(scales) 176 | colors = torch.cat([splats["features_dc"], splats["features_rest"]], 1) 177 | features = torch.load(f"{results_dir}/features_lseg.pt") 178 | 179 | K = splats["camera_matrix"].float() 180 | 181 | width = int(K[0, 2] * 2) 182 | height = int(K[1, 2] * 2) 183 | 184 | cv2.namedWindow("Click and Segment", cv2.WINDOW_NORMAL) 185 | ui_manager = UIManager("Click and Segment") 186 | 187 | net = LSegNet( 188 | backbone="clip_vitl16_384", 189 | features=256, 190 | crop_size=480, 191 | arch_option=0, 192 | block_depth=0, 193 | activation="lrelu", 194 | ) 195 | # Load pre-trained weights 196 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 197 | net.eval() 198 | net.cuda() 199 | 200 | # Preprocess the text prompt 201 | clip_text_encoder = net.clip_pretrained.encode_text 202 | 203 | other_prompt = clip.tokenize(["other"]) 204 | other_prompt = other_prompt.cuda() 205 | other_prompt = clip_text_encoder(other_prompt) # N, 512, N - number of prompts 206 | other_prompt = torch.nn.functional.normalize(other_prompt, dim=1).float() 207 | 208 | mask_3d = None 209 | 210 | positions_3d_positives = [] 211 | positions_3d_negatives = [] 212 | 213 | positive_prompts = torch.zeros(0, 512).to(device) 214 | negative_prompts = other_prompt.to(device) 215 | 216 | def trigger_callback(xy, event, ctrl_pressed): 217 | if xy[0] >= width or xy[1] >= height: 218 | return 219 | params = ui_manager.get_params() 220 | 221 | nonlocal positive_prompts 222 | nonlocal negative_prompts 223 | 224 | roll = params["Roll"] 225 | pitch = params["Pitch"] 226 | yaw = params["Yaw"] 227 | 228 | roll_rad = np.deg2rad(roll) 229 | pitch_rad = np.deg2rad(pitch) 230 | yaw_rad = np.deg2rad(yaw) 231 | 232 | viewmat = ( 233 | torch.tensor(get_rpy_matrix(roll_rad, pitch_rad, yaw_rad)) 234 | .float() 235 | .to(device) 236 | ) 237 | viewmat[0, 3] = params["X"] / 100.0 238 | viewmat[1, 3] = params["Y"] / 100.0 239 | viewmat[2, 3] = params["Z"] / 100.0 240 | scaling = params["Scaling"] / 100.0 241 | output, alphas, meta = rasterization( 242 | means, 243 | quats, 244 | scales * scaling, 245 | opacities, 246 | features, 247 | viewmat[None], 248 | K[None], 249 | width=width, 250 | height=height, 251 | render_mode="RGB+D", 252 | ) 253 | 254 | output, depth = output[0, ..., :512], output[0, ..., 512] 255 | 256 | fx = K[0, 0] 257 | fy = K[1, 1] 258 | cx = K[0, 2] 259 | cy = K[1, 2] 260 | if not ctrl_pressed: 261 | if xy is not None: 262 | Z = depth[xy[1], xy[0]] 263 | XY = ( 264 | torch.tensor([(xy[0] - cx) / fx * Z, (xy[1] - cy) / fy * Z, Z, 1.0]) 265 | .float() 266 | .to(device) 267 | ) 268 | XY = XY.reshape(4, 1) 269 | XY_world = torch.inverse(viewmat) @ XY 270 | if event == cv2.EVENT_LBUTTONDOWN: 271 | positions_3d_positives.append(XY_world.cpu().numpy()) 272 | elif event == cv2.EVENT_MBUTTONDOWN: 273 | positions_3d_negatives.append(XY_world.cpu().numpy()) 274 | output = torch.nn.functional.normalize(output, dim=-1) 275 | 276 | positive_2d_position = [] 277 | negative_2d_position = [] 278 | 279 | for pos in positions_3d_positives: 280 | x, y = calculate_3d_to_2d(viewmat, fx, fy, cx, cy, pos) 281 | positive_2d_position.append((x, y)) 282 | 283 | for pos in positions_3d_negatives: 284 | x, y = calculate_3d_to_2d(viewmat, fx, fy, cx, cy, pos) 285 | negative_2d_position.append((x, y)) 286 | 287 | if not ctrl_pressed and event == cv2.EVENT_LBUTTONDOWN: 288 | positive_prompts = torch.cat([positive_prompts, output[xy[1], xy[0]][None]]) 289 | if not ctrl_pressed and event == cv2.EVENT_MBUTTONDOWN: 290 | negative_prompts = torch.cat([negative_prompts, output[xy[1], xy[0]][None]]) 291 | if ctrl_pressed and event == cv2.EVENT_LBUTTONDOWN: 292 | del_idx = None 293 | for i, (x_i, y_i) in enumerate(positive_2d_position): 294 | if abs(x_i - xy[0]) < 40 and abs(y_i - xy[1]) < 40: 295 | del_idx = i 296 | break 297 | if del_idx is not None: 298 | positive_prompts = torch.cat( 299 | [positive_prompts[:del_idx], positive_prompts[del_idx + 1 :]] 300 | ) 301 | del positions_3d_positives[del_idx] 302 | if ctrl_pressed and event == cv2.EVENT_MBUTTONDOWN: 303 | del_idx = None 304 | for i, (x_i, y_i) in enumerate(negative_2d_position): 305 | if abs(x_i - xy[0]) < 40 and abs(y_i - xy[1]) < 40: 306 | del_idx = i 307 | break 308 | if del_idx is not None: 309 | negative_prompts = torch.cat( 310 | [negative_prompts[: del_idx + 1], negative_prompts[del_idx + 2 :]] 311 | ) 312 | del positions_3d_negatives[del_idx] 313 | nonlocal mask_3d 314 | if not positions_3d_positives: 315 | mask_3d = None 316 | else: 317 | scores_pos = features @ positive_prompts.T # [N, P] 318 | scores_pos = scores_pos.max(dim=1) # [N] 319 | scores_neg = features @ negative_prompts.T # [N, P] 320 | scores_neg = scores_neg.max(dim=1) # [N] 321 | mask_3d = scores_pos.values > scores_neg.values 322 | 323 | ui_manager.set_trigger_callback(trigger_callback) 324 | 325 | while True: 326 | for image in splats["colmap_project"].images.values(): 327 | viewmat_cmap = get_viewmat_from_colmap_image(image) 328 | roll = ui_manager.params["Roll"] 329 | pitch = ui_manager.params["Pitch"] 330 | yaw = ui_manager.params["Yaw"] 331 | 332 | roll_rad = np.deg2rad(roll) 333 | pitch_rad = np.deg2rad(pitch) 334 | yaw_rad = np.deg2rad(yaw) 335 | 336 | viewmat = ( 337 | torch.tensor(get_rpy_matrix(roll_rad, pitch_rad, yaw_rad)) 338 | .float() 339 | .to(device) 340 | ) 341 | viewmat[0, 3] = ui_manager.params["X"] / 100.0 342 | viewmat[1, 3] = ui_manager.params["Y"] / 100.0 343 | viewmat[2, 3] = ui_manager.params["Z"] / 100.0 344 | scaling = ui_manager.params["Scaling"] / 100.0 345 | output, alphas, meta = rasterization( 346 | means, 347 | quats, 348 | scales * scaling, 349 | opacities, 350 | colors, 351 | viewmat[None], 352 | K[None], 353 | width=width, 354 | height=height, 355 | sh_degree=3, 356 | ) 357 | 358 | output_cv = torch_to_cv(output[0]) 359 | 360 | if mask_3d is not None: 361 | opacities_extracted = opacities.clone() 362 | opacities_deleted = opacities.clone() 363 | opacities_extracted[~mask_3d] = 0 364 | opacities_deleted[mask_3d] = 0 365 | else: 366 | opacities_extracted = opacities 367 | opacities_deleted = opacities * 0 368 | output, alphas, meta = rasterization( 369 | means, 370 | quats, 371 | scales * scaling, 372 | opacities_extracted, 373 | colors, 374 | viewmat_cmap[None], 375 | K[None], 376 | width=width, 377 | height=height, 378 | sh_degree=3, 379 | ) 380 | output_cv2 = torch_to_cv(output[0]) 381 | output, alphas, meta = rasterization( 382 | means, 383 | quats, 384 | scales * scaling, 385 | opacities_deleted, 386 | colors, 387 | viewmat_cmap[None], 388 | K[None], 389 | width=width, 390 | height=height, 391 | sh_degree=3, 392 | ) 393 | output_cv3 = torch_to_cv(output[0]) 394 | output_cv = cv2.hconcat([output_cv, output_cv2, output_cv3]) 395 | 396 | fx = K[0, 0] 397 | fy = K[1, 1] 398 | cx = K[0, 2] 399 | cy = K[1, 2] 400 | viewmat = viewmat.cpu().numpy() 401 | for pos in positions_3d_positives: 402 | x, y = calculate_3d_to_2d(viewmat, fx, fy, cx, cy, pos) 403 | cv2.circle(output_cv, (x, y), 10, (0, 255, 0), -1) 404 | for pos in positions_3d_negatives: 405 | x, y = calculate_3d_to_2d(viewmat, fx, fy, cx, cy, pos) 406 | cv2.circle(output_cv, (x, y), 10, (0, 0, 255), -1) 407 | cv2.imshow("Click and Segment", output_cv) 408 | key = cv2.waitKey(10) 409 | if key == ord("q"): 410 | break 411 | if key == ord("q"): 412 | break 413 | 414 | 415 | if __name__ == "__main__": 416 | tyro.cli(main) 417 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | # Basic OpenCV viewer with sliders for rotation and translation. 2 | # Can be easily customizable to different use cases. 3 | from dataclasses import dataclass 4 | from typing import Literal, Optional 5 | import torch 6 | from gsplat import rasterization 7 | import cv2 8 | from scipy.spatial.transform import Rotation as scipyR 9 | import pycolmap_scene_manager as pycolmap 10 | import warnings 11 | 12 | import numpy as np 13 | import json 14 | import tyro 15 | 16 | from utils import ( 17 | get_rpy_matrix, 18 | get_viewmat_from_colmap_image, 19 | prune_by_gradients, 20 | torch_to_cv, 21 | load_checkpoint, 22 | ) 23 | 24 | # Check if CUDA is available. Else raise an error. 25 | if not torch.cuda.is_available(): 26 | raise RuntimeError( 27 | "CUDA is not available. Please install the correct version of PyTorch with CUDA support." 28 | ) 29 | 30 | device = torch.device("cuda") 31 | torch.set_default_device("cuda") 32 | 33 | 34 | @dataclass 35 | class Args: 36 | checkpoint: str # Path to the 3DGS checkpoint file (.pth/.pt) to be visualized. 37 | data_dir: ( 38 | str # Path to the COLMAP project directory containing sparse reconstruction. 39 | ) 40 | format: Optional[Literal["inria", "gsplat", "ply"]] = ( 41 | "gsplat" # Format of the checkpoint: 'inria' (original 3DGS), 'gsplat', or 'ply'. 42 | ) 43 | rasterizer: Optional[Literal["inria", "gsplat"]] = ( 44 | None # [Deprecated] Use --format instead. Provided for backward compatibility. 45 | ) 46 | data_factor: int = 4 # Downscaling factor for the renderings. 47 | turntable: bool = False # Whether to use a turntable mode for the viewer. 48 | 49 | 50 | @dataclass 51 | class ViewerArgs: 52 | turntable: bool = False 53 | 54 | 55 | class Viewer: 56 | def __init__(self, splats, viewer_args): 57 | self.splats = None 58 | self.camera_matrix = None 59 | self.width = None 60 | self.height = None 61 | self.viewmat = None 62 | self.window_name = "GSplat Explorer" 63 | self._init_sliders() 64 | self._load_splats(splats) 65 | self.turntable = viewer_args.turntable 66 | self.mouse_down = False 67 | self.mouse_x = 0 68 | self.mouse_y = 0 69 | cv2.setMouseCallback(self.window_name, self.handle_mouse_event) 70 | 71 | def _load_splats(self, splats): 72 | K = splats["camera_matrix"].cuda() 73 | width = int(K[0, 2] * 2) 74 | height = int(K[1, 2] * 2) 75 | 76 | means = splats["means"].float() 77 | opacities = splats["opacity"] 78 | quats = splats["rotation"] 79 | scales = splats["scaling"].float() 80 | 81 | opacities = torch.sigmoid(opacities) 82 | scales = torch.exp(scales) 83 | colors = torch.cat([splats["features_dc"], splats["features_rest"]], 1) 84 | 85 | self.splats = splats 86 | self.camera_matrix = K 87 | self.width = width 88 | self.height = height 89 | self.means = means 90 | self.opacities = opacities 91 | self.quats = quats 92 | self.scales = scales 93 | self.colors = colors 94 | 95 | def _init_sliders(self): 96 | cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL) 97 | trackbars = { 98 | "Roll": (-180, 180), 99 | "Pitch": (-180, 180), 100 | "Yaw": (-180, 180), 101 | "X": (-1000, 1000), 102 | "Y": (-1000, 1000), 103 | "Z": (-1000, 1000), 104 | "Scaling": (0, 100), 105 | } 106 | 107 | for name, (min_val, max_val) in trackbars.items(): 108 | cv2.createTrackbar(name, self.window_name, 0, max_val, lambda x: None) 109 | cv2.setTrackbarMin(name, self.window_name, min_val) 110 | cv2.setTrackbarMax(name, self.window_name, max_val) 111 | 112 | cv2.setTrackbarPos( 113 | "Scaling", self.window_name, 100 114 | ) # Default value for scaling is 100 115 | 116 | def update_trackbars_from_viewmat(self, world_to_camera): 117 | # if torch tensor is passed, convert to numpy 118 | if isinstance(world_to_camera, torch.Tensor): 119 | world_to_camera = world_to_camera.cpu().numpy() 120 | r = scipyR.from_matrix(world_to_camera[:3, :3]) 121 | roll, pitch, yaw = r.as_euler("xyz") 122 | cv2.setTrackbarPos("Roll", self.window_name, np.rad2deg(roll).astype(int)) 123 | cv2.setTrackbarPos("Pitch", self.window_name, np.rad2deg(pitch).astype(int)) 124 | cv2.setTrackbarPos("Yaw", self.window_name, np.rad2deg(yaw).astype(int)) 125 | cv2.setTrackbarPos("X", self.window_name, int(world_to_camera[0, 3] * 100)) 126 | cv2.setTrackbarPos("Y", self.window_name, int(world_to_camera[1, 3] * 100)) 127 | cv2.setTrackbarPos("Z", self.window_name, int(world_to_camera[2, 3] * 100)) 128 | 129 | def get_special_viewmat(self, viewmat, side="top"): 130 | if isinstance(viewmat, torch.Tensor): 131 | viewmat = viewmat.cpu().numpy() 132 | if not self.turntable: 133 | warnings.warn("Top view is only available in turntable mode.") 134 | return viewmat 135 | world_to_pcd = np.eye(4) 136 | 137 | # Trick: just put the new axes in columns, done! 138 | world_to_pcd[:3, :3] = np.array( 139 | [ 140 | self.view_direction, 141 | np.cross(self.upvector, self.view_direction), 142 | self.upvector, 143 | ] 144 | ).T 145 | world_to_pcd[:3, 3] = self.center_point 146 | pcd_to_world = np.linalg.inv(world_to_pcd) 147 | 148 | world_to_camera = np.eye(4) 149 | if side == "top": 150 | world_to_camera[:3, :3] = np.array( 151 | [ 152 | [1, 0, 0], 153 | [0, -1, 0], 154 | [0, 0, -1], 155 | ] 156 | ).T 157 | elif side == "front": 158 | world_to_camera[:3, :3] = np.array( 159 | [ 160 | [1, 0, 0], 161 | [0, 0, 1], 162 | [0, -1, 0], 163 | ] 164 | ).T 165 | elif side == "right": 166 | world_to_camera[:3, :3] = np.array( 167 | [ 168 | [0, 0, -1], 169 | [1, 0, 0], 170 | [0, -1, 0], 171 | ] 172 | ).T 173 | else: 174 | warnings.warn(f"Unknown view type: {side}.") 175 | 176 | world_to_camera_before = viewmat @ world_to_pcd 177 | dist = np.linalg.norm(world_to_camera_before[:3, 3]) 178 | world_to_camera[:3, 3] = np.array([0, 0, dist]) 179 | 180 | # cam_point = world_to_camera @ pcd_to_world @ pcd_coord 181 | # cam_point = viewmat @ pcd_coord 182 | # viewmat = world_to_camera @ pcd_to_world 183 | viewmat = world_to_camera @ pcd_to_world 184 | viewmat = torch.tensor(viewmat).float().to(device) 185 | return viewmat 186 | 187 | def _get_viewmat_from_trackbars(self): 188 | roll = cv2.getTrackbarPos("Roll", self.window_name) 189 | pitch = cv2.getTrackbarPos("Pitch", self.window_name) 190 | yaw = cv2.getTrackbarPos("Yaw", self.window_name) 191 | 192 | roll_rad = np.deg2rad(roll) 193 | pitch_rad = np.deg2rad(pitch) 194 | yaw_rad = np.deg2rad(yaw) 195 | 196 | viewmat = ( 197 | torch.tensor(get_rpy_matrix(roll_rad, pitch_rad, yaw_rad)) 198 | .float() 199 | .to(device) 200 | ) 201 | 202 | viewmat[0, 3] = cv2.getTrackbarPos("X", self.window_name) / 100.0 203 | viewmat[1, 3] = cv2.getTrackbarPos("Y", self.window_name) / 100.0 204 | viewmat[2, 3] = cv2.getTrackbarPos("Z", self.window_name) / 100.0 205 | 206 | return viewmat 207 | 208 | def render_gaussians(self, viewmat, scaling, anaglyph=False): 209 | output, _, _ = rasterization( 210 | self.means, 211 | self.quats, 212 | self.scales * scaling, 213 | self.opacities, 214 | self.colors, 215 | viewmat[None], 216 | self.camera_matrix[None], 217 | width=self.width, 218 | height=self.height, 219 | sh_degree=3, 220 | ) 221 | if not anaglyph: 222 | return np.ascontiguousarray(torch_to_cv(output[0])) 223 | left = torch_to_cv(output[0]) 224 | viewmat_right_eye = viewmat.clone() 225 | viewmat_right_eye[0, 3] -= 0.05 # Offset for the right eye 226 | output, _, _ = rasterization( 227 | self.means, 228 | self.quats, 229 | self.scales * scaling, 230 | self.opacities, 231 | self.colors, 232 | viewmat_right_eye[None], 233 | self.camera_matrix[None], 234 | width=self.width, 235 | height=self.height, 236 | sh_degree=3, 237 | ) 238 | right = torch_to_cv(output[0]) 239 | left_copy = left.copy() 240 | right_copy = right.copy() 241 | left_copy[..., :2] = 0 # Set left eye's red and green channels to zero 242 | right_copy[..., -1] = 0 # Set right eye's blue channel to zero 243 | return ( 244 | left_copy + right_copy, 245 | np.ascontiguousarray(left_copy), 246 | np.ascontiguousarray(right_copy), 247 | ) 248 | 249 | def compute_world_frame(self): 250 | """ 251 | Compute the new world frame (center_point, upvector, view_direction, ortho_direction) 252 | based on the average camera positions and orientations. 253 | """ 254 | # Initialize vectors 255 | center_point = np.zeros(3, dtype=np.float32) 256 | upvector_sum = np.zeros(3, dtype=np.float32) 257 | view_direction_sum = np.zeros(3, dtype=np.float32) 258 | 259 | # Iterate over camera images to compute average position and orientation 260 | for image in self.splats["colmap_project"].images.values(): 261 | viewmat = get_viewmat_from_colmap_image(image) 262 | viewmat_np = viewmat.cpu().numpy() 263 | c2w = np.linalg.inv(viewmat_np) 264 | center_point += c2w[:3, 3].squeeze() # camera position 265 | upvector_sum += -c2w[:3, 1].squeeze() # up direction 266 | view_direction_sum += c2w[:3, 2].squeeze() # viewing direction 267 | 268 | # Average position and orientation vectors 269 | num_images = len(self.splats["colmap_project"].images) 270 | center_point /= num_images 271 | upvector = upvector_sum / np.linalg.norm(upvector_sum) 272 | view_direction = view_direction_sum / np.linalg.norm(view_direction_sum) 273 | 274 | # Make view_direction orthogonal to upvector 275 | view_direction -= upvector * np.dot(view_direction, upvector) 276 | view_direction /= np.linalg.norm(view_direction) 277 | 278 | # Compute the orthogonal direction (right vector) 279 | ortho_direction = np.cross(upvector, view_direction) 280 | ortho_direction /= np.linalg.norm(ortho_direction) 281 | 282 | # Optionally override center_point with the mean of your 3D data 283 | center_point = torch.mean(self.means, dim=0).cpu().numpy() 284 | 285 | # Save the computed frame vectors as attributes 286 | self.center_point = center_point 287 | self.upvector = upvector 288 | self.view_direction = view_direction 289 | self.ortho_direction = ortho_direction 290 | 291 | def visualize_world_frame(self, output_cv, viewmat): 292 | viewmat_np = viewmat.cpu().numpy() 293 | T = np.eye(4) 294 | z_axis = self.upvector 295 | x_axis = self.view_direction 296 | y_axis = np.cross(z_axis, x_axis) 297 | T[:3, :3] = np.array([x_axis, y_axis, z_axis]).T 298 | T[:3, 3] = self.center_point 299 | T = viewmat_np @ T 300 | rvec = cv2.Rodrigues(T[:3, :3])[0] 301 | tvec = T[:3, 3] 302 | cv2.drawFrameAxes( 303 | output_cv, 304 | self.camera_matrix.cpu().numpy(), 305 | None, 306 | rvec, 307 | tvec, 308 | length=1, 309 | thickness=2, 310 | ) 311 | 312 | def run(self): 313 | """Run the interactive Gaussian Splat viewer loop once until exit.""" 314 | self.show_anaglyph = False 315 | self.compute_world_frame() 316 | 317 | while True: 318 | scaling = cv2.getTrackbarPos("Scaling", self.window_name) / 100.0 319 | viewmat = self._get_viewmat_from_trackbars() 320 | 321 | if self.show_anaglyph: 322 | output_cv, _, _ = self.render_gaussians(viewmat, scaling, anaglyph=True) 323 | else: 324 | output_cv = self.render_gaussians(viewmat, scaling) 325 | 326 | if self.turntable: 327 | self.visualize_world_frame(output_cv, viewmat) 328 | 329 | cv2.imshow(self.window_name, output_cv) 330 | full_key = cv2.waitKeyEx(1) 331 | key = full_key & 0xFF 332 | 333 | should_continue = self.handle_key_press(key, {"viewmat": viewmat}) 334 | if not should_continue: 335 | break 336 | 337 | cv2.destroyAllWindows() 338 | 339 | def handle_key_press(self, key, data): 340 | viewmat = data["viewmat"] 341 | if key == ord("q") or key == 27: 342 | return False # Exit the viewer 343 | if key == ord("3"): 344 | self.show_anaglyph = not self.show_anaglyph 345 | if key in [ord("w"), ord("a"), ord("s"), ord("d")]: 346 | # Modify viewmat and sync UI 347 | delta = 0.1 348 | if key == ord("w"): 349 | viewmat[2, 3] -= delta 350 | elif key == ord("s"): 351 | viewmat[2, 3] += delta 352 | elif key == ord("a"): 353 | viewmat[0, 3] += delta 354 | elif key == ord("d"): 355 | viewmat[0, 3] -= delta 356 | self.update_trackbars_from_viewmat(viewmat) 357 | if key in [ord("7")]: 358 | viewmat = self.get_special_viewmat(viewmat, side="top") 359 | self.update_trackbars_from_viewmat(viewmat) 360 | elif key in [ord("8")]: 361 | viewmat = self.get_special_viewmat(viewmat, side="front") 362 | self.update_trackbars_from_viewmat(viewmat) 363 | elif key in [ord("9")]: 364 | viewmat = self.get_special_viewmat(viewmat, side="right") 365 | self.update_trackbars_from_viewmat(viewmat) 366 | return True # Continue the viewer loop 367 | 368 | def handle_mouse_event(self, event, x, y, flags, param): 369 | if not self.turntable: 370 | return 371 | if event == cv2.EVENT_LBUTTONDOWN: 372 | self.mouse_down = True 373 | self.mouse_x = x 374 | self.mouse_y = y 375 | self.view_mat_progress = self._get_viewmat_from_trackbars() 376 | self.is_alt_pressed = flags & cv2.EVENT_FLAG_ALTKEY 377 | self.is_shift_pressed = flags & cv2.EVENT_FLAG_SHIFTKEY 378 | self.is_ctrl_pressed = flags & cv2.EVENT_FLAG_CTRLKEY 379 | elif event == cv2.EVENT_LBUTTONUP: 380 | self.mouse_down = False 381 | elif event == cv2.EVENT_MOUSEMOVE and self.mouse_down: 382 | dx = x - self.mouse_x 383 | dy = y - self.mouse_y 384 | 385 | if self.is_ctrl_pressed: 386 | viewmat = self._get_viewmat_from_trackbars() 387 | viewmat[2, 3] += dy / self.height * 10 # Move camera forward/backward 388 | self.update_trackbars_from_viewmat(viewmat) 389 | self.mouse_x = x 390 | self.mouse_y = y 391 | return 392 | 393 | # viewmat = self._get_viewmat_from_trackbars() 394 | viewmat = self.view_mat_progress.clone() 395 | viewmat_np = viewmat.cpu().numpy() # w2c 396 | world_to_pcd = np.eye(4) 397 | world_to_pcd[:3, :3] = np.array( 398 | [ 399 | self.view_direction, 400 | np.cross(self.upvector, self.view_direction), 401 | self.upvector, 402 | ] 403 | ).T 404 | world_to_pcd[:3, 3] = self.center_point 405 | pcd_to_world = np.linalg.inv(world_to_pcd) 406 | # camera_coordinates = viewmat @ world_to_pcd @ transform @ pcd_to_world @ pcd_coods 407 | # camera_coordinates = viewmat_new @ pcd_coords 408 | # ie. viewmat_new = viewmat @ world_to_pcd @ transform @ pcd_to_world 409 | transform = np.eye(4) 410 | height, width = self.height, self.width 411 | if self.is_shift_pressed: 412 | viewmat_np[0, 3] += dx / width * 10 413 | viewmat_np[1, 3] += dy / height * 10 414 | else: 415 | # Rotation of the world 416 | c2pcd = np.linalg.inv(viewmat_np) 417 | c2w = pcd_to_world @ c2pcd 418 | direction_with_respect_to_world = -c2w[:3, 2] 419 | lambda_ = -c2w[2, 3] / direction_with_respect_to_world[2] 420 | intersection_point = ( 421 | c2w[:3, 3] + lambda_ * direction_with_respect_to_world 422 | ) 423 | 424 | world_to_intersection = np.eye(4) 425 | world_to_intersection[:3, 3] = -intersection_point 426 | intersection_to_world = np.linalg.inv(world_to_intersection) 427 | transform = get_rpy_matrix(0, 0, dx / width * 10) 428 | if self.is_alt_pressed: 429 | world_to_intersection = np.eye(4) 430 | intersection_to_world = np.eye(4) 431 | 432 | # rotation of camera 433 | viewmat_np[:3, :3] = ( 434 | get_rpy_matrix(dy / height * 10, 0, 0)[:3, :3] @ viewmat_np[:3, :3] 435 | ) 436 | 437 | # rotation of the world 438 | viewmat_np = ( 439 | viewmat_np 440 | @ world_to_pcd 441 | @ intersection_to_world 442 | @ transform 443 | @ world_to_intersection 444 | @ pcd_to_world 445 | ) 446 | 447 | 448 | self.update_trackbars_from_viewmat( 449 | torch.tensor(viewmat_np).float().to(device) 450 | ) 451 | 452 | 453 | def main(args: Args): 454 | format = args.format or args.rasterizer 455 | if args.rasterizer: 456 | warnings.warn( 457 | "`rasterizer` is deprecated. Use `format` instead.", DeprecationWarning 458 | ) 459 | if not format: 460 | raise ValueError("Must specify --format or the deprecated --rasterizer") 461 | 462 | splats = load_checkpoint(args.checkpoint, args.data_dir, format, args.data_factor) 463 | splats = prune_by_gradients(splats) 464 | 465 | if args.turntable: 466 | viewer_args = ViewerArgs(turntable=True) 467 | else: 468 | viewer_args = ViewerArgs(turntable=False) 469 | 470 | viewer = Viewer(splats, viewer_args=viewer_args) 471 | viewer.run() 472 | 473 | 474 | if __name__ == "__main__": 475 | args = tyro.cli(Args) 476 | main(args) 477 | -------------------------------------------------------------------------------- /viewer_with_llm.py: -------------------------------------------------------------------------------- 1 | # Basic OpenCV viewer with sliders for rotation and translation. 2 | # Can be easily customizable to different use cases. 3 | from dataclasses import dataclass 4 | from typing import Literal, Optional 5 | import torch 6 | from gsplat import rasterization 7 | import cv2 8 | import clip 9 | from scipy.spatial.transform import Rotation as scipyR 10 | import pycolmap_scene_manager as pycolmap 11 | import warnings 12 | from torchvision.transforms import functional as TF 13 | 14 | import numpy as np 15 | import json 16 | import tyro 17 | from lseg import LSegNet 18 | 19 | from utils import ( 20 | get_rpy_matrix, 21 | get_viewmat_from_colmap_image, 22 | prune_by_gradients, 23 | torch_to_cv, 24 | load_checkpoint, 25 | ) 26 | 27 | # Check if CUDA is available. Else raise an error. 28 | if not torch.cuda.is_available(): 29 | raise RuntimeError( 30 | "CUDA is not available. Please install the correct version of PyTorch with CUDA support." 31 | ) 32 | 33 | device = torch.device("cuda") 34 | torch.set_default_device("cuda") 35 | 36 | from transformers import AutoModelForCausalLM, AutoTokenizer 37 | import torch 38 | 39 | from transformers import pipeline 40 | 41 | 42 | 43 | from typing import List, Dict, Optional 44 | import json 45 | import warnings 46 | 47 | def get_mask3d_lseg(splats, features, prompt, neg_prompt, threshold=None): 48 | 49 | net = LSegNet( 50 | backbone="clip_vitl16_384", 51 | features=256, 52 | crop_size=480, 53 | arch_option=0, 54 | block_depth=0, 55 | activation="lrelu", 56 | ) 57 | # Load pre-trained weights 58 | net.load_state_dict(torch.load("./checkpoints/lseg_minimal_e200.ckpt")) 59 | net.eval() 60 | net.cuda() 61 | 62 | # Preprocess the text prompt 63 | clip_text_encoder = net.clip_pretrained.encode_text 64 | 65 | positive_prompts_length = len(prompt.split(";")) 66 | 67 | prompts = prompt.split(";") + neg_prompt.split(";") 68 | 69 | prompts = clip.tokenize(prompts) 70 | prompts = prompts.cuda() 71 | 72 | text_feat = clip_text_encoder(prompts) # N, 512, N - number of prompts 73 | text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) 74 | 75 | features = torch.nn.functional.normalize(features, dim=1) 76 | score = features @ text_feat_norm.float().T 77 | mask_3d = score[:, :positive_prompts_length].max(dim=1)[0] > score[:, positive_prompts_length:].max(dim=1)[0] 78 | if threshold is not None: 79 | mask_3d = mask_3d & (score[:, 0] > threshold) 80 | mask_3d_inv = ~mask_3d 81 | 82 | return mask_3d, mask_3d_inv 83 | 84 | COLOR_TO_RGB = { 85 | "red": [1.0, 0.0, 0.0], 86 | "green": [0.0, 1.0, 0.0], 87 | "blue": [0.0, 0.0, 1.0], 88 | "yellow": [1.0, 1.0, 0.0], 89 | "cyan": [0.0, 1.0, 1.0], 90 | "magenta": [1.0, 0.0, 1.0], 91 | "white": [1.0, 1.0, 1.0], 92 | "black": [0.0, 0.0, 0.0], 93 | } 94 | 95 | class Assistant: 96 | def __init__(self): 97 | model_id = "mistralai/Mistral-7B-Instruct-v0.3" 98 | tokenizer = AutoTokenizer.from_pretrained(model_id) 99 | model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") 100 | self.model = model 101 | self.tokenizer = tokenizer 102 | 103 | # Emulating system prompt with a series of example interactions 104 | self.system_prompt = [ 105 | { 106 | "role": "user", 107 | "content": """ 108 | You are a 3DGS viewer assistant. You understand commands like changing the view (top, front, right), 109 | segment 3d gaussians, change color, and exiting the application. 110 | Always respond in strict JSON format: {"request": "", "side": ""} or {"request": "exit"}. 111 | If the command is unclear, respond with {"request": "unknown"}. 112 | """, 113 | }, 114 | { 115 | "role": "assistant", 116 | "content": """ 117 | Ok! 118 | """, 119 | }, 120 | { 121 | "role": "user", 122 | "content": """ 123 | Can you show me the front view? 124 | """, 125 | }, 126 | { 127 | "role": "assistant", 128 | "content": """ 129 | {"request": "change_view", "side": "front"} 130 | """, 131 | }, 132 | { 133 | "role": "user", 134 | "content": """ 135 | Can you show me the right view? 136 | """, 137 | }, 138 | { 139 | "role": "assistant", 140 | "content": """ 141 | {"request": "change_view", "side": "right"} 142 | """ 143 | }, 144 | { 145 | "role": "user", 146 | "content": """ 147 | Please change view to the top view. 148 | """, 149 | }, 150 | { 151 | "role": "assistant", 152 | "content": """ 153 | {"request": "change_view", "side": "top"} 154 | """, 155 | }, 156 | { 157 | "role": "user", 158 | "content": """ 159 | Abracadabra 160 | """, 161 | }, 162 | { 163 | "role": "assistant", 164 | "content": """ 165 | {"request": "unknown"} 166 | """, 167 | }, 168 | { 169 | "role": "user", 170 | "content": """ 171 | Who are you? 172 | """, 173 | }, 174 | { 175 | "role": "assistant", 176 | "content": """ 177 | {"request": "unknown"} 178 | """, 179 | }, 180 | { 181 | "role": "user", 182 | "content": """ 183 | Can you exit? 184 | """, 185 | }, 186 | { 187 | "role": "assistant", 188 | "content": """ 189 | {"request": "exit", "message": "Goodbye!"} 190 | """, 191 | }, 192 | { 193 | "role": "user", 194 | "content": """ 195 | Bye, please quit. 196 | """, 197 | }, 198 | { 199 | "role": "assistant", 200 | "content": """ 201 | {"request": "exit", "message": "Goodbye!"} 202 | """, 203 | }, 204 | { 205 | "role": "user", 206 | "content": """ 207 | Can you segment 3D Gaussians with the prompt "car" and negative prompt "tree"? 208 | """, 209 | }, 210 | { 211 | "role": "assistant", 212 | "content": """ 213 | {"request": "segment", "prompt": "car", "neg_prompt": "tree"} 214 | """, 215 | }, 216 | { 217 | "role": "user", 218 | "content": """ 219 | Can you segment 3D Gaussians with the prompt "table"? 220 | """, 221 | }, 222 | { 223 | "role": "assistant", 224 | "content": """ 225 | {"request": "segment", "prompt": "table", "neg_prompt": "none"} 226 | """, 227 | }, 228 | { 229 | "role": "user", 230 | "content": """ 231 | Can you segment 3D Gaussians containing table and exclude plant? 232 | """, 233 | }, 234 | { 235 | "role": "assistant", 236 | "content": """ 237 | {"request": "segment", "prompt": "table", "neg_prompt": "plant"} 238 | """, 239 | }, 240 | { 241 | "role": "user", 242 | "content": """Can you reset segmentation?""", 243 | }, 244 | { 245 | "role": "assistant", 246 | "content": """ 247 | {"request": "reset_segmentation"} 248 | """, 249 | }, 250 | { 251 | "role": "user", 252 | "content": "Change the color of grass to red.", 253 | }, { 254 | "role": "assistant", 255 | "content": """ 256 | {"request": "change_color", "object": "grass", "color": "red"} 257 | """, 258 | }, 259 | { 260 | "role": "user", 261 | "content": "Change the color of table to blue.", 262 | }, { 263 | "role": "assistant", 264 | "content": """ 265 | {"request": "change_color", "object": "table", "color": "blue"} 266 | """, 267 | }, { 268 | "role": "user", 269 | "content": "Reset the color of all objects.", 270 | }, { 271 | "role": "assistant", 272 | "content": """ 273 | {"request": "reset_color"} 274 | """, 275 | } 276 | ] 277 | self.pipeline = pipeline( 278 | "text-generation", 279 | model=self.model, 280 | tokenizer=self.tokenizer, 281 | torch_dtype=torch.bfloat16, 282 | device_map="auto",) 283 | 284 | def ask(self, query: str, max_new_tokens: int = 512, temperature: float = 0.7) -> Optional[Dict]: 285 | if query.startswith("`"): 286 | query = query[1:] 287 | output = self.pipeline(self.system_prompt + [{'role': 'user', 'content': query}], 288 | max_new_tokens=200, 289 | do_sample=True, 290 | temperature=temperature, 291 | ) 292 | response = output[0]['generated_text'][-1]["content"].strip() 293 | # Try to extract JSON from response 294 | try: 295 | json_start = response.find("{") 296 | json_end = response.rfind("}") + 1 297 | response_str = response[json_start:json_end] 298 | parsed_response = json.loads(response_str) 299 | return parsed_response 300 | except json.JSONDecodeError as e: 301 | warnings.warn(f"Failed to parse JSON from response: {e}\nResponse:\n{response}") 302 | return None 303 | 304 | def __call__(self, query: str, max_new_tokens: int = 512, temperature: float = 0.7) -> Optional[Dict]: 305 | """ 306 | Call the assistant with a query. 307 | 308 | Args: 309 | query: User's question or command. 310 | max_new_tokens: Maximum number of tokens to generate. 311 | temperature: Sampling temperature. 312 | 313 | Returns: 314 | Parsed JSON response as dict, or None if parsing fails. 315 | """ 316 | return self.ask(query, max_new_tokens, temperature) 317 | 318 | 319 | @dataclass 320 | class Args: 321 | checkpoint: str # Path to the 3DGS checkpoint file (.pth/.pt) to be visualized. 322 | data_dir: ( 323 | str # Path to the COLMAP project directory containing sparse reconstruction. 324 | ) 325 | format: Optional[Literal["inria", "gsplat", "ply"]] = ( 326 | "gsplat" # Format of the checkpoint: 'inria' (original 3DGS), 'gsplat', or 'ply'. 327 | ) 328 | rasterizer: Optional[Literal["inria", "gsplat"]] = ( 329 | None # [Deprecated] Use --format instead. Provided for backward compatibility. 330 | ) 331 | data_factor: int = 4 # Downscaling factor for the renderings. 332 | turntable: bool = True # Whether to use a turntable mode for the viewer. 333 | lseg_checkpoint: Optional[str] = None # Path to the LSEG checkpoint for segmentation, if available. 334 | 335 | 336 | @dataclass 337 | class ViewerArgs: 338 | turntable: bool = False 339 | 340 | from viewer import Viewer 341 | 342 | class ViewerWithAssistant(Viewer): 343 | def __init__(self, splats, viewer_args): 344 | super().__init__(splats, viewer_args) 345 | self.assistant = Assistant() 346 | self.current_view = None 347 | self.assistant_mode = False 348 | self.user_query = "" 349 | self.mask3d = None 350 | print(splats.keys()) 351 | self.opacities_backup = torch.sigmoid(splats["opacity"].clone().detach()) 352 | self.colors_backup = self.colors.clone().detach() 353 | def handle_key_press(self, key, data): 354 | if key == ord("`"): 355 | self.assistant_mode = True 356 | if not self.assistant_mode: 357 | should_continue = super().handle_key_press(key, data) 358 | return should_continue 359 | # if key == ord("q") or key == 27: 360 | # return False 361 | # Check if the key is backspace 362 | if key == ord("\b") or key == 8: # Backspace key 363 | if len(self.user_query) > 1: 364 | # Remove the last character from the assistant text 365 | self.user_query = self.user_query[:-1] 366 | # Check if the key is a printable character 367 | if 32 <= key <= 126: 368 | # Convert the key to a character and append it to the assistant text 369 | self.user_query += chr(key) 370 | elif key == ord("\n") or key == ord("\r"): 371 | # Process the assistant text when Enter is pressed 372 | json_response = self.assistant(self.user_query) 373 | print("json_response:", json_response) 374 | if json_response is None: 375 | warnings.warn( 376 | f"Failed to parse assistant response as JSON: {self.user_query}" 377 | ) 378 | self.user_query = "" 379 | return True 380 | if "request" not in json_response: 381 | warnings.warn( 382 | f"Invalid assistant response: {json_response}. Expected 'request' key." 383 | ) 384 | self.user_query = "" 385 | return True 386 | if json_response["request"] == "change_view": 387 | side = json_response.get("side", "top") 388 | viewmat = self._get_viewmat_from_trackbars() 389 | self.current_view = self.get_special_viewmat(viewmat, side) 390 | self.update_trackbars_from_viewmat(self.current_view) 391 | if json_response["request"] == "exit": 392 | print(json_response.get("message", "Exiting...")) 393 | return False 394 | if json_response["request"] == "segment": 395 | # reset opacities 396 | self.opacities = self.opacities_backup.clone().detach() 397 | prompt = json_response.get("prompt", "") 398 | neg_prompt = json_response.get("neg_prompt", "none") 399 | if not prompt: 400 | warnings.warn("No prompt provided for segmentation.") 401 | if neg_prompt == "" or neg_prompt == "none": 402 | neg_prompt = "other" 403 | else: 404 | neg_prompt = neg_prompt + ";other" 405 | features = self.splats["lseg"] 406 | mask3d, mask3d_inv = get_mask3d_lseg( 407 | self.splats, 408 | features, 409 | prompt, 410 | neg_prompt, 411 | # threshold=0.5, 412 | ) 413 | self.opacities[~mask3d] = 0.0 414 | if json_response["request"] == "reset_segmentation": 415 | # reset opacities 416 | self.opacities = self.opacities_backup.clone().detach() 417 | self.mask3d = None 418 | print("Segmentation reset.") 419 | if json_response["request"] == "change_color": 420 | object_name = json_response.get("object", "") 421 | color = json_response.get("color", "white") 422 | if color not in COLOR_TO_RGB: 423 | warnings.warn(f"Color '{color}' is not recognized. Please change COLOR_TO_RGB dictionary.") 424 | else: 425 | features = self.splats["lseg"] 426 | mask3d, mask3d_inv = get_mask3d_lseg( 427 | self.splats, 428 | features, 429 | prompt=object_name, 430 | neg_prompt="other", 431 | # threshold=0.5, 432 | ) 433 | colors = self.colors_backup[mask3d, 0, :].clone().detach() * 0.2820947917738781 + 0.5 434 | grays = TF.rgb_to_grayscale(colors.permute(1,0).reshape(1,3,-1,1))[0,0] 435 | self.colors[mask3d,0,:] = (torch.tensor(COLOR_TO_RGB[color], device=device)*grays[:,0:1] - 0.5) / 0.2820947917738781 436 | if json_response["request"] == "reset_color": 437 | self.mask3d = None 438 | self.colors = self.colors_backup.clone().detach() 439 | if json_response["request"] == "unknown": 440 | warnings.warn( 441 | f"Assistant response is unknown: {json_response}. Please try again." 442 | ) 443 | self.user_query = "" 444 | self.assistant_mode = False 445 | return True 446 | 447 | def render_gaussians(self, viewmat, scaling, anaglyph=False): 448 | outputs = super().render_gaussians(viewmat, scaling, anaglyph) 449 | if not self.assistant_mode: 450 | return outputs 451 | if isinstance(outputs, tuple): 452 | output_cv, left, right = outputs 453 | # Render the assistant text on the output image 454 | cv2.putText( 455 | output_cv, 456 | self.user_query, 457 | (10, 30), 458 | cv2.FONT_HERSHEY_SIMPLEX, 459 | 0.7, 460 | (255, 255, 255), 461 | 2, 462 | ) 463 | return ( 464 | np.ascontiguousarray(output_cv), 465 | np.ascontiguousarray(left), 466 | np.ascontiguousarray(right), 467 | ) 468 | else: 469 | output_cv = outputs 470 | # Render the assistant text on the output image 471 | cv2.putText( 472 | output_cv, 473 | self.user_query, 474 | (10, 30), 475 | cv2.FONT_HERSHEY_SIMPLEX, 476 | 0.7, 477 | (255, 255, 255), 478 | 2, 479 | ) 480 | return np.ascontiguousarray(output_cv) 481 | 482 | 483 | 484 | def main(args: Args): 485 | format = args.format or args.rasterizer 486 | if args.rasterizer: 487 | warnings.warn( 488 | "`rasterizer` is deprecated. Use `format` instead.", DeprecationWarning 489 | ) 490 | if not format: 491 | raise ValueError("Must specify --format or the deprecated --rasterizer") 492 | 493 | splats = load_checkpoint(args.checkpoint, args.data_dir, format, args.data_factor) 494 | splats = prune_by_gradients(splats) 495 | if args.lseg_checkpoint: 496 | splats["lseg"] = torch.load(args.lseg_checkpoint, map_location=device) 497 | print("splats['lseg'].shape:", splats["lseg"].shape) 498 | 499 | if args.turntable: 500 | viewer_args = ViewerArgs(turntable=True) 501 | else: 502 | viewer_args = ViewerArgs(turntable=False) 503 | 504 | viewer = ViewerWithAssistant(splats, viewer_args=viewer_args) 505 | viewer.run() 506 | 507 | 508 | if __name__ == "__main__": 509 | args = tyro.cli(Args) 510 | main(args) 511 | -------------------------------------------------------------------------------- /f3dgs/datasets/colmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional 3 | 4 | import cv2 5 | import imageio.v2 as imageio 6 | import numpy as np 7 | import torch 8 | from pycolmap_scene_manager import SceneManager 9 | 10 | from .normalize import ( 11 | align_principle_axes, 12 | similarity_from_cameras, 13 | transform_cameras, 14 | transform_points, 15 | ) 16 | 17 | 18 | def _get_rel_paths(path_dir: str) -> List[str]: 19 | """Recursively get relative paths of files in a directory.""" 20 | paths = [] 21 | for dp, dn, fn in os.walk(path_dir): 22 | for f in fn: 23 | paths.append(os.path.relpath(os.path.join(dp, f), path_dir)) 24 | return paths 25 | 26 | 27 | class Parser: 28 | """COLMAP parser.""" 29 | 30 | def __init__( 31 | self, 32 | data_dir: str, 33 | factor: int = 1, 34 | normalize: bool = False, 35 | test_every: int = 8, 36 | ): 37 | self.data_dir = data_dir 38 | self.factor = factor 39 | self.normalize = normalize 40 | self.test_every = test_every 41 | 42 | colmap_dir = os.path.join(data_dir, "sparse/0/") 43 | if not os.path.exists(colmap_dir): 44 | colmap_dir = os.path.join(data_dir, "sparse") 45 | assert os.path.exists( 46 | colmap_dir 47 | ), f"COLMAP directory {colmap_dir} does not exist." 48 | 49 | manager = SceneManager(colmap_dir) 50 | manager.load_cameras() 51 | manager.load_images() 52 | manager.load_points3D() 53 | 54 | # Extract extrinsic matrices in world-to-camera format. 55 | imdata = manager.images 56 | w2c_mats = [] 57 | camera_ids = [] 58 | Ks_dict = dict() 59 | params_dict = dict() 60 | imsize_dict = dict() # width, height 61 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4) 62 | for k in imdata: 63 | im = imdata[k] 64 | rot = im.R() 65 | trans = im.tvec.reshape(3, 1) 66 | w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) 67 | w2c_mats.append(w2c) 68 | 69 | # support different camera intrinsics 70 | camera_id = im.camera_id 71 | camera_ids.append(camera_id) 72 | 73 | # camera intrinsics 74 | cam = manager.cameras[camera_id] 75 | fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy 76 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 77 | K[:2, :] /= factor 78 | Ks_dict[camera_id] = K 79 | 80 | # Get distortion parameters. 81 | type_ = cam.camera_type 82 | if type_ == 0 or type_ == "SIMPLE_PINHOLE": 83 | params = np.empty(0, dtype=np.float32) 84 | camtype = "perspective" 85 | elif type_ == 1 or type_ == "PINHOLE": 86 | params = np.empty(0, dtype=np.float32) 87 | camtype = "perspective" 88 | if type_ == 2 or type_ == "SIMPLE_RADIAL": 89 | params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32) 90 | camtype = "perspective" 91 | elif type_ == 3 or type_ == "RADIAL": 92 | params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32) 93 | camtype = "perspective" 94 | elif type_ == 4 or type_ == "OPENCV": 95 | params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32) 96 | camtype = "perspective" 97 | elif type_ == 5 or type_ == "OPENCV_FISHEYE": 98 | params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32) 99 | camtype = "fisheye" 100 | assert ( 101 | camtype == "perspective" 102 | ), f"Only support perspective camera model, got {type_}" 103 | 104 | params_dict[camera_id] = params 105 | 106 | # image size 107 | imsize_dict[camera_id] = (cam.width // factor, cam.height // factor) 108 | 109 | print( 110 | f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras." 111 | ) 112 | 113 | if len(imdata) == 0: 114 | raise ValueError("No images found in COLMAP.") 115 | if not (type_ == 0 or type_ == 1): 116 | print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.") 117 | 118 | w2c_mats = np.stack(w2c_mats, axis=0) 119 | 120 | # Convert extrinsics to camera-to-world. 121 | camtoworlds = np.linalg.inv(w2c_mats) 122 | 123 | # Image names from COLMAP. No need for permuting the poses according to 124 | # image names anymore. 125 | image_names = [imdata[k].name for k in imdata] 126 | 127 | # Previous Nerf results were generated with images sorted by filename, 128 | # ensure metrics are reported on the same test set. 129 | inds = np.argsort(image_names) 130 | image_names = [image_names[i] for i in inds] 131 | camtoworlds = camtoworlds[inds] 132 | camera_ids = [camera_ids[i] for i in inds] 133 | 134 | # Load images. 135 | if factor > 1: 136 | image_dir_suffix = f"_{factor}" 137 | else: 138 | image_dir_suffix = "" 139 | colmap_image_dir = os.path.join(data_dir, "images") 140 | image_dir = os.path.join(data_dir, "images" + image_dir_suffix) 141 | for d in [image_dir, colmap_image_dir]: 142 | if not os.path.exists(d): 143 | raise ValueError(f"Image folder {d} does not exist.") 144 | 145 | # Downsampled images may have different names vs images used for COLMAP, 146 | # so we need to map between the two sorted lists of files. 147 | colmap_files = sorted(_get_rel_paths(colmap_image_dir)) 148 | image_files = sorted(_get_rel_paths(image_dir)) 149 | colmap_to_image = dict(zip(colmap_files, image_files)) 150 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] 151 | 152 | # 3D points and {image_name -> [point_idx]} 153 | points = manager.points3D.astype(np.float32) 154 | points_err = manager.point3D_errors.astype(np.float32) 155 | points_rgb = manager.point3D_colors.astype(np.uint8) 156 | point_indices = dict() 157 | 158 | image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()} 159 | for point_id, data in manager.point3D_id_to_images.items(): 160 | for image_id, _ in data: 161 | image_name = image_id_to_name[image_id] 162 | point_idx = manager.point3D_id_to_point3D_idx[point_id] 163 | point_indices.setdefault(image_name, []).append(point_idx) 164 | point_indices = { 165 | k: np.array(v).astype(np.int32) for k, v in point_indices.items() 166 | } 167 | 168 | # Normalize the world space. 169 | if normalize: 170 | T1 = similarity_from_cameras(camtoworlds) 171 | camtoworlds = transform_cameras(T1, camtoworlds) 172 | points = transform_points(T1, points) 173 | 174 | T2 = align_principle_axes(points) 175 | camtoworlds = transform_cameras(T2, camtoworlds) 176 | points = transform_points(T2, points) 177 | 178 | transform = T2 @ T1 179 | else: 180 | transform = np.eye(4) 181 | 182 | self.image_names = image_names # List[str], (num_images,) 183 | self.image_paths = image_paths # List[str], (num_images,) 184 | self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) 185 | self.camera_ids = camera_ids # List[int], (num_images,) 186 | self.Ks_dict = Ks_dict # Dict of camera_id -> K 187 | self.params_dict = params_dict # Dict of camera_id -> params 188 | self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) 189 | self.points = points # np.ndarray, (num_points, 3) 190 | self.points_err = points_err # np.ndarray, (num_points,) 191 | self.points_rgb = points_rgb # np.ndarray, (num_points, 3) 192 | self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] 193 | self.transform = transform # np.ndarray, (4, 4) 194 | 195 | # load one image to check the size. In the case of tanksandtemples dataset, the 196 | # intrinsics stored in COLMAP corresponds to 2x upsampled images. 197 | actual_image = imageio.imread(self.image_paths[0])[..., :3] 198 | actual_height, actual_width = actual_image.shape[:2] 199 | colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]] 200 | s_height, s_width = actual_height / colmap_height, actual_width / colmap_width 201 | for camera_id, K in self.Ks_dict.items(): 202 | K[0, :] *= s_width 203 | K[1, :] *= s_height 204 | self.Ks_dict[camera_id] = K 205 | width, height = self.imsize_dict[camera_id] 206 | self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height)) 207 | 208 | # undistortion 209 | self.mapx_dict = dict() 210 | self.mapy_dict = dict() 211 | self.roi_undist_dict = dict() 212 | for camera_id in self.params_dict.keys(): 213 | params = self.params_dict[camera_id] 214 | if len(params) == 0: 215 | continue # no distortion 216 | assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}" 217 | assert ( 218 | camera_id in self.params_dict 219 | ), f"Missing params for camera {camera_id}" 220 | K = self.Ks_dict[camera_id] 221 | width, height = self.imsize_dict[camera_id] 222 | K_undist, roi_undist = cv2.getOptimalNewCameraMatrix( 223 | K, params, (width, height), 0 224 | ) 225 | mapx, mapy = cv2.initUndistortRectifyMap( 226 | K, params, None, K_undist, (width, height), cv2.CV_32FC1 227 | ) 228 | self.Ks_dict[camera_id] = K_undist 229 | self.mapx_dict[camera_id] = mapx 230 | self.mapy_dict[camera_id] = mapy 231 | self.roi_undist_dict[camera_id] = roi_undist 232 | 233 | # size of the scene measured by cameras 234 | camera_locations = camtoworlds[:, :3, 3] 235 | scene_center = np.mean(camera_locations, axis=0) 236 | dists = np.linalg.norm(camera_locations - scene_center, axis=1) 237 | self.scene_scale = np.max(dists) 238 | 239 | class BlenderParser: 240 | """Blender dataset parser.""" 241 | 242 | def __init__( 243 | self, 244 | data_dir: str, 245 | factor: int = 1, 246 | normalize: bool = False, 247 | test_every: int = 8, 248 | ): 249 | self.data_dir = data_dir 250 | self.factor = factor 251 | self.normalize = normalize 252 | self.test_every = test_every 253 | 254 | blender_dir = os.path.join(data_dir, "train") 255 | 256 | # colmap_dir = os.path.join(data_dir, "sparse/0/") 257 | # if not os.path.exists(colmap_dir): 258 | # colmap_dir = os.path.join(data_dir, "sparse") 259 | # assert os.path.exists( 260 | # colmap_dir 261 | # ), f"COLMAP directory {colmap_dir} does not exist." 262 | 263 | # manager = SceneManager(colmap_dir) 264 | # manager.load_cameras() 265 | # manager.load_images() 266 | # manager.load_points3D() 267 | 268 | # # Extract extrinsic matrices in world-to-camera format. 269 | # imdata = manager.images 270 | # w2c_mats = [] 271 | # camera_ids = [] 272 | # Ks_dict = dict() 273 | # params_dict = dict() 274 | # imsize_dict = dict() # width, height 275 | # bottom = np.array([0, 0, 0, 1]).reshape(1, 4) 276 | # for k in imdata: 277 | # im = imdata[k] 278 | # rot = im.R() 279 | # trans = im.tvec.reshape(3, 1) 280 | # w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) 281 | # w2c_mats.append(w2c) 282 | 283 | # # support different camera intrinsics 284 | # camera_id = im.camera_id 285 | # camera_ids.append(camera_id) 286 | 287 | # # camera intrinsics 288 | # cam = manager.cameras[camera_id] 289 | # fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy 290 | # K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 291 | # K[:2, :] /= factor 292 | # Ks_dict[camera_id] = K 293 | 294 | # # Get distortion parameters. 295 | # type_ = cam.camera_type 296 | # if type_ == 0 or type_ == "SIMPLE_PINHOLE": 297 | # params = np.empty(0, dtype=np.float32) 298 | # camtype = "perspective" 299 | # elif type_ == 1 or type_ == "PINHOLE": 300 | # params = np.empty(0, dtype=np.float32) 301 | # camtype = "perspective" 302 | # if type_ == 2 or type_ == "SIMPLE_RADIAL": 303 | # params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32) 304 | # camtype = "perspective" 305 | # elif type_ == 3 or type_ == "RADIAL": 306 | # params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32) 307 | # camtype = "perspective" 308 | # elif type_ == 4 or type_ == "OPENCV": 309 | # params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32) 310 | # camtype = "perspective" 311 | # elif type_ == 5 or type_ == "OPENCV_FISHEYE": 312 | # params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32) 313 | # camtype = "fisheye" 314 | # assert ( 315 | # camtype == "perspective" 316 | # ), f"Only support perspective camera model, got {type_}" 317 | 318 | # params_dict[camera_id] = params 319 | 320 | # # image size 321 | # imsize_dict[camera_id] = (cam.width // factor, cam.height // factor) 322 | 323 | # print( 324 | # f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras." 325 | # ) 326 | 327 | # if len(imdata) == 0: 328 | # raise ValueError("No images found in COLMAP.") 329 | # if not (type_ == 0 or type_ == 1): 330 | # print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.") 331 | 332 | # w2c_mats = np.stack(w2c_mats, axis=0) 333 | 334 | # # Convert extrinsics to camera-to-world. 335 | # camtoworlds = np.linalg.inv(w2c_mats) 336 | 337 | # # Image names from COLMAP. No need for permuting the poses according to 338 | # # image names anymore. 339 | # get files ending with .png 340 | image_names = [f for f in os.listdir(blender_dir) if f.endswith(".png")] 341 | import natsort 342 | image_names = natsort.natsorted(image_names) 343 | # print("image_names: ", image_names) 344 | # exit() 345 | image_paths = [os.path.join(blender_dir, f) for f in image_names] 346 | import json 347 | json_file = os.path.join(data_dir, "transforms_train.json") 348 | camtoworlds = [None for _ in range(len(image_names))] 349 | with open(json_file) as f: 350 | json_data = json.load(f) 351 | camera_angle_x = json_data["camera_angle_x"] 352 | width, height = 800, 800 353 | fx = 0.5 * width / np.tan(0.5 * camera_angle_x) 354 | fy = fx 355 | K = np.array([[fx, 0, 0.5 * width], [0, fy, 0.5 * height], [0, 0, 1]]) 356 | frames = json_data["frames"] 357 | camtoworlds = [None for _ in range(len(image_names))] 358 | for item in frames: 359 | image_name = item["file_path"].split("/")[-1]# + ".png" 360 | idx = image_names.index(image_name) 361 | print(idx, image_name, len(image_names)) 362 | # print("item: ", item) 363 | cam_matrix = np.array(item["transform_matrix"]) 364 | # Rotate 180 degrees around x-axis to convert from blender to colmap coordinate system 365 | R = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) 366 | cam_matrix[:3, :3] = cam_matrix[:3, :3] @ R 367 | # cam_matrix[:3, 3] = R @ cam_matrix[:3, 3] 368 | # cam_matrix_inv = np.linalg.inv(cam_matrix) 369 | print(cam_matrix) 370 | camtoworlds[idx] = cam_matrix 371 | 372 | # Check none values 373 | for idx, item in enumerate(camtoworlds): 374 | if item is None: 375 | print("None value at idx: ", idx) 376 | camtoworlds = np.stack(camtoworlds, axis=0) 377 | # exit() 378 | # print("camtoworlds shape: ", camtoworlds.shape) 379 | # exit() 380 | # image_names = [imdata[k].name for k in imdata] 381 | 382 | # # Previous Nerf results were generated with images sorted by filename, 383 | # # ensure metrics are reported on the same test set. 384 | # inds = np.argsort(image_names) 385 | # image_names = [image_names[i] for i in inds] 386 | # camtoworlds = camtoworlds[inds] 387 | # camera_ids = [camera_ids[i] for i in inds] 388 | 389 | # # Load images. 390 | # if factor > 1: 391 | # image_dir_suffix = f"_{factor}" 392 | # else: 393 | # image_dir_suffix = "" 394 | # colmap_image_dir = os.path.join(data_dir, "images") 395 | # image_dir = os.path.join(data_dir, "images" + image_dir_suffix) 396 | # for d in [image_dir, colmap_image_dir]: 397 | # if not os.path.exists(d): 398 | # raise ValueError(f"Image folder {d} does not exist.") 399 | 400 | # # Downsampled images may have different names vs images used for COLMAP, 401 | # # so we need to map between the two sorted lists of files. 402 | # colmap_files = sorted(_get_rel_paths(colmap_image_dir)) 403 | # image_files = sorted(_get_rel_paths(image_dir)) 404 | # colmap_to_image = dict(zip(colmap_files, image_files)) 405 | # image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] 406 | 407 | # # 3D points and {image_name -> [point_idx]} 408 | # points = manager.points3D.astype(np.float32) 409 | # points_err = manager.point3D_errors.astype(np.float32) 410 | # points_rgb = manager.point3D_colors.astype(np.uint8) 411 | # point_indices = dict() 412 | 413 | # image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()} 414 | # for point_id, data in manager.point3D_id_to_images.items(): 415 | # for image_id, _ in data: 416 | # image_name = image_id_to_name[image_id] 417 | # point_idx = manager.point3D_id_to_point3D_idx[point_id] 418 | # point_indices.setdefault(image_name, []).append(point_idx) 419 | # point_indices = { 420 | # k: np.array(v).astype(np.int32) for k, v in point_indices.items() 421 | # } 422 | 423 | # Normalize the world space. 424 | # if normalize: 425 | # T1 = similarity_from_cameras(camtoworlds) 426 | # camtoworlds = transform_cameras(T1, camtoworlds) 427 | # points = transform_points(T1, points) 428 | 429 | # T2 = align_principle_axes(points) 430 | # camtoworlds = transform_cameras(T2, camtoworlds) 431 | # points = transform_points(T2, points) 432 | 433 | # transform = T2 @ T1 434 | # else: 435 | # transform = np.eye(4) 436 | 437 | self.image_names = image_names # List[str], (num_images,) 438 | self.image_paths = image_paths # List[str], (num_images,) 439 | self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) 440 | # self.camtoworlds = np.linalg.inv(camtoworlds) 441 | self.camera_ids = list(range(len(camtoworlds))) # List[int], (num_images,) 442 | self.Ks_dict = {i: K for i in self.camera_ids} # Dict of camera_id -> K 443 | self.params_dict = {i:[] for i in self.camera_ids} # Dict of camera_id -> params 444 | self.points = np.random.uniform(-1,1,(15000,3)) # np.ndarray, (num_points, 3) 445 | self.points_rgb = np.random.uniform(0,255,(15000,3)).astype(int) # np.ndarray, (num_points, 3) 446 | # print(self.points.shape, self.points_rgb.shape) 447 | self.imsize_dict = {i: (800, 800) for i in self.camera_ids} # Dict of camera_id -> (width, height) 448 | # self.camera_ids = camera_ids # List[int], (num_images,) 449 | # self.Ks_dict = Ks_dict # Dict of camera_id -> K 450 | # self.params_dict = params_dict # Dict of camera_id -> params 451 | # self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) 452 | # self.points = points # np.ndarray, (num_points, 3) 453 | # self.points_err = points_err # np.ndarray, (num_points,) 454 | # self.points_rgb = points_rgb # np.ndarray, (num_points, 3) 455 | # self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] 456 | # self.transform = transform # np.ndarray, (4, 4) 457 | 458 | # # load one image to check the size. In the case of tanksandtemples dataset, the 459 | # # intrinsics stored in COLMAP corresponds to 2x upsampled images. 460 | # actual_image = imageio.imread(self.image_paths[0])[..., :3] 461 | # actual_height, actual_width = actual_image.shape[:2] 462 | # colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]] 463 | # s_height, s_width = actual_height / colmap_height, actual_width / colmap_width 464 | # for camera_id, K in self.Ks_dict.items(): 465 | # K[0, :] *= s_width 466 | # K[1, :] *= s_height 467 | # self.Ks_dict[camera_id] = K 468 | # width, height = self.imsize_dict[camera_id] 469 | # self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height)) 470 | 471 | # # undistortion 472 | # self.mapx_dict = dict() 473 | # self.mapy_dict = dict() 474 | # self.roi_undist_dict = dict() 475 | # for camera_id in self.params_dict.keys(): 476 | # params = self.params_dict[camera_id] 477 | # if len(params) == 0: 478 | # continue # no distortion 479 | # assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}" 480 | # assert ( 481 | # camera_id in self.params_dict 482 | # ), f"Missing params for camera {camera_id}" 483 | # K = self.Ks_dict[camera_id] 484 | # width, height = self.imsize_dict[camera_id] 485 | # K_undist, roi_undist = cv2.getOptimalNewCameraMatrix( 486 | # K, params, (width, height), 0 487 | # ) 488 | # mapx, mapy = cv2.initUndistortRectifyMap( 489 | # K, params, None, K_undist, (width, height), cv2.CV_32FC1 490 | # ) 491 | # self.Ks_dict[camera_id] = K_undist 492 | # self.mapx_dict[camera_id] = mapx 493 | # self.mapy_dict[camera_id] = mapy 494 | # self.roi_undist_dict[camera_id] = roi_undist 495 | 496 | # # size of the scene measured by cameras 497 | camera_locations = camtoworlds[:, :3, 3] 498 | scene_center = np.mean(camera_locations, axis=0) 499 | dists = np.linalg.norm(camera_locations - scene_center, axis=1) 500 | self.scene_scale = np.max(dists) 501 | 502 | 503 | class Dataset: 504 | """A simple dataset class.""" 505 | 506 | def __init__( 507 | self, 508 | parser: Parser, 509 | split: str = "train", 510 | patch_size: Optional[int] = None, 511 | load_depths: bool = False, 512 | ): 513 | self.parser = parser 514 | self.split = split 515 | self.patch_size = patch_size 516 | self.load_depths = load_depths 517 | indices = np.arange(len(self.parser.image_names)) 518 | if split == "train": 519 | self.indices = indices[indices % self.parser.test_every != 0] 520 | else: 521 | self.indices = indices[indices % self.parser.test_every == 0] 522 | 523 | def __len__(self): 524 | return len(self.indices) 525 | 526 | def __getitem__(self, item: int) -> Dict[str, Any]: 527 | index = self.indices[item] 528 | image = imageio.imread(self.parser.image_paths[index])[..., :3] 529 | camera_id = self.parser.camera_ids[index] 530 | K = self.parser.Ks_dict[camera_id].copy() # undistorted K 531 | params = self.parser.params_dict[camera_id] 532 | camtoworlds = self.parser.camtoworlds[index] 533 | 534 | if len(params) > 0: 535 | # Images are distorted. Undistort them. 536 | mapx, mapy = ( 537 | self.parser.mapx_dict[camera_id], 538 | self.parser.mapy_dict[camera_id], 539 | ) 540 | image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) 541 | x, y, w, h = self.parser.roi_undist_dict[camera_id] 542 | image = image[y : y + h, x : x + w] 543 | 544 | if self.patch_size is not None: 545 | # Random crop. 546 | h, w = image.shape[:2] 547 | x = np.random.randint(0, max(w - self.patch_size, 1)) 548 | y = np.random.randint(0, max(h - self.patch_size, 1)) 549 | image = image[y : y + self.patch_size, x : x + self.patch_size] 550 | K[0, 2] -= x 551 | K[1, 2] -= y 552 | 553 | data = { 554 | "K": torch.from_numpy(K).float(), 555 | "camtoworld": torch.from_numpy(camtoworlds).float(), 556 | "image": torch.from_numpy(image).float(), 557 | "image_id": item, # the index of the image in the dataset 558 | } 559 | 560 | if self.load_depths: 561 | # projected points to image plane to get depths 562 | worldtocams = np.linalg.inv(camtoworlds) 563 | image_name = self.parser.image_names[index] 564 | point_indices = self.parser.point_indices[image_name] 565 | points_world = self.parser.points[point_indices] 566 | points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T 567 | points_proj = (K @ points_cam.T).T 568 | points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2) 569 | depths = points_cam[:, 2] # (M,) 570 | if self.patch_size is not None: 571 | points[:, 0] -= x 572 | points[:, 1] -= y 573 | # filter out points outside the image 574 | selector = ( 575 | (points[:, 0] >= 0) 576 | & (points[:, 0] < image.shape[1]) 577 | & (points[:, 1] >= 0) 578 | & (points[:, 1] < image.shape[0]) 579 | & (depths > 0) 580 | ) 581 | points = points[selector] 582 | depths = depths[selector] 583 | data["points"] = torch.from_numpy(points).float() 584 | data["depths"] = torch.from_numpy(depths).float() 585 | 586 | return data 587 | 588 | 589 | if __name__ == "__main__": 590 | import argparse 591 | 592 | import imageio.v2 as imageio 593 | import tqdm 594 | 595 | parser = argparse.ArgumentParser() 596 | parser.add_argument("--data_dir", type=str, default="data/360_v2/garden") 597 | parser.add_argument("--factor", type=int, default=4) 598 | args = parser.parse_args() 599 | 600 | # Parse COLMAP data. 601 | parser = Parser( 602 | data_dir=args.data_dir, factor=args.factor, normalize=True, test_every=8 603 | ) 604 | dataset = Dataset(parser, split="train", load_depths=True) 605 | print(f"Dataset: {len(dataset)} images.") 606 | 607 | writer = imageio.get_writer("results/points.mp4", fps=30) 608 | for data in tqdm.tqdm(dataset, desc="Plotting points"): 609 | image = data["image"].numpy().astype(np.uint8) 610 | points = data["points"].numpy() 611 | depths = data["depths"].numpy() 612 | for x, y in points: 613 | cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1) 614 | writer.append_data(image) 615 | writer.close() 616 | --------------------------------------------------------------------------------