├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── media ├── POGS_servoing.gif └── POGS_teaser.gif ├── pogs ├── calibration_outputs │ └── world_to_extrinsic_zed_for_grasping_down.tf ├── camera │ ├── capture_utils.py │ └── zed_stereo.py ├── configs │ ├── base_config.py │ ├── camera_config.yaml │ └── conda_compile_pointnet_tfops.sh ├── contact_graspnet_wrapper │ ├── prime_config_utils.py │ ├── prime_inference.py │ └── prime_visualization_utils.py ├── data │ ├── depth_dataset.py │ ├── full_images_datamanager.py │ └── utils │ │ ├── detic_dataloader.py │ │ ├── dino_dataloader.py │ │ ├── dino_extractor.py │ │ ├── feature_dataloader.py │ │ ├── patch_embedding_dataloader.py │ │ └── pyramid_embedding_dataloader.py ├── encoders │ ├── image_encoder.py │ └── openclip_encoder.py ├── field_components │ └── gaussian_fieldheadnames.py ├── fields │ └── gaussian_field.py ├── grasping │ ├── generate_grasps_ply.py │ └── results │ │ └── predictions_global.ply.npz ├── model_components │ └── losses.py ├── pogs.py ├── pogs_config.py ├── pogs_pipeline.py ├── scripts │ ├── calibrate_cameras.py │ ├── scene_capture.py │ ├── track_main_demo.py │ └── track_main_online_demo.py └── tracking │ ├── atap_loss.py │ ├── data │ ├── ZED2.stl │ └── ZEDM.stl │ ├── motion.py │ ├── observation.py │ ├── optim.py │ ├── rigid_group_optimizer.py │ ├── toad_object.py │ ├── transforms │ ├── __init__.py │ ├── _base.py │ ├── _se3.py │ ├── _so3.py │ └── utils │ │ ├── __init__.py │ │ └── _utils.py │ ├── tri_zed.py │ ├── utils.py │ ├── utils2.py │ └── zed.py ├── pyproject.toml └── scripts ├── dino_pca_visualization.py ├── shelf_iron.png ├── shelf_iron.png_pca.png └── shelf_iron.png_pca_hist.png /.gitignore: -------------------------------------------------------------------------------- 1 | *.ckpt 2 | 3 | *.egg-info 4 | 5 | *.ipynb 6 | 7 | *.log 8 | 9 | *.model 10 | 11 | *.pth 12 | 13 | *.pyc 14 | 15 | *.npy 16 | 17 | *.ply 18 | 19 | __pycache__/ 20 | 21 | pogs/sample_data/* 22 | 23 | pogs/configs/__pycache__/* 24 | pogs/data/__pycache__/* 25 | pogs/data/utils/__pycache__/* 26 | pogs/encoders/__pycache__/* 27 | 28 | pogs/scripts/outputs/* 29 | 30 | pogs/ur5_interface/ur5_interface/data/* 31 | pogs/ur5_interface/ur5_interface/scripts/outputs/* 32 | outputs/* 33 | 34 | pogs/tracking/__pycache__/* 35 | pogs/ur5_interface/ur5_interface/outputs/* 36 | pogs/ur5_interface/ur5_interface/scripts/results/* 37 | 38 | data/* 39 | media/POGS.mp4 40 | outputs/* 41 | pogs/tracking/models/ 42 | *.mp4 43 | *.ipynb 44 | .vscode/ 45 | pogs/data/utils/datasets/ 46 | pogs/calibration_outputs/world_to_extrinsic_zed.tf 47 | pogs/calibration_outputs/wrist_to_zed_mini.tf 48 | pogs/pogs/grasping/results 49 | pogs/results 50 | pogs/robot_description 51 | pogs/scripts/goto_click_extrinsic_cam.py 52 | pogs/scripts/goto_click_wrist_cam.py 53 | pogs/scripts/test_grasp_visualization.py 54 | pogs/scripts/results -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pogs/dependencies/detectron2"] 2 | path = pogs/dependencies/detectron2 3 | url = https://github.com/facebookresearch/detectron2.git 4 | [submodule "pogs/dependencies/Detic"] 5 | path = pogs/dependencies/Detic 6 | url = https://github.com/facebookresearch/Detic.git 7 | [submodule "pogs/dependencies/ur5py"] 8 | path = pogs/dependencies/ur5py 9 | url = https://github.com/kushtimusPrime/ur5py 10 | [submodule "pogs/dependencies/raftstereo"] 11 | path = pogs/dependencies/raftstereo 12 | url = https://github.com/kushtimusPrime/RAFT-Stereo 13 | [submodule "pogs/dependencies/contact_graspnet"] 14 | path = pogs/dependencies/contact_graspnet 15 | url = https://github.com/NVlabs/contact_graspnet.git 16 | [submodule "pogs/dependencies/nerfstudio"] 17 | path = pogs/dependencies/nerfstudio 18 | url = https://github.com/nerfstudio-project/nerfstudio.git 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Justin Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Persistent Object Gaussian Splat (POGS) for Tracking Human and Robot Manipulation of Irregularly Shaped Objects 2 | 3 |
4 | 5 | [[Website]](https://berkeleyautomation.github.io/POGS/) 6 | 7 | 8 | 9 | 10 | 11 | 12 |
 
13 | 14 |
 
15 | 16 | 17 |
18 | 19 | This repository contains the official implementation for [POGS](https://berkeleyautomation.github.io/POGS/). 20 | 21 | Tested on Python 3.10, cuda 11.8, using conda. 22 | 23 | ## Installation 24 | 1. Create conda environment and install relevant packages 25 | ``` 26 | conda create --name pogs_env -y python=3.10 27 | conda activate pogs_env 28 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit 29 | 30 | pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu118 31 | 32 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 33 | pip install jaxtyping rich 34 | pip install gsplat --index-url https://docs.gsplat.studio/whl/pt20cu118 35 | pip install warp-lang 36 | ``` 37 | 38 | 2. [`cuml`](https://docs.rapids.ai/install) is required (for global clustering). 39 | The best way to install it is with pip: `pip install --extra-index-url=https://pypi.nvidia.com cudf-cu11==25.4.* cuml-cu11==25.4.*` 40 | 41 | 3. Install POGS! 42 | ``` 43 | git clone https://github.com/uynitsuj/POGS.git --recurse-submodules 44 | cd POGS 45 | python -m pip install -e . 46 | python -m pip install pogs/dependencies/nerfstudio/ 47 | pip install fast_simplification==0.1.9 48 | pip install numpy==1.26.4 49 | ns-install-cli 50 | ``` 51 | ### Robot Interaction Code Installation (UR5 Specific) 52 | 53 | There is also a physical robot component with the UR5 and ZED 2 cameras. To install relevant libraries: 54 | #### ur5py 55 | ``` 56 | pip install ur_rtde==1.4.2 cowsay opt-einsum pyvista autolab-core 57 | pip install -e /pogs/dependencies/ur5py 58 | ``` 59 | 60 | #### RAFT-Stereo 61 | ``` 62 | cd ~/POGS/pogs/dependencies/raftstereo 63 | bash download_models.sh 64 | pip install -e . 65 | ``` 66 | 67 | #### Contact-Graspnet 68 | Contact Graspnet relies on some older library setups, so we couldn't merge everything into 1 conda environment. However, we can make it work by making this separate conda environment and then calling it in a subprocess. 69 | ``` 70 | conda deactivate 71 | conda create --name contact_graspnet_env python=3.8 72 | conda activate contact_graspnet_env 73 | conda install -c conda-forge cudatoolkit=11.2 74 | conda install -c conda-forge cudnn=8.2 75 | # If you don't have cuda installed at /usr/local/cuda then you can install on your conda env and run these two lines 76 | conda install -c conda-forge cudatoolkit-dev 77 | export CUDA_HOME=/path/to/anaconda/envs/contact_graspnet_env/bin/nvcc 78 | pip install tensorflow==2.5 tensorflow-gpu==2.5 79 | pip install opencv-python-headless pyyaml pyrender tqdm mayavi 80 | pip install open3d==0.10.0 typing-extensions==3.7.4 trimesh==3.8.12 configobj==5.0.6 matplotlib==3.3.2 pyside2==5.11.0 scikit-image==0.19.0 numpy==1.19.2 scipy==1.9.1 vtk==9.3.1 81 | # if you have cuda installed at /usr/local/cuda run these lines 82 | cd ~/POGS/pogs/dependencies/contact_graspnet 83 | sh compile_pointnet_tfops.sh 84 | # if you have cuda installed on your conda env run these lines 85 | cd ~/POGS/pogs/configs 86 | cp conda_compile_pointnet_tfops.sh ~/pogs/pogs/dependencies/contact_graspnet/ 87 | cd ~/POGS/pogs/dependencies/contact_graspnet 88 | sh conda_compile_pointnet_tfops.sh 89 | pip install autolab-core 90 | ``` 91 | 92 | #### Download Models and Data 93 | ##### Model 94 | Download trained models from [here](https://drive.google.com/drive/folders/1tBHKf60K8DLM5arm-Chyf7jxkzOr5zGl?usp=sharing) and copy them into the `checkpoints/` folder. 95 | ##### Test data 96 | Download the test data from [here](https://drive.google.com/drive/folders/1TqpM2wHAAo0j3i1neu3Xeru3_WnsYQnx?usp=sharing) and copy them them into the `test_data/` folder. 97 | 98 | ## Usage 99 | ### Calibrate wrist mounted and third person cameras 100 | Before training/tracking POGS, make sure wrist mounted camera and third-person view camera are calibrated. We use an Aruco marker for the calibration 101 | ``` 102 | conda activate pogs_env 103 | cd ~/POGS/pogs/scripts 104 | python calibrate_cameras.py 105 | ``` 106 | 107 | ### Scene Capture 108 | Script used to perform hemisphere capture with robot on tabletop scene. We used manual trajectory but you can also put the robot in "teach" mode to capture trajectory. 109 | ``` 110 | conda activate pogs_env 111 | cd ~/POGS/pogs/scripts 112 | python scene_capture.py --scene DATA_NAME 113 | ``` 114 | 115 | ### Train POGS 116 | Script used to train the POGS for 4000 steps 117 | ``` 118 | conda activate pogs_env 119 | ns-train pogs --data /path/to/data/folder 120 | ``` 121 | Once the POGS has completed training, there are N steps to then actually define/save the object clusters. 122 | 1. Hit the cluster scene button. 123 | 2. It will take 10-20 seconds, but then after, you should see your objects as specific clusters. If not, hit Toggle RGB/Cluster and try to cluster the scene again but change the Cluster Eps (lower normally works better). 124 | 3. Once you have your scene clustered, hit Toggle RGB/Cluster. 125 | 4. Then, hit Click and click on your desired object (green ball will appear on object). 126 | 5. Hit Crop to Click, and it should isolate the object. 127 | 6. A draggable coordinate frame will pop up to indicate the object's origin, drag it to where you want it to be. (For experiments, this was what we used to align for object reset or tool servoing) 128 | 7. Hit Add Crop to Group List 129 | 8. Repeat steps 4-7 for all objects in scene 130 | 9. Hit View Crop Group List 131 | Once you have trained the POGS, make sure you have the config file and checkpoint directory from the terminal saved. 132 | 133 | ### Run POGS for grasping 134 | Script for letting you use a POGS to track an object online and grasp it. 135 | ``` 136 | conda activate pogs_env 137 | python ~/POGS/pogs/scripts/track_main_online_demo.py --config_path /path/to/config/yml 138 | ``` 139 | 140 | ## Bibtex 141 | If you find POGS useful for your work please consider citing: 142 | ``` 143 | @article{yu2025pogs, 144 | author = {Yu, Justin and Hari, Kush and El-Refai, Karim and Dalil, Arnav and Kerr, Justin and Kim, Chung-Min and Cheng, Richard, and Irshad, Muhammad Z. and Goldberg, Ken}, 145 | title = {Persistent Object Gaussian Splat (POGS) for Tracking Human and Robot Manipulation of Irregularly Shaped Objects}, 146 | journal = {ICRA}, 147 | year = {2025}, 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /media/POGS_servoing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/media/POGS_servoing.gif -------------------------------------------------------------------------------- /media/POGS_teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/media/POGS_teaser.gif -------------------------------------------------------------------------------- /pogs/calibration_outputs/world_to_extrinsic_zed_for_grasping_down.tf: -------------------------------------------------------------------------------- 1 | zed_extrinsic_for_grasping 2 | world 3 | 0.0 0.0 0.0 4 | 0.0 -1.0 0.0 5 | -1.0 0.0 0.0 6 | 0.0 0.0 -1.0 -------------------------------------------------------------------------------- /pogs/camera/capture_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from autolab_core import RigidTransform, CameraIntrinsics 3 | from autolab_core.transformations import euler_matrix, euler_from_matrix 4 | from PIL import Image 5 | import json 6 | from scipy.optimize import minimize 7 | from typing import List 8 | 9 | def estimate_cam2rob(H_chess_cams: List[RigidTransform], H_rob_worlds: List[RigidTransform]): 10 | ''' 11 | Estimates transform between camera and robot end-effector frame using least-squares optimization. 12 | This implements the hand-eye calibration procedure using multiple observations. 13 | 14 | Parameters: 15 | ----------- 16 | H_chess_cams : List[RigidTransform] 17 | List of transformations from chessboard frame to camera frame for each observation 18 | H_rob_worlds : List[RigidTransform] 19 | List of transformations from robot end-effector frame to world frame for each observation 20 | 21 | Returns: 22 | -------- 23 | H_cam_rob : RigidTransform 24 | Estimated transformation from camera frame to robot end-effector frame 25 | H_chess_world : RigidTransform 26 | Estimated transformation from chessboard frame to world frame 27 | ''' 28 | def residual(x): 29 | ''' 30 | Computes the residual error for the current estimate of transformations. 31 | 32 | Parameters: 33 | ----------- 34 | x : numpy.ndarray 35 | Parameter vector formatted as [x, y, z, rx, ry, rz] for chessboard-to-world 36 | followed by [x, y, z, rx, ry, rz] for camera-to-robot 37 | 38 | Returns: 39 | -------- 40 | float 41 | Sum of position and orientation errors across all observations 42 | ''' 43 | err = 0 44 | # Extract transformations from parameter vector 45 | H_chess_world = RigidTransform(translation=x[:3], 46 | rotation=euler_matrix(x[3], x[4], x[5])[:3, :3], 47 | from_frame='chess', to_frame='world') 48 | H_cam_rob = RigidTransform(translation=x[6:9], 49 | rotation=euler_matrix(x[9], x[10], x[11])[:3, :3], 50 | from_frame='cam', to_frame='rob') 51 | 52 | # Compute error across all observations 53 | for H_chess_cam, H_rob_world in zip(H_chess_cams, H_rob_worlds): 54 | # Estimate chessboard-to-world transform using current parameters 55 | H_chess_world_est = H_rob_world * H_cam_rob * H_chess_cam 56 | 57 | # Compute translation error 58 | err += np.linalg.norm(H_chess_world.translation - H_chess_world_est.translation) 59 | 60 | # Compute rotation error using Euler angles 61 | rot_diff = H_chess_world.rotation @ np.linalg.inv(H_chess_world_est.rotation) 62 | eul_diff = euler_from_matrix(rot_diff) 63 | err += np.linalg.norm(eul_diff) 64 | 65 | print(err) 66 | return err 67 | 68 | # Initialize parameters with zeros 69 | x0 = np.zeros(12) 70 | 71 | # Perform optimization using Sequential Least Squares Programming 72 | res = minimize(residual, x0, method='SLSQP') 73 | print(res) 74 | 75 | # Alert if optimization did not converge 76 | if not res.success: 77 | input("Optimization was not successful, press enter to acknowledge") 78 | 79 | # Extract optimized parameters 80 | x = res.x 81 | H_chess_world = RigidTransform(translation=x[:3], 82 | rotation=euler_matrix(x[3], x[4], x[5])[:3, :3], 83 | from_frame='chess', to_frame='world') 84 | H_cam_rob = RigidTransform(translation=x[6:9], 85 | rotation=euler_matrix(x[9], x[10], x[11])[:3, :3], 86 | from_frame='cam', to_frame='rob') 87 | 88 | return H_cam_rob, H_chess_world 89 | 90 | 91 | def point_at(cam_t, obstacle_t, extra_R=np.eye(3)): 92 | ''' 93 | Computes a transformation that orients a camera to point at a specific 3D point. 94 | 95 | Parameters: 96 | ----------- 97 | cam_t : numpy.ndarray 98 | 3D position of the camera/gripper in world coordinates 99 | obstacle_t : numpy.ndarray 100 | 3D position of the target location to point the camera at 101 | extra_R : numpy.ndarray, optional 102 | Additional rotation matrix to apply to the camera frame, useful for fine-tuning orientation 103 | 104 | Returns: 105 | -------- 106 | RigidTransform 107 | Transformation representing camera pose pointing at the target 108 | ''' 109 | # Compute the direction vector from camera to target 110 | dir = obstacle_t - cam_t 111 | z_axis = dir / np.linalg.norm(dir) 112 | 113 | # Compute the x-axis perpendicular to z-axis and world z-axis 114 | # Note: change the sign if camera positioning is difficult 115 | x_axis_dir = -np.cross(np.array((0, 0, 1)), z_axis) 116 | 117 | # Handle special case when camera direction is aligned with world z-axis 118 | if np.linalg.norm(x_axis_dir) < 1e-10: 119 | x_axis_dir = np.array((0, 1, 0)) 120 | x_axis = x_axis_dir / np.linalg.norm(x_axis_dir) 121 | 122 | # Complete the right-handed coordinate system with y-axis 123 | y_axis_dir = np.cross(z_axis, x_axis) 124 | y_axis = y_axis_dir / np.linalg.norm(y_axis_dir) 125 | 126 | # Create rotation matrix from axes and apply extra rotation 127 | # Post-multiply to rotate the camera with respect to itself 128 | R = RigidTransform.rotation_from_axes(x_axis, y_axis, z_axis) @ extra_R 129 | 130 | # Create and return the complete transformation 131 | H = RigidTransform(translation=cam_t, rotation=R, from_frame='camera', to_frame='base_link') 132 | return H 133 | 134 | def save_data(imgs, poses, savedir, intr: CameraIntrinsics): 135 | ''' 136 | Saves a collection of images and camera poses in the NeRF dataset format. 137 | 138 | Parameters: 139 | ----------- 140 | imgs : List[numpy.ndarray] 141 | List of images captured from each viewpoint 142 | poses : List[RigidTransform] 143 | List of camera poses corresponding to each image 144 | savedir : str 145 | Directory path to save the dataset 146 | intr : CameraIntrinsics 147 | Camera intrinsic parameters 148 | 149 | Notes: 150 | ------ 151 | The NeRF format includes: 152 | - Individual image files 153 | - A transforms.json file containing camera parameters and poses 154 | - Special conventions for coordinate systems (flipped y and z axes) 155 | ''' 156 | import os 157 | os.makedirs(savedir, exist_ok=True) 158 | 159 | # Initialize data dictionary for transforms.json 160 | data_dict = dict() 161 | data_dict['frames'] = [] 162 | 163 | # Add camera intrinsic parameters 164 | data_dict['fl_x'] = intr.fx 165 | data_dict['fl_y'] = intr.fy 166 | data_dict['cx'] = intr.cx 167 | data_dict['cy'] = intr.cy 168 | data_dict['h'] = imgs[0].shape[0] 169 | data_dict['w'] = imgs[0].shape[1] 170 | 171 | # NeRF-specific parameters 172 | data_dict['aabb_scale'] = 2 173 | data_dict['scale'] = 1.2 174 | 175 | pil_images = [] 176 | for i, (im, p) in enumerate(zip(imgs, poses)): 177 | # Remove alpha channel if present 178 | if im.shape[2] == 4: 179 | im = im[..., :3] 180 | 181 | # Save image file 182 | img = Image.fromarray(im) 183 | pil_images.append(img) 184 | img.save(f'{savedir}/img{i}.jpg') 185 | 186 | # Convert pose to NeRF format (flip y and z axes) 187 | mat = p.matrix 188 | mat[:3, 1] *= -1 189 | mat[:3, 2] *= -1 190 | 191 | # Add frame information to data dictionary 192 | frame = {'file_path': f'img{i}.jpg', 'transform_matrix': mat.tolist()} 193 | data_dict['frames'].append(frame) 194 | 195 | # Save transforms.json file 196 | with open(f"{savedir}/transforms.json", 'w') as fp: 197 | json.dump(data_dict, fp) 198 | 199 | def load_data(savedir): 200 | ''' 201 | Loads a NeRF dataset from a directory. 202 | 203 | Parameters: 204 | ----------- 205 | savedir : str 206 | Directory path containing the NeRF dataset 207 | 208 | Returns: 209 | -------- 210 | Tuple[List[numpy.ndarray], List[RigidTransform]] 211 | Tuple containing (images, camera_poses) 212 | 213 | Notes: 214 | ------ 215 | This function reverses the coordinate system conversions applied in save_data 216 | to restore the original camera poses. 217 | ''' 218 | import os 219 | if not os.path.exists(savedir): 220 | raise FileNotFoundError(f'{savedir} does not exist') 221 | 222 | # Load transforms.json file 223 | with open(f"{savedir}/transforms.json", 'r') as fp: 224 | data_dict = json.load(fp) 225 | 226 | poses = [] 227 | imgs = [] 228 | 229 | # Process each frame in the dataset 230 | for frame in data_dict['frames']: 231 | # Load image from file 232 | img = Image.open(f'{savedir}/{frame["file_path"]}') 233 | imgs.append(np.array(img)) 234 | 235 | # Convert pose matrix from NeRF format back to original format 236 | mat = np.array(frame['transform_matrix']) 237 | mat[:3, 1] *= -1 238 | mat[:3, 2] *= -1 239 | 240 | # Create RigidTransform from matrix 241 | poses.append(RigidTransform(*RigidTransform.rotation_and_translation_from_matrix(mat))) 242 | 243 | return imgs, poses -------------------------------------------------------------------------------- /pogs/camera/zed_stereo.py: -------------------------------------------------------------------------------- 1 | import pyzed.sl as sl 2 | import numpy as np 3 | from autolab_core import CameraIntrinsics, PointCloud, RgbCloud 4 | from raftstereo.raft_stereo import * 5 | from raftstereo.utils.utils import InputPadder 6 | import argparse 7 | from dataclasses import dataclass, field 8 | from typing import List 9 | import os 10 | DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | # Configuration class for the RAFT stereo model parameters 13 | class RAFTConfig: 14 | # Path to the pre-trained RAFT stereo model checkpoint 15 | restore_ckpt: str = str(os.path.join(DIR_PATH,'../dependencies/raftstereo/models/raftstereo-middlebury.pth')) 16 | # Hidden dimensions for network layers 17 | hidden_dims: List[int] = [128]*3 18 | # Correlation implementation type 19 | corr_implementation: str = "reg" 20 | # Whether to use a shared backbone for both images 21 | shared_backbone: bool = False 22 | # Correlation pyramid levels for feature matching 23 | corr_levels: int = 4 24 | # Radius of correlation lookups 25 | corr_radius: int = 4 26 | # Number of downsampling layers 27 | n_downsample: int = 2 28 | # Normalization type for context network 29 | context_norm: str = "batch" 30 | # Whether to use slow-fast GRU for iterative updates 31 | slow_fast_gru: bool = False 32 | # Enable mixed precision for faster computation and memory efficiency 33 | mixed_precision: bool = False 34 | # Number of GRU layers in the update block 35 | n_gru_layers: int = 3 36 | 37 | # ZED stereo camera interface with RAFT stereo depth estimation 38 | class Zed: 39 | """ 40 | ZED camera wrapper for stereo image capture and depth estimation using RAFT-Stereo algorithm. 41 | Provides functions for camera initialization, image capture, depth estimation, and point cloud generation. 42 | """ 43 | def __init__(self, flip_mode, resolution, fps, cam_id=None, recording_file=None, start_time=0.0): 44 | """ 45 | Initialize the ZED camera with specified parameters. 46 | 47 | Args: 48 | flip_mode (bool): Whether to flip the camera image 49 | resolution (str): Camera resolution ('720p', '1080p', or '2k') 50 | fps (int): Frames per second 51 | cam_id (int, optional): Camera serial number for multi-camera setups 52 | recording_file (str, optional): Path to SVO recording file for playback 53 | start_time (float, optional): Start time in seconds for SVO playback 54 | """ 55 | init = sl.InitParameters() 56 | if cam_id is not None: 57 | init.set_from_serial_number(cam_id) 58 | self.cam_id = cam_id 59 | self.height_ = None 60 | self.width_ = None 61 | 62 | # Set camera flip mode 63 | if flip_mode: 64 | init.camera_image_flip = sl.FLIP_MODE.ON 65 | else: 66 | init.camera_image_flip = sl.FLIP_MODE.OFF 67 | 68 | # Configure for SVO file playback if provided 69 | if recording_file is not None: 70 | init.set_from_svo_file(recording_file) 71 | 72 | # Configure camera resolution 73 | if resolution == '720p': 74 | init.camera_resolution = sl.RESOLUTION.HD720 75 | self.height_ = 720 76 | self.width_ = 1280 77 | elif resolution == '1080p': 78 | init.camera_resolution = sl.RESOLUTION.HD1080 79 | self.height_ = 1080 80 | self.width_ = 1920 81 | elif resolution == '2k': 82 | init.camera_resolution = sl.RESOLUTION.HD2k 83 | self.height_ = 1242 84 | self.width_ = 2208 85 | else: 86 | print("Only 720p, 1080p, and 2k supported by Zed") 87 | exit() 88 | 89 | # Disable native ZED depth computation (we'll use RAFT-Stereo instead) 90 | init.depth_mode = sl.DEPTH_MODE.NONE 91 | init.sdk_verbose = 1 92 | init.camera_fps = fps 93 | self.cam = sl.Camera() 94 | init.camera_disable_self_calib = True 95 | 96 | # Open the camera with the specified parameters 97 | status = self.cam.open(init) 98 | self.recording_file = recording_file 99 | self.start_time = start_time 100 | 101 | # Set SVO playback position if applicable 102 | if recording_file is not None: 103 | fps = self.cam.get_camera_information().camera_configuration.fps 104 | self.cam.set_svo_position(int(start_time * fps)) 105 | 106 | # Check if camera opened successfully 107 | if status != sl.ERROR_CODE.SUCCESS: 108 | print("Camera Open : " + repr(status) + ". Exit program.") 109 | exit() 110 | else: 111 | print("Opened camera") 112 | 113 | # Calculate stereo parameters for depth estimation 114 | left_cx = self.get_K(cam="left")[0, 2] 115 | right_cx = self.get_K(cam="right")[0, 2] 116 | self.cx_diff = right_cx - left_cx # Horizontal principal point difference for disparity calculation 117 | self.f_ = self.get_K(cam="left")[0,0] # Focal length 118 | self.cx_ = left_cx # Principal point x-coordinate 119 | self.cy_ = self.get_K(cam="left")[1,2] # Principal point y-coordinate 120 | self.Tx_ = self.get_stereo_transform()[0,3] # Baseline (translation between cameras) 121 | 122 | # RAFT-Stereo parameters 123 | self.valid_iters_ = 32 # Number of iterations for RAFT-Stereo flow estimation 124 | self.padder_ = InputPadder(torch.empty(1,3,self.height_,self.width_).shape, divis_by=32) 125 | self.model = self.create_raft() # Initialize RAFT-Stereo model 126 | 127 | def create_raft(self): 128 | """ 129 | Initialize and load the RAFT-Stereo model for depth estimation. 130 | 131 | Returns: 132 | torch.nn.Module: Loaded RAFT-Stereo model in evaluation mode 133 | """ 134 | raft_args = RAFTConfig() 135 | model = torch.nn.DataParallel(RAFTStereo(raft_args), device_ids=[0]) 136 | model.load_state_dict(torch.load(raft_args.restore_ckpt)) 137 | 138 | model = model.module 139 | model = model.to('cuda') 140 | model = model.eval() 141 | return model 142 | 143 | def load_image_raft(self, im): 144 | """ 145 | Prepare image for RAFT-Stereo model inference. 146 | 147 | Args: 148 | im (numpy.ndarray): RGB image array 149 | 150 | Returns: 151 | torch.Tensor: Processed image tensor on GPU 152 | """ 153 | img = torch.from_numpy(im).permute(2,0,1).float() 154 | return img[None].to('cuda') 155 | 156 | def get_depth_image_and_pointcloud(self, left_img, right_img, from_frame): 157 | """ 158 | Compute depth image and point cloud from stereo images using RAFT-Stereo. 159 | 160 | Args: 161 | left_img (numpy.ndarray): Left RGB image 162 | right_img (numpy.ndarray): Right RGB image 163 | from_frame (str): Coordinate frame name for point cloud 164 | 165 | Returns: 166 | tuple: (depth_image, points, rgbs) containing depth map and colored point cloud 167 | """ 168 | with torch.no_grad(): 169 | # Prepare images for RAFT model 170 | image1 = self.load_image_raft(left_img) 171 | image2 = self.load_image_raft(right_img) 172 | image1, image2 = self.padder_.pad(image1, image2) 173 | 174 | # Run RAFT-Stereo to compute disparity (negative of optical flow) 175 | _, flow_up = self.model(image1, image2, iters=self.valid_iters_, test_mode=True) 176 | flow_up = self.padder_.unpad(flow_up).squeeze() 177 | 178 | flow_up_np = -flow_up.detach().cpu().numpy().squeeze() 179 | 180 | # Convert disparity to depth using the stereo camera equation: 181 | # depth = (focal_length * baseline) / disparity 182 | depth_image = (self.f_ * self.Tx_) / abs(flow_up_np + self.cx_diff) 183 | rows, cols = depth_image.shape 184 | y, x = np.meshgrid(range(rows), range(cols), indexing="ij") 185 | 186 | # Convert depth image to x,y,z point cloud using the pinhole camera model 187 | Z = depth_image # Depth values 188 | # X = (x - cx) * Z / fx 189 | X = (x - self.cx_) * Z / self.f_ 190 | # Y = (y - cy) * Z / fy 191 | Y = (y - self.cy_) * Z / self.f_ 192 | points = np.stack((X,Y,Z), axis=-1) 193 | rgbs = left_img 194 | 195 | # Remove points with zero depth (invalid or background points) 196 | non_zero_indices = np.all(points != [0, 0, 0], axis=-1) 197 | points = points[non_zero_indices] 198 | rgbs = rgbs[non_zero_indices] 199 | 200 | # Format as PointCloud and RgbCloud objects for compatibility with autolab_core 201 | points = points.reshape(-1,3) 202 | rgbs = rgbs.reshape(-1,3) 203 | points = PointCloud(points.T, from_frame) 204 | rgbs = RgbCloud(rgbs.T, from_frame) 205 | return depth_image, points, rgbs 206 | 207 | def get_frame(self, depth=True, cam="left"): 208 | """ 209 | Capture a frame from the ZED camera and optionally compute depth. 210 | 211 | Args: 212 | depth (bool): Whether to compute depth 213 | cam (str): Which camera to use as reference ("left" or "right") 214 | 215 | Returns: 216 | tuple: (left_image, right_image, depth_map) or None if frame capture failed 217 | """ 218 | res = sl.Resolution() 219 | res.width = self.width_ 220 | res.height = self.height_ 221 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS: 222 | # Retrieve stereo images 223 | left_rgb = sl.Mat() 224 | right_rgb = sl.Mat() 225 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT, sl.MEM.CPU, res) 226 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT, sl.MEM.CPU, res) 227 | left, right = ( 228 | torch.from_numpy( 229 | np.flip(left_rgb.get_data()[..., :3], axis=2).copy() 230 | ).cuda(), 231 | torch.from_numpy( 232 | np.flip(right_rgb.get_data()[..., :3], axis=2).copy() 233 | ).cuda(), 234 | ) 235 | 236 | # Compute depth if requested 237 | if depth: 238 | left_torch, right_torch = left.permute(2, 0, 1), right.permute(2, 0, 1) 239 | 240 | # Handle different reference views (left or right camera) 241 | if cam == "left": 242 | flow = raft_inference(left_torch, right_torch, self.model) 243 | else: 244 | right_torch = torch.flip(right_torch, dims=[2]) 245 | left_torch = torch.flip(left_torch, dims=[2]) 246 | flow = raft_inference(right_torch, left_torch, self.model) 247 | 248 | # Compute depth from disparity using stereo camera equation 249 | fx = self.get_K()[0, 0] # Focal length 250 | depth = ( 251 | fx * self.get_stereo_transform()[0, 3] / (flow.abs() + self.cx_diff) 252 | ) 253 | 254 | if cam != "left": 255 | depth = torch.flip(depth, dims=[1]) 256 | else: 257 | depth = None 258 | return left, right, depth 259 | elif self.cam.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED: 260 | print("End of recording file") 261 | return None, None, None 262 | else: 263 | raise RuntimeError("Could not grab frame") 264 | 265 | def get_K(self, cam="left"): 266 | """ 267 | Get camera intrinsic matrix (calibration matrix K). 268 | 269 | Args: 270 | cam (str): Which camera to use ("left" or "right") 271 | 272 | Returns: 273 | numpy.ndarray: 3x3 intrinsic matrix K 274 | """ 275 | calib = ( 276 | self.cam.get_camera_information().camera_configuration.calibration_parameters 277 | ) 278 | if cam == "left": 279 | intrinsics = calib.left_cam 280 | else: 281 | intrinsics = calib.right_cam 282 | K = np.array( 283 | [ 284 | [intrinsics.fx, 0, intrinsics.cx], 285 | [0, intrinsics.fy, intrinsics.cy], 286 | [0, 0, 1], 287 | ] 288 | ) 289 | return K 290 | 291 | def get_intr(self, cam="left"): 292 | """ 293 | Get camera intrinsics in autolab_core format. 294 | 295 | Args: 296 | cam (str): Which camera to use ("left" or "right") 297 | 298 | Returns: 299 | CameraIntrinsics: Camera intrinsics object 300 | """ 301 | calib = ( 302 | self.cam.get_camera_information().camera_configuration.calibration_parameters 303 | ) 304 | if cam == "left": 305 | intrinsics = calib.left_cam 306 | else: 307 | intrinsics = calib.right_cam 308 | return CameraIntrinsics( 309 | frame="zed", 310 | fx=intrinsics.fx, 311 | fy=intrinsics.fy, 312 | cx=intrinsics.cx, 313 | cy=intrinsics.cy, 314 | width=1280, 315 | height=720, 316 | ) 317 | 318 | def get_stereo_transform(self): 319 | """ 320 | Get transformation matrix from left to right camera. 321 | 322 | Returns: 323 | numpy.ndarray: 4x4 transformation matrix (in meters) 324 | """ 325 | transform = ( 326 | self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.m 327 | ) 328 | transform[:3, 3] /= 1000 # Convert from millimeters to meters 329 | return transform 330 | 331 | def start_record(self, out_path): 332 | """ 333 | Start recording to an SVO file. 334 | 335 | Args: 336 | out_path (str): Output file path 337 | """ 338 | recordingParameters = sl.RecordingParameters() 339 | recordingParameters.compression_mode = sl.SVO_COMPRESSION_MODE.H264 340 | recordingParameters.video_filename = out_path 341 | err = self.cam.enable_recording(recordingParameters) 342 | 343 | def stop_record(self): 344 | """Stop recording to SVO file.""" 345 | self.cam.disable_recording() 346 | 347 | def get_rgb_depth(self, cam="left"): 348 | """ 349 | Get RGB images and depth map. 350 | 351 | Args: 352 | cam (str): Which camera to use as reference ("left" or "right") 353 | 354 | Returns: 355 | tuple: (left_image, right_image, depth_map) as numpy arrays 356 | """ 357 | left, right, depth = self.get_frame(cam=cam) 358 | return left.cpu().numpy(), right.cpu().numpy(), depth.cpu().numpy() 359 | 360 | def get_rgb(self, cam="left"): 361 | """ 362 | Get only RGB images without computing depth. 363 | 364 | Args: 365 | cam (str): Which camera to use as reference ("left" or "right") 366 | 367 | Returns: 368 | tuple: (left_image, right_image) as numpy arrays 369 | """ 370 | left, right, _ = self.get_frame(depth=False, cam=cam) 371 | return left.cpu().numpy(), right.cpu().numpy() 372 | 373 | def get_ns_intrinsics(self): 374 | """ 375 | Get camera intrinsics in NeRF Studio format. 376 | 377 | Returns: 378 | dict: Camera intrinsics dictionary compatible with NeRF Studio 379 | """ 380 | calib = ( 381 | self.cam.get_camera_information().camera_configuration.calibration_parameters 382 | ) 383 | calibration_parameters_l = calib.left_cam 384 | return { 385 | "w": self.width_, 386 | "h": self.height_, 387 | "fl_x": calibration_parameters_l.fx, 388 | "fl_y": calibration_parameters_l.fy, 389 | "cx": calibration_parameters_l.cx, 390 | "cy": calibration_parameters_l.cy, 391 | "k1": calibration_parameters_l.disto[0], 392 | "k2": calibration_parameters_l.disto[1], 393 | "p1": calibration_parameters_l.disto[3], 394 | "p2": calibration_parameters_l.disto[4], 395 | "camera_model": "OPENCV", 396 | } 397 | 398 | def get_zed_depth(self): 399 | """ 400 | Get native ZED depth map (not using RAFT-Stereo). 401 | Note: This requires depth_mode to be enabled in InitParameters. 402 | 403 | Returns: 404 | numpy.ndarray: Depth map 405 | """ 406 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS: 407 | depth = sl.Mat() 408 | self.cam.retrieve_measure(depth, sl.MEASURE.DEPTH) 409 | return depth.get_data() 410 | else: 411 | raise RuntimeError("Could not grab frame") 412 | 413 | def close(self): 414 | """Close the camera and release resources.""" 415 | self.cam.close() 416 | self.cam = None 417 | print("Closed camera") 418 | 419 | def reopen(self): 420 | """Reopen the camera with previously specified parameters.""" 421 | if self.cam is None: 422 | init = sl.InitParameters() 423 | # Configuration for SVO file playback 424 | if self.recording_file is not None: 425 | init.set_from_svo_file(self.recording_file) 426 | init.camera_image_flip = sl.FLIP_MODE.OFF 427 | init.depth_mode = sl.DEPTH_MODE.NONE 428 | init.camera_resolution = sl.RESOLUTION.HD1080 429 | init.sdk_verbose = 1 430 | init.camera_fps = 15 431 | # Configuration for live camera 432 | else: 433 | init.camera_resolution = sl.RESOLUTION.HD720 434 | init.sdk_verbose = 1 435 | init.camera_fps = 15 436 | init.camera_image_flip = sl.FLIP_MODE.OFF 437 | init.depth_mode = sl.DEPTH_MODE.NONE 438 | init.depth_minimum_distance = 100 # millimeters 439 | 440 | self.cam = sl.Camera() 441 | init.camera_disable_self_calib = True 442 | status = self.cam.open(init) 443 | 444 | # Set SVO playback position if applicable 445 | if self.recording_file is not None: 446 | fps = self.cam.get_camera_information().camera_configuration.fps 447 | self.cam.set_svo_position(int(self.start_time * fps)) 448 | 449 | # Check if camera opened successfully 450 | if status != sl.ERROR_CODE.SUCCESS: 451 | print("Camera Open : " + repr(status) + ". Exit program.") 452 | exit() 453 | else: 454 | print("Opened camera") 455 | print( 456 | "Current Exposure is set to: ", 457 | self.cam.get_camera_settings(sl.VIDEO_SETTINGS.EXPOSURE), 458 | ) 459 | 460 | # Reinitialize RAFT model and stereo parameters 461 | self.model = self.create_raft() 462 | left_cx = self.get_K(cam="left")[0, 2] 463 | right_cx = self.get_K(cam="right")[0, 2] 464 | self.cx_diff = right_cx - left_cx -------------------------------------------------------------------------------- /pogs/configs/base_config.py: -------------------------------------------------------------------------------- 1 | 2 | """Base Configs""" 3 | 4 | 5 | from __future__ import annotations 6 | 7 | from dataclasses import dataclass, field 8 | from pathlib import Path 9 | from typing import Any, List, Literal, Optional, Tuple, Type 10 | 11 | # Pretty printing class 12 | class PrintableConfig: 13 | """Printable Config defining str function""" 14 | 15 | def __str__(self): 16 | lines = [self.__class__.__name__ + ":"] 17 | for key, val in vars(self).items(): 18 | if isinstance(val, Tuple): 19 | flattened_val = "[" 20 | for item in val: 21 | flattened_val += str(item) + "\n" 22 | flattened_val = flattened_val.rstrip("\n") 23 | val = flattened_val + "]" 24 | lines += f"{key}: {str(val)}".split("\n") 25 | return "\n ".join(lines) 26 | 27 | 28 | # Base instantiate configs 29 | @dataclass 30 | class InstantiateConfig(PrintableConfig): 31 | """Config class for instantiating an the class specified in the _target attribute.""" 32 | 33 | _target: Type 34 | 35 | def setup(self, **kwargs) -> Any: 36 | """Returns the instantiated object using the config.""" 37 | return self._target(self, **kwargs) -------------------------------------------------------------------------------- /pogs/configs/camera_config.yaml: -------------------------------------------------------------------------------- 1 | third_view_zed: 2 | exposure: 100 3 | flip_mode: false 4 | fps: 30 5 | gain: 31 6 | id: 22008760 7 | resolution: 1080p 8 | wrist_mounted_zed: 9 | exposure: 67 10 | flip_mode: true 11 | fps: 30 12 | gain: 28 13 | id: 16347230 14 | resolution: 720p 15 | -------------------------------------------------------------------------------- /pogs/configs/conda_compile_pointnet_tfops.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Make sure the conda environment is activated so that CONDA_PREFIX is defined. 3 | 4 | # Replace the system paths with conda environment paths: 5 | CUDA_INCLUDE=" -I${CONDA_PREFIX}/include/" 6 | CUDA_LIB=" -L${CONDA_PREFIX}/lib/" 7 | 8 | # Get TensorFlow compile and link flags from the active Python: 9 | TF_CFLAGS=$(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') 10 | TF_LFLAGS=$(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') 11 | 12 | cd pointnet2/tf_ops/sampling 13 | 14 | nvcc -std=c++11 -c -o tf_sampling_g.cu.o tf_sampling_g.cu \ 15 | ${CUDA_INCLUDE} ${TF_CFLAGS} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 16 | 17 | g++ -std=c++11 -shared -o tf_sampling_so.so tf_sampling.cpp \ 18 | tf_sampling_g.cu.o ${CUDA_INCLUDE} ${TF_CFLAGS} -fPIC -lcudart ${TF_LFLAGS} ${CUDA_LIB} 19 | 20 | echo 'testing sampling' 21 | python3 tf_sampling.py 22 | 23 | cd ../grouping 24 | 25 | nvcc -std=c++11 -c -o tf_grouping_g.cu.o tf_grouping_g.cu \ 26 | ${CUDA_INCLUDE} ${TF_CFLAGS} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 27 | 28 | g++ -std=c++11 -shared -o tf_grouping_so.so tf_grouping.cpp \ 29 | tf_grouping_g.cu.o ${CUDA_INCLUDE} ${TF_CFLAGS} -fPIC -lcudart ${TF_LFLAGS} ${CUDA_LIB} 30 | 31 | echo 'testing grouping' 32 | python3 tf_grouping_op_test.py 33 | 34 | cd ../3d_interpolation 35 | 36 | g++ -std=c++11 tf_interpolate.cpp -o tf_interpolate_so.so -shared -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2 37 | 38 | echo 'testing interpolate' 39 | python3 tf_interpolate_op_test.py 40 | -------------------------------------------------------------------------------- /pogs/contact_graspnet_wrapper/prime_config_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | def recursive_key_value_assign(d,ks,v): 5 | """ 6 | Recursive value assignment to a nested dict 7 | 8 | Arguments: 9 | d {dict} -- dict 10 | ks {list} -- list of hierarchical keys 11 | v {value} -- value to assign 12 | """ 13 | 14 | if len(ks) > 1: 15 | recursive_key_value_assign(d[ks[0]],ks[1:],v) 16 | elif len(ks) == 1: 17 | d[ks[0]] = v 18 | 19 | def load_config(checkpoint_dir, batch_size=None, max_epoch=None, data_path=None, arg_configs=[], save=False): 20 | """ 21 | Loads yaml config file and overwrites parameters with function arguments and --arg_config parameters 22 | 23 | Arguments: 24 | checkpoint_dir {str} -- Checkpoint directory where config file was copied to 25 | 26 | Keyword Arguments: 27 | batch_size {int} -- [description] (default: {None}) 28 | max_epoch {int} -- "epochs" (number of scenes) to train (default: {None}) 29 | data_path {str} -- path to scenes with contact grasp data (default: {None}) 30 | arg_configs {list} -- Overwrite config parameters by hierarchical command line arguments (default: {[]}) 31 | save {bool} -- Save overwritten config file (default: {False}) 32 | 33 | Returns: 34 | [dict] -- Config 35 | """ 36 | 37 | config_path = os.path.join(checkpoint_dir, 'config.yaml') 38 | config_path = config_path if os.path.exists(config_path) else os.path.join(os.path.dirname(__file__),'config.yaml') 39 | with open(config_path,'r') as f: 40 | global_config = yaml.safe_load(f) 41 | 42 | for conf in arg_configs: 43 | k_str, v = conf.split(':') 44 | try: 45 | v = eval(v) 46 | except: 47 | pass 48 | ks = [int(k) if k.isdigit() else k for k in k_str.split('.')] 49 | 50 | recursive_key_value_assign(global_config, ks, v) 51 | 52 | if batch_size is not None: 53 | global_config['OPTIMIZER']['batch_size'] = int(batch_size) 54 | if max_epoch is not None: 55 | global_config['OPTIMIZER']['max_epoch'] = int(max_epoch) 56 | if data_path is not None: 57 | global_config['DATA']['data_path'] = data_path 58 | 59 | global_config['DATA']['classes'] = None 60 | 61 | if save: 62 | with open(os.path.join(checkpoint_dir, 'config.yaml'),'w') as f: 63 | yaml.dump(global_config, f) 64 | 65 | return global_config 66 | 67 | -------------------------------------------------------------------------------- /pogs/contact_graspnet_wrapper/prime_visualization_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mayavi.mlab as mlab 3 | import matplotlib.pyplot as plt 4 | from matplotlib import cm 5 | 6 | import mesh_utils 7 | 8 | def plot_mesh(mesh, cam_trafo=np.eye(4), mesh_pose=np.eye(4)): 9 | """ 10 | Plots mesh in mesh_pose from 11 | 12 | Arguments: 13 | mesh {trimesh.base.Trimesh} -- input mesh, e.g. gripper 14 | 15 | Keyword Arguments: 16 | cam_trafo {np.ndarray} -- 4x4 transformation from world to camera coords (default: {np.eye(4)}) 17 | mesh_pose {np.ndarray} -- 4x4 transformation from mesh to world coords (default: {np.eye(4)}) 18 | """ 19 | 20 | homog_mesh_vert = np.pad(mesh.vertices, (0, 1), 'constant', constant_values=(0, 1)) 21 | mesh_cam = homog_mesh_vert.dot(mesh_pose.T).dot(cam_trafo.T)[:,:3] 22 | mlab.triangular_mesh(mesh_cam[:, 0], 23 | mesh_cam[:, 1], 24 | mesh_cam[:, 2], 25 | mesh.faces, 26 | colormap='Blues', 27 | opacity=0.5) 28 | 29 | def plot_coordinates(t,r, tube_radius=0.005): 30 | """ 31 | plots coordinate frame 32 | 33 | Arguments: 34 | t {np.ndarray} -- translation vector 35 | r {np.ndarray} -- rotation matrix 36 | 37 | Keyword Arguments: 38 | tube_radius {float} -- radius of the plotted tubes (default: {0.005}) 39 | """ 40 | mlab.plot3d([t[0],t[0]+0.2*r[0,0]], [t[1],t[1]+0.2*r[1,0]], [t[2],t[2]+0.2*r[2,0]], color=(1,0,0), tube_radius=tube_radius, opacity=1) 41 | mlab.plot3d([t[0],t[0]+0.2*r[0,1]], [t[1],t[1]+0.2*r[1,1]], [t[2],t[2]+0.2*r[2,1]], color=(0,1,0), tube_radius=tube_radius, opacity=1) 42 | mlab.plot3d([t[0],t[0]+0.2*r[0,2]], [t[1],t[1]+0.2*r[1,2]], [t[2],t[2]+0.2*r[2,2]], color=(0,0,1), tube_radius=tube_radius, opacity=1) 43 | 44 | def show_image(rgb, segmap): 45 | """ 46 | Overlay rgb image with segmentation and imshow segment 47 | 48 | Arguments: 49 | rgb {np.ndarray} -- color image 50 | segmap {np.ndarray} -- integer segmap of same size as rgb 51 | """ 52 | plt.figure() 53 | figManager = plt.get_current_fig_manager() 54 | 55 | plt.ion() 56 | plt.show() 57 | 58 | if rgb is not None: 59 | plt.imshow(rgb) 60 | if segmap is not None: 61 | cmap = plt.get_cmap('rainbow') 62 | cmap.set_under(alpha=0.0) 63 | plt.imshow(segmap, cmap=cmap, alpha=0.5, vmin=0.0001) 64 | plt.draw() 65 | plt.pause(0.001) 66 | 67 | def visualize_grasps(full_pc, pred_grasps_cam, scores, plot_opencv_cam=False, pc_colors=None, gripper_openings=None, gripper_width=0.08): 68 | """Visualizes colored point cloud and predicted grasps. If given, colors grasps by segmap regions. 69 | Thick grasp is most confident per segment. For scene point cloud predictions, colors grasps according to confidence. 70 | 71 | Arguments: 72 | full_pc {np.ndarray} -- Nx3 point cloud of the scene 73 | pred_grasps_cam {dict[int:np.ndarray]} -- Predicted 4x4 grasp trafos per segment or for whole point cloud 74 | scores {dict[int:np.ndarray]} -- Confidence scores for grasps 75 | 76 | Keyword Arguments: 77 | plot_opencv_cam {bool} -- plot camera coordinate frame (default: {False}) 78 | pc_colors {np.ndarray} -- Nx3 point cloud colors (default: {None}) 79 | gripper_openings {dict[int:np.ndarray]} -- Predicted grasp widths (default: {None}) 80 | gripper_width {float} -- If gripper_openings is None, plot grasp widths (default: {0.008}) 81 | """ 82 | 83 | print('Visualizing...takes time') 84 | cm = plt.get_cmap('rainbow') 85 | cm2 = plt.get_cmap('gist_rainbow') 86 | 87 | fig = mlab.figure('Pred Grasps') 88 | mlab.view(azimuth=180, elevation=180, distance=0.2) 89 | draw_pc_with_colors(full_pc, pc_colors) 90 | colors = [cm(1. * i/len(pred_grasps_cam))[:3] for i in range(len(pred_grasps_cam))] 91 | colors2 = {k:cm2(0.5*np.max(scores[k]))[:3] for k in pred_grasps_cam if np.any(pred_grasps_cam[k])} 92 | 93 | if plot_opencv_cam: 94 | plot_coordinates(np.zeros(3,),np.eye(3,3)) 95 | for i,k in enumerate(pred_grasps_cam): 96 | if np.any(pred_grasps_cam[k]): 97 | gripper_openings_k = np.ones(len(pred_grasps_cam[k]))*gripper_width if gripper_openings is None else gripper_openings[k] 98 | if len(pred_grasps_cam) > 1: 99 | draw_grasps(pred_grasps_cam[k], np.eye(4), color=colors[i], gripper_openings=gripper_openings_k) 100 | draw_grasps([pred_grasps_cam[k][np.argmax(scores[k])]], np.eye(4), color=colors2[k], 101 | gripper_openings=[gripper_openings_k[np.argmax(scores[k])]], tube_radius=0.0025) 102 | else: 103 | colors3 = [cm2(0.5*score)[:3] for score in scores[k]] 104 | draw_grasps(pred_grasps_cam[k], np.eye(4), colors=colors3, gripper_openings=gripper_openings_k) 105 | mlab.show() 106 | 107 | def draw_pc_with_colors(pc, pc_colors=None, single_color=(0.3,0.3,0.3), mode='2dsquare', scale_factor=0.0018): 108 | """ 109 | Draws colored point clouds 110 | 111 | Arguments: 112 | pc {np.ndarray} -- Nx3 point cloud 113 | pc_colors {np.ndarray} -- Nx3 point cloud colors 114 | 115 | Keyword Arguments: 116 | single_color {tuple} -- single color for point cloud (default: {(0.3,0.3,0.3)}) 117 | mode {str} -- primitive type to plot (default: {'point'}) 118 | scale_factor {float} -- Scale of primitives. Does not work for points. (default: {0.002}) 119 | 120 | """ 121 | 122 | if pc_colors is None: 123 | mlab.points3d(pc[:, 0], pc[:, 1], pc[:, 2], color=single_color, scale_factor=scale_factor, mode=mode) 124 | else: 125 | #create direct grid as 256**3 x 4 array 126 | def create_8bit_rgb_lut(): 127 | xl = np.mgrid[0:256, 0:256, 0:256] 128 | lut = np.vstack((xl[0].reshape(1, 256**3), 129 | xl[1].reshape(1, 256**3), 130 | xl[2].reshape(1, 256**3), 131 | 255 * np.ones((1, 256**3)))).T 132 | return lut.astype('int32') 133 | 134 | scalars = pc_colors[:,0]*256**2 + pc_colors[:,1]*256 + pc_colors[:,2] 135 | rgb_lut = create_8bit_rgb_lut() 136 | points_mlab = mlab.points3d(pc[:, 0], pc[:, 1], pc[:, 2], scalars, mode=mode, scale_factor=.0018) 137 | points_mlab.glyph.scale_mode = 'scale_by_vector' 138 | points_mlab.module_manager.scalar_lut_manager.lut._vtk_obj.SetTableRange(0, rgb_lut.shape[0]) 139 | points_mlab.module_manager.scalar_lut_manager.lut.number_of_colors = rgb_lut.shape[0] 140 | points_mlab.module_manager.scalar_lut_manager.lut.table = rgb_lut 141 | 142 | def draw_grasps(grasps, cam_pose, gripper_openings, color=(0,1.,0), colors=None, show_gripper_mesh=False, tube_radius=0.0008): 143 | """ 144 | Draws wireframe grasps from given camera pose and with given gripper openings 145 | 146 | Arguments: 147 | grasps {np.ndarray} -- Nx4x4 grasp pose transformations 148 | cam_pose {np.ndarray} -- 4x4 camera pose transformation 149 | gripper_openings {np.ndarray} -- Nx1 gripper openings 150 | 151 | Keyword Arguments: 152 | color {tuple} -- color of all grasps (default: {(0,1.,0)}) 153 | colors {np.ndarray} -- Nx3 color of each grasp (default: {None}) 154 | tube_radius {float} -- Radius of the grasp wireframes (default: {0.0008}) 155 | show_gripper_mesh {bool} -- Renders the gripper mesh for one of the grasp poses (default: {False}) 156 | """ 157 | 158 | gripper = mesh_utils.create_gripper('panda') 159 | gripper_control_points = gripper.get_control_point_tensor(1, False, convex_hull=False).squeeze() 160 | mid_point = 0.5*(gripper_control_points[1, :] + gripper_control_points[2, :]) 161 | grasp_line_plot = np.array([np.zeros((3,)), mid_point, gripper_control_points[1], gripper_control_points[3], 162 | gripper_control_points[1], gripper_control_points[2], gripper_control_points[4]]) 163 | 164 | if show_gripper_mesh and len(grasps) > 0: 165 | plot_mesh(gripper.hand, cam_pose, grasps[0]) 166 | 167 | all_pts = [] 168 | connections = [] 169 | index = 0 170 | N = 7 171 | for i,(g,g_opening) in enumerate(zip(grasps, gripper_openings)): 172 | gripper_control_points_closed = grasp_line_plot.copy() 173 | gripper_control_points_closed[2:,0] = np.sign(grasp_line_plot[2:,0]) * g_opening/2 174 | 175 | pts = np.matmul(gripper_control_points_closed, g[:3, :3].T) 176 | pts += np.expand_dims(g[:3, 3], 0) 177 | pts_homog = np.concatenate((pts, np.ones((7, 1))),axis=1) 178 | pts = np.dot(pts_homog, cam_pose.T)[:,:3] 179 | 180 | color = color if colors is None else colors[i] 181 | 182 | all_pts.append(pts) 183 | connections.append(np.vstack([np.arange(index, index + N - 1.5), 184 | np.arange(index + 1, index + N - .5)]).T) 185 | index += N 186 | # mlab.plot3d(pts[:, 0], pts[:, 1], pts[:, 2], color=color, tube_radius=tube_radius, opacity=1.0) 187 | 188 | # speeds up plot3d because only one vtk object 189 | all_pts = np.vstack(all_pts) 190 | connections = np.vstack(connections) 191 | src = mlab.pipeline.scalar_scatter(all_pts[:,0], all_pts[:,1], all_pts[:,2]) 192 | src.mlab_source.dataset.lines = connections 193 | src.update() 194 | lines =mlab.pipeline.tube(src, tube_radius=tube_radius, tube_sides=12) 195 | mlab.pipeline.surface(lines, color=color, opacity=1.0) 196 | 197 | -------------------------------------------------------------------------------- /pogs/data/depth_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Depth dataset. 17 | """ 18 | 19 | import json 20 | from pathlib import Path 21 | from typing import Dict, Union 22 | 23 | import numpy as np 24 | import torch 25 | from PIL import Image 26 | from rich.progress import track 27 | 28 | from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs 29 | from nerfstudio.data.datasets.base_dataset import InputDataset 30 | from nerfstudio.data.utils.data_utils import get_depth_image_from_path 31 | from nerfstudio.model_components import losses 32 | from nerfstudio.utils.misc import torch_compile 33 | from nerfstudio.utils.rich_utils import CONSOLE 34 | 35 | 36 | class DepthDataset(InputDataset): 37 | """Dataset that returns images and depths. If no depths are found, then we generate them with Zoe Depth. 38 | 39 | Args: 40 | dataparser_outputs: description of where and how to read input images. 41 | scale_factor: The scaling factor for the dataparser outputs. 42 | """ 43 | 44 | def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0): 45 | super().__init__(dataparser_outputs, scale_factor) 46 | # if there are no depth images than we want to generate them all with zoe depth 47 | if len(dataparser_outputs.image_filenames) > 0 and ( 48 | "depth_filenames" not in dataparser_outputs.metadata.keys() 49 | or dataparser_outputs.metadata["depth_filenames"] is None 50 | ): 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | CONSOLE.print("[bold yellow] No depth data found!") 53 | losses.FORCE_PSEUDODEPTH_LOSS = True 54 | CONSOLE.print("[bold red] Using psueodepth: forcing depth loss to be ranking loss.") 55 | cache = dataparser_outputs.image_filenames[0].parent / "depths.npy" 56 | # Note: this should probably be saved to disk as images, and then loaded with the dataparser. 57 | # That will allow multi-gpu training. 58 | if cache.exists(): 59 | CONSOLE.print("[bold yellow] Loading pseudodata depth from cache!") 60 | # load all the depths 61 | self.depths = np.load(cache) 62 | self.depths = torch.from_numpy(self.depths).to(device) 63 | else: 64 | CONSOLE.print("[bold yellow] No cache found...") 65 | dataparser_outputs.metadata["depth_filenames"] = None 66 | dataparser_outputs.metadata["depth_unit_scale_factor"] = 1.0 67 | self.metadata["depth_filenames"] = None 68 | self.metadata["depth_unit_scale_factor"] = 1.0 69 | 70 | self.depth_filenames = self.metadata["depth_filenames"] 71 | self.depth_unit_scale_factor = self.metadata["depth_unit_scale_factor"] 72 | 73 | def get_metadata(self, data: Dict) -> Dict: 74 | if self.depth_filenames is None: 75 | return {} 76 | 77 | filepath = self.depth_filenames[data["image_idx"]] 78 | height = int(self._dataparser_outputs.cameras.height[data["image_idx"]]) 79 | width = int(self._dataparser_outputs.cameras.width[data["image_idx"]]) 80 | 81 | # Scale depth images to meter units and also by scaling applied to cameras 82 | scale_factor = self.depth_unit_scale_factor * self._dataparser_outputs.dataparser_scale 83 | depth_image = get_depth_image_from_path( 84 | filepath=filepath, height=height, width=width, scale_factor=scale_factor 85 | ) 86 | 87 | return {"depth_image": depth_image} 88 | 89 | def _find_transform(self, image_path: Path) -> Union[Path, None]: 90 | while image_path.parent != image_path: 91 | transform_path = image_path.parent / "transforms.json" 92 | if transform_path.exists(): 93 | return transform_path 94 | image_path = image_path.parent 95 | return None -------------------------------------------------------------------------------- /pogs/data/utils/detic_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import argparse 4 | import json 5 | import os 6 | import pickle 7 | import random 8 | from typing import List, Dict, Tuple 9 | import sys 10 | import time 11 | import cv2 12 | from pogs.data.utils.feature_dataloader import FeatureDataloader 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from matplotlib.animation import FuncAnimation 16 | import matplotlib.colors as mcolors 17 | # from mpl_toolkits.mplot3d import Axes3D 18 | import open3d as o3d 19 | from PIL import Image 20 | import torch 21 | from tqdm import tqdm 22 | 23 | # Change the current working directory to 'Detic' 24 | dir_path = os.path.dirname(os.path.realpath(__file__)) 25 | cwd = os.getcwd() 26 | os.chdir(dir_path+'/../../dependencies/Detic') 27 | 28 | # Setup detectron2 logger 29 | import detectron2 30 | from detectron2.utils.logger import setup_logger 31 | setup_logger() 32 | 33 | # import common libraries 34 | sys.path.insert(0, os.getcwd()+'/third_party/CenterNet2/') 35 | 36 | # import some common detectron2 utilities 37 | from detectron2 import model_zoo 38 | from detectron2.engine import DefaultPredictor 39 | from detectron2.config import get_cfg 40 | from detectron2.utils.visualizer import Visualizer 41 | from detectron2.data import MetadataCatalog, DatasetCatalog 42 | 43 | # Detic libraries 44 | from collections import defaultdict 45 | from centernet.config import add_centernet_config 46 | from pogs.dependencies.Detic.detic.config import add_detic_config 47 | from pogs.dependencies.Detic.detic.modeling.utils import reset_cls_test 48 | from sklearn.cluster import DBSCAN 49 | import matplotlib.patches as patches 50 | import torch.nn.functional as F 51 | 52 | os.chdir(cwd) 53 | 54 | class DeticDataloader(FeatureDataloader): 55 | def __init__( 56 | self, 57 | cfg: dict, 58 | device: torch.device, 59 | image_list: torch.Tensor = None, 60 | cache_path: str = None, 61 | ): 62 | # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | self.sam = cfg["sam"] 64 | self.cstm_vocab = cfg["custom_vocab"] 65 | self.downscale_factor = cfg["downscale_factor"] 66 | # image_list: torch.Tensor = None, 67 | self.outs = [], 68 | super().__init__(cfg, device, image_list, cache_path) 69 | 70 | def create(self, image_list = None): 71 | # os.makedirs(self.cache_path, exist_ok=True) 72 | # Build the detector and download our pretrained weights 73 | cfg = get_cfg() 74 | add_centernet_config(cfg) 75 | add_detic_config(cfg) 76 | cfg.merge_from_file(dir_path+'/../../dependencies/Detic/configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml') 77 | cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth' 78 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model 79 | cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand' 80 | cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True # For better visualization purpose. Set to False for all classes. 81 | # cfg.MODEL.DEVICE='cpu' # uncomment this to use cpu-only mode. 82 | os.chdir(dir_path+'/../../dependencies/Detic') 83 | self.detic_predictor = DefaultPredictor(cfg) 84 | os.chdir(cwd) 85 | if self.sam == True: 86 | from segment_anything import sam_model_registry, SamPredictor 87 | sam_checkpoint = "../sam_model/sam_vit_h_4b8939.pth" 88 | model_type = "vit_h" 89 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 90 | sam.to(device=self.device) 91 | print('SAM + Detic on device: ', self.device) 92 | self.sam_predictor = SamPredictor(sam) 93 | if len(self.cstm_vocab) > 0: 94 | print("Using custom vocabulary with classes: ", self.cstm_vocab) 95 | self.custom_vocab(self.cstm_vocab) 96 | else: 97 | print("Using default vocabulary") 98 | os.chdir(dir_path+'/../../dependencies/Detic') 99 | self.default_vocab() 100 | os.chdir(cwd) 101 | 102 | if image_list is not None: 103 | 104 | start_time = time.time() 105 | for idx, img in enumerate(tqdm(image_list, desc="Detic Detector", leave=False)): 106 | H, W = img.shape[-2:] 107 | img = (img.permute(1, 2, 0).numpy()*255).astype(np.uint8) 108 | 109 | output = self.detic_predictor(img[:, :, ::-1]) 110 | instances = output["instances"].to('cpu') 111 | 112 | boxes = instances.pred_boxes.tensor.numpy() 113 | 114 | masks = None 115 | components = torch.zeros(H, W) 116 | if self.sam: 117 | if len(boxes) > 0: 118 | # Only run SAM if there are bboxes 119 | masks = self.SAM(img, boxes) 120 | for i in range(masks.shape[0]): 121 | if torch.sum(masks[i][0]) <= H*W/3.5: 122 | components[masks[i][0]] = i + 1 123 | else: 124 | masks = output['instances'].pred_masks.unsqueeze(1) 125 | for i in range(masks.shape[0]): 126 | if torch.sum(masks[i][0]) <= H*W/3.5: 127 | components[masks[i][0]] = i + 1 128 | bg_mask = (components == 0).to(self.device) 129 | 130 | # Erode all masks using 3x3 kernel 131 | eroded_masks = torch.conv2d( 132 | (~masks).float().cuda(), 133 | torch.full((3, 3), 1.0).view(1, 1, 3, 3).to("cuda"), 134 | padding=1, 135 | ) 136 | eroded_masks = ~(eroded_masks >= 2) 137 | 138 | # Filter out small masks 139 | filtered_idx = [] 140 | for i in range(len(masks)): 141 | if masks[i].sum(dim=(1,2)) <= H*W/3.5: 142 | filtered_idx.append(i) 143 | filtered_masks = torch.cat([eroded_masks[filtered_idx], bg_mask.unsqueeze(0).unsqueeze(0)], dim=0).cpu().numpy() 144 | 145 | if self.downscale_factor > 1: 146 | scaled_height = H//self.downscale_factor 147 | scaled_width = W//self.downscale_factor 148 | filtered_masks = F.interpolate(torch.from_numpy(filtered_masks).to(float), (scaled_height, scaled_width), mode = 'nearest').to(bool).squeeze(1).view(-1, scaled_height*scaled_width) 149 | filtered_masks = filtered_masks.numpy() 150 | 151 | outputs = { 152 | # "vis": out, 153 | "boxes": boxes, 154 | "masks": masks, 155 | "masks_filtered": filtered_masks, 156 | # "class_idx": class_idx, 157 | # "class_name": class_name, 158 | # "clip_embeds": clip_embeds, 159 | "components": components, 160 | "scores" : output["instances"].scores, 161 | } 162 | 163 | self.outs[0].append(outputs['masks_filtered']) 164 | 165 | self.data = np.empty(len(image_list), dtype=object) 166 | self.data[:] = self.outs[0] 167 | 168 | print("Detic batch inference time: ", time.time() - start_time) 169 | 170 | # Overridden load method 171 | def load(self): 172 | cache_info_path = self.cache_path.with_suffix(".info") 173 | # print(cache_info_path) 174 | # print(cache_info_path.exists()) 175 | # import pdb; pdb.set_trace() 176 | if not cache_info_path.exists(): 177 | raise FileNotFoundError 178 | with open(cache_info_path, "r") as f: 179 | cfg = json.loads(f.read()) 180 | if cfg != self.cfg: 181 | raise ValueError("Config mismatch") 182 | self.data = np.load(self.cache_path, allow_pickle=True) 183 | 184 | # Overridden save method 185 | def save(self): 186 | os.makedirs(self.cache_path.parent, exist_ok=True) 187 | cache_info_path = self.cache_path.with_suffix(".info") 188 | with open(cache_info_path, "w") as f: 189 | f.write(json.dumps(self.cfg)) 190 | np.save(self.cache_path, self.data) 191 | 192 | def __call__(self, img_idx): 193 | return NotImplementedError 194 | 195 | def default_vocab(self): 196 | # detic_predictor = self.detic_predictor 197 | # Setup the model's vocabulary using build-in datasets 198 | BUILDIN_CLASSIFIER = { 199 | 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy', 200 | 'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy', 201 | 'openimages': 'datasets/metadata/oid_clip_a+cname.npy', 202 | 'coco': 'datasets/metadata/coco_clip_a+cname.npy', 203 | } 204 | 205 | BUILDIN_METADATA_PATH = { 206 | 'lvis': 'lvis_v1_val', 207 | 'objects365': 'objects365_v2_val', 208 | 'openimages': 'oid_val_expanded', 209 | 'coco': 'coco_2017_val', 210 | } 211 | 212 | vocabulary = 'lvis' # change to 'lvis', 'objects365', 'openimages', or 'coco' 213 | self.metadata = MetadataCatalog.get(BUILDIN_METADATA_PATH[vocabulary]) 214 | classifier = BUILDIN_CLASSIFIER[vocabulary] 215 | 216 | num_classes = len(self.metadata.thing_classes) 217 | reset_cls_test(self.detic_predictor.model, classifier, num_classes) 218 | 219 | def custom_vocab(self, classes): 220 | os.chdir(dir_path+'/../../dependencies/Detic') 221 | self.metadata = MetadataCatalog.get("__unused2") 222 | os.chdir(cwd) 223 | self.metadata.thing_classes = classes 224 | classifier = self.get_clip_embeddings(self.metadata.thing_classes) 225 | num_classes = len(self.metadata.thing_classes) 226 | reset_cls_test(self.detic_predictor.model, classifier, num_classes) 227 | 228 | # Reset visualization threshold 229 | output_score_threshold = 0.3 230 | for cascade_stages in range(len(self.detic_predictor.model.roi_heads.box_predictor)): 231 | self.detic_predictor.model.roi_heads.box_predictor[cascade_stages].test_score_thresh = output_score_threshold 232 | 233 | def get_clip_embeddings(self, vocabulary, prompt='a '): 234 | self.text_encoder.eval() 235 | texts = [prompt + x for x in vocabulary] 236 | emb = self.text_encoder(texts).detach().permute(1, 0).contiguous().cpu() 237 | return emb 238 | 239 | def SAM(self, im, boxes, class_idx = None, metadata = None): 240 | self.sam_predictor.set_image(im) 241 | input_boxes = torch.tensor(boxes, device=self.sam_predictor.device) 242 | transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(input_boxes, im.shape[:2]) 243 | masks, _, _ = self.sam_predictor.predict_torch( 244 | point_coords=None, 245 | point_labels=None, 246 | boxes=transformed_boxes, 247 | multimask_output=False, 248 | ) 249 | return masks 250 | 251 | def visualize_detic(self, output): 252 | output_im = output.get_image()[:, :, ::-1] 253 | cv2.imshow("Detic Predictions", output_im) 254 | cv2.waitKey(0) 255 | cv2.destroyAllWindows() 256 | 257 | def predict(self, im): 258 | if im is None: 259 | print("Error: Unable to read the image file") 260 | 261 | H, W = im.shape[:2] 262 | 263 | # Run model and show results 264 | start_time = time.time() 265 | output = self.detic_predictor(im[:, :, ::-1]) # Detic expects BGR images. 266 | print("Inference time: ", time.time() - start_time) 267 | v = Visualizer(im, self.metadata) 268 | out = v.draw_instance_predictions(output["instances"].to('cpu')) 269 | instances = output["instances"].to('cpu') 270 | boxes = instances.pred_boxes.tensor.numpy() 271 | # class_idx = instances.pred_classes.numpy() 272 | # class_name = [self.metadata.thing_classes[idx] for idx in class_idx] 273 | # clip_embeds = self.get_clip_embeddings(class_name) 274 | 275 | masks = None 276 | components = torch.zeros(H, W) 277 | if self.sam: 278 | if len(boxes) > 0: 279 | # Only run SAM if there are bboxes 280 | masks = self.SAM(im, boxes) 281 | for i in range(masks.shape[0]): 282 | if torch.sum(masks[i][0]) <= H*W/3.5: 283 | components[masks[i][0]] = i + 1 284 | else: 285 | masks = output['instances'].pred_masks.unsqueeze(1) 286 | for i in range(masks.shape[0]): 287 | if torch.sum(masks[i][0]) <= H*W/3.5: 288 | components[masks[i][0]] = i + 1 289 | bg_mask = (components == 0).to(self.device) 290 | 291 | # Filter out small masks 292 | filtered_idx = [] 293 | for i in range(len(masks)): 294 | if masks[i].sum(dim=(1,2)) <= H*W/3.5: 295 | filtered_idx.append(i) 296 | filtered_masks = torch.cat([masks[filtered_idx], bg_mask.unsqueeze(0).unsqueeze(0)], dim=0) 297 | 298 | # invert_masks = ~filtered_masks 299 | # # erode all masks using 3x3 kernel 300 | # eroded_masks = torch.conv2d( 301 | # invert_masks.float(), 302 | # torch.full((3, 3), 1.0).view(1, 1, 3, 3).to("cuda"), 303 | # padding=1, 304 | # ) 305 | # filtered_masks = ~(eroded_masks >= 5).squeeze(1) # (num_masks, H, W) 306 | 307 | 308 | outputs = { 309 | "vis": out, 310 | "boxes": boxes, 311 | "masks": masks, 312 | "masks_filtered": filtered_masks, 313 | # "class_idx": class_idx, 314 | # "class_name": class_name, 315 | # "clip_embeds": clip_embeds, 316 | "components": components, 317 | "scores" : output["instances"].scores, 318 | } 319 | return outputs 320 | 321 | 322 | def show_mask(self, mask, ax, random_color=False): 323 | if random_color: 324 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 325 | else: 326 | color = np.array([30/255, 144/255, 255/255, 0.6]) 327 | h, w = mask.shape[-2:] 328 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 329 | ax.imshow(mask_image) 330 | 331 | 332 | def show_box(self, box, ax): 333 | x0, y0 = box[0], box[1] 334 | w, h = box[2] - box[0], box[3] - box[1] 335 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 336 | 337 | 338 | def visualize_output(self, im, masks, input_boxes, classes, image_save_path, mask_only=False): 339 | plt.figure(figsize=(10, 10)) 340 | plt.imshow(im) 341 | for mask in masks: 342 | self.show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) 343 | if not mask_only: 344 | for box, class_name in zip(input_boxes, classes): 345 | self.show_box(box, plt.gca()) 346 | x, y = box[:2] 347 | plt.gca().text(x, y - 5, class_name, color='white', fontsize=12, fontweight='bold', bbox=dict(facecolor='green', edgecolor='green', alpha=0.5)) 348 | plt.axis('off') 349 | plt.savefig(image_save_path) 350 | #plt.show() 351 | 352 | 353 | def generate_colors(self, num_colors): 354 | hsv_colors = [] 355 | for i in range(num_colors): 356 | hue = i / float(num_colors) 357 | hsv_colors.append((hue, 1.0, 1.0)) 358 | 359 | return [mcolors.hsv_to_rgb(color) for color in hsv_colors] -------------------------------------------------------------------------------- /pogs/data/utils/dino_dataloader.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | from pogs.data.utils.dino_extractor import ViTExtractor 5 | from pogs.data.utils.feature_dataloader import FeatureDataloader 6 | from tqdm import tqdm 7 | from torchvision import transforms 8 | from typing import Tuple 9 | import numpy as np 10 | 11 | #usually 1260 max size, 1050 for vit-L with ROI 12 | MAX_DINO_SIZE = 1260 13 | def get_img_resolution(H, W, p=14): 14 | if H torch.Tensor: 102 | """ 103 | returns BxHxWxC 104 | """ 105 | return self.data[img_ind].to(self.device) -------------------------------------------------------------------------------- /pogs/data/utils/feature_dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import typing 4 | from abc import ABC, ABCMeta, abstractmethod 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class FeatureDataloader(ABC): 12 | def __init__( 13 | self, 14 | cfg: dict, 15 | device: torch.device, 16 | image_list: torch.Tensor, # (N, 3, H, W) 17 | cache_path: Path, 18 | ): 19 | self.cfg = cfg 20 | self.device = device 21 | self.cache_path = cache_path 22 | self.data = None # only expect data to be cached, nothing else 23 | self.try_load(image_list) # don't save image_list, avoid duplicates 24 | 25 | @abstractmethod 26 | def __call__(self, img_points): 27 | pass 28 | 29 | @abstractmethod 30 | def create(self, image_list: torch.Tensor): 31 | pass 32 | 33 | def load(self): 34 | cache_info_path = self.cache_path.with_suffix(".info") 35 | print(cache_info_path) 36 | import os 37 | print(os.getcwd()) 38 | if not cache_info_path.exists(): 39 | raise FileNotFoundError 40 | with open(cache_info_path, "r") as f: 41 | cfg = json.loads(f.read()) 42 | if cfg != self.cfg: 43 | raise ValueError("Config mismatch") 44 | self.data = torch.from_numpy(np.load(self.cache_path)).to(self.device) 45 | 46 | def save(self): 47 | os.makedirs(self.cache_path.parent, exist_ok=True) 48 | cache_info_path = self.cache_path.with_suffix(".info") 49 | with open(cache_info_path, "w") as f: 50 | f.write(json.dumps(self.cfg)) 51 | np.save(self.cache_path, self.data.numpy(force=True)) 52 | 53 | def try_load(self, img_list: torch.Tensor): 54 | try: 55 | self.load() 56 | except (FileNotFoundError, ValueError): 57 | self.create(img_list) 58 | self.save() -------------------------------------------------------------------------------- /pogs/data/utils/patch_embedding_dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import torch 5 | from pogs.data.utils.feature_dataloader import FeatureDataloader 6 | from pogs.encoders.image_encoder import BaseImageEncoder 7 | from tqdm import tqdm 8 | 9 | 10 | class PatchEmbeddingDataloader(FeatureDataloader): 11 | def __init__( 12 | self, 13 | cfg: dict, 14 | device: torch.device, 15 | model: BaseImageEncoder, 16 | image_list: torch.Tensor = None, 17 | cache_path: str = None, 18 | ): 19 | assert "tile_ratio" in cfg 20 | assert "stride_ratio" in cfg 21 | assert "image_shape" in cfg 22 | assert "model_name" in cfg 23 | 24 | self.kernel_size = int(cfg["image_shape"][0] * cfg["tile_ratio"]) 25 | self.stride = int(self.kernel_size * cfg["stride_ratio"]) 26 | self.padding = self.kernel_size // 2 27 | self.center_x = ( 28 | (self.kernel_size - 1) / 2 29 | - self.padding 30 | + self.stride 31 | * np.arange( 32 | np.floor((cfg["image_shape"][0] + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1) 33 | ) 34 | ) 35 | self.center_y = ( 36 | (self.kernel_size - 1) / 2 37 | - self.padding 38 | + self.stride 39 | * np.arange( 40 | np.floor((cfg["image_shape"][1] + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1) 41 | ) 42 | ) 43 | self.center_x = torch.from_numpy(self.center_x).half() 44 | self.center_y = torch.from_numpy(self.center_y).half() 45 | self.start_x = self.center_x[0].float() 46 | self.start_y = self.center_y[0].float() 47 | 48 | self.model = model 49 | self.embed_size = self.model.embedding_dim 50 | super().__init__(cfg, device, image_list, cache_path) 51 | 52 | def load(self): 53 | cache_info_path = self.cache_path.with_suffix(".info") 54 | if not cache_info_path.exists(): 55 | raise FileNotFoundError 56 | with open(cache_info_path, "r") as f: 57 | cfg = json.loads(f.read()) 58 | if cfg != self.cfg: 59 | raise ValueError("Config mismatch") 60 | self.data = torch.from_numpy(np.load(self.cache_path)).half().to(self.device) 61 | 62 | def create(self, image_list): 63 | assert self.model is not None, "model must be provided to generate features" 64 | assert image_list is not None, "image_list must be provided to generate features" 65 | 66 | unfold_func = torch.nn.Unfold( 67 | kernel_size=self.kernel_size, 68 | stride=self.stride, 69 | padding=self.padding, 70 | ).to(self.device) 71 | 72 | img_embeds = [] 73 | for img in tqdm(image_list, desc="CLIP Embedding Images", leave=False): 74 | img_embeds.append(self._embed_clip_tiles(img.unsqueeze(0), unfold_func)) 75 | self.data = torch.from_numpy(np.stack(img_embeds)).half().to(self.device) 76 | 77 | def __call__(self, img_points): 78 | # img_points: (B, 3) # (img_ind, x, y) (img_ind, row, col) 79 | # return: (B, 512) 80 | img_points = img_points.cpu() 81 | img_ind, img_points_x, img_points_y = img_points[:, 0], img_points[:, 1], img_points[:, 2] 82 | 83 | x_ind = torch.floor((img_points_x - (self.start_x)) / self.stride).long() 84 | y_ind = torch.floor((img_points_y - (self.start_y)) / self.stride).long() 85 | return self._interp_inds(img_ind, x_ind, y_ind, img_points_x, img_points_y) 86 | 87 | def _interp_inds(self, img_ind, x_ind, y_ind, img_points_x, img_points_y): 88 | img_ind = img_ind.to(self.data.device) # self.data is on cpu to save gpu memory, hence this line 89 | topleft = self.data[img_ind, x_ind, y_ind].to(self.device) 90 | topright = self.data[img_ind, x_ind + 1, y_ind].to(self.device) 91 | botleft = self.data[img_ind, x_ind, y_ind + 1].to(self.device) 92 | botright = self.data[img_ind, x_ind + 1, y_ind + 1].to(self.device) 93 | 94 | x_stride = self.stride 95 | y_stride = self.stride 96 | right_w = ((img_points_x - (self.center_x[x_ind])) / x_stride).to(self.device) # .half() 97 | top = torch.lerp(topleft, topright, right_w[:, None]) 98 | bot = torch.lerp(botleft, botright, right_w[:, None]) 99 | 100 | bot_w = ((img_points_y - (self.center_y[y_ind])) / y_stride).to(self.device) # .half() 101 | return torch.lerp(top, bot, bot_w[:, None]) 102 | 103 | def _embed_clip_tiles(self, image, unfold_func): 104 | # image augmentation: slow-ish (0.02s for 600x800 image per augmentation) 105 | aug_imgs = torch.cat([image]) 106 | 107 | tiles = unfold_func(aug_imgs).permute(2, 0, 1).reshape(-1, 3, self.kernel_size, self.kernel_size).to("cuda") 108 | 109 | with torch.no_grad(): 110 | clip_embeds = self.model.encode_image(tiles) 111 | clip_embeds /= clip_embeds.norm(dim=-1, keepdim=True) 112 | 113 | clip_embeds = clip_embeds.reshape((self.center_x.shape[0], self.center_y.shape[0], -1)) 114 | clip_embeds = torch.concat((clip_embeds, clip_embeds[:, [-1], :]), dim=1) 115 | clip_embeds = torch.concat((clip_embeds, clip_embeds[[-1], :, :]), dim=0) 116 | return clip_embeds.detach().cpu().numpy() -------------------------------------------------------------------------------- /pogs/data/utils/pyramid_embedding_dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from pogs.data.utils.feature_dataloader import FeatureDataloader 8 | from pogs.data.utils.patch_embedding_dataloader import PatchEmbeddingDataloader 9 | from pogs.encoders.image_encoder import BaseImageEncoder 10 | from tqdm import tqdm 11 | 12 | 13 | class PyramidEmbeddingDataloader(FeatureDataloader): 14 | def __init__( 15 | self, 16 | cfg: dict, 17 | device: torch.device, 18 | model: BaseImageEncoder, 19 | image_list: torch.Tensor = None, 20 | cache_path: str = None, 21 | ): 22 | assert "tile_size_range" in cfg 23 | assert "tile_size_res" in cfg 24 | assert "stride_scaler" in cfg 25 | assert "image_shape" in cfg 26 | assert "model_name" in cfg 27 | 28 | self.tile_sizes = torch.linspace(*cfg["tile_size_range"], cfg["tile_size_res"]).to(device) 29 | self.strider_scaler_list = [self._stride_scaler(tr.item(), cfg["stride_scaler"]) for tr in self.tile_sizes] 30 | 31 | self.model = model 32 | self.embed_size = self.model.embedding_dim 33 | self.data_dict = {} 34 | super().__init__(cfg, device, image_list, cache_path) 35 | 36 | def __call__(self, img_points, scale=None): 37 | if scale is None: 38 | return self._random_scales(img_points) 39 | else: 40 | return self._uniform_scales(img_points, scale) 41 | 42 | def _stride_scaler(self, tile_ratio, stride_scaler): 43 | return np.interp(tile_ratio, [0.05, 0.15], [1.0, stride_scaler]) 44 | 45 | def load(self): 46 | # don't create anything, PatchEmbeddingDataloader will create itself 47 | cache_info_path = self.cache_path.with_suffix(".info") 48 | 49 | # check if cache exists 50 | if not cache_info_path.exists(): 51 | raise FileNotFoundError 52 | 53 | # if config is different, remove all cached content 54 | with open(cache_info_path, "r") as f: 55 | cfg = json.loads(f.read()) 56 | if cfg != self.cfg: 57 | for f in os.listdir(self.cache_path): 58 | os.remove(os.path.join(self.cache_path, f)) 59 | raise ValueError("Config mismatch") 60 | 61 | raise FileNotFoundError # trigger create 62 | 63 | def create(self, image_list): 64 | os.makedirs(self.cache_path, exist_ok=True) 65 | for i, tr in enumerate(tqdm(self.tile_sizes, desc="Scales")): 66 | stride_scaler = self.strider_scaler_list[i] 67 | self.data_dict[i] = PatchEmbeddingDataloader( 68 | cfg={ 69 | "tile_ratio": tr.item(), 70 | "stride_ratio": stride_scaler, 71 | "image_shape": self.cfg["image_shape"], 72 | "model_name": self.cfg["model_name"], 73 | }, 74 | device=self.device, 75 | model=self.model, 76 | image_list=image_list, 77 | cache_path=Path(f"{self.cache_path}/level_{i}.npy"), 78 | ) 79 | print(image_list.shape) 80 | 81 | def save(self): 82 | cache_info_path = self.cache_path.with_suffix(".info") 83 | with open(cache_info_path, "w") as f: 84 | f.write(json.dumps(self.cfg)) 85 | # don't save anything, PatchEmbeddingDataloader will save itself 86 | pass 87 | 88 | def _random_scales(self, img_points): 89 | # img_points: (B, 3) # (img_ind, x, y) 90 | # return: (B, 512), some random scale (between 0, 1) 91 | img_points = img_points.to(self.device) 92 | random_scale_bin = torch.randint(self.tile_sizes.shape[0] - 1, size=(img_points.shape[0],), device=self.device) 93 | random_scale_weight = torch.rand(img_points.shape[0], dtype=torch.float16, device=self.device) 94 | 95 | stepsize = (self.tile_sizes[1] - self.tile_sizes[0]) / (self.tile_sizes[-1] - self.tile_sizes[0]) 96 | 97 | bottom_interp = torch.zeros((img_points.shape[0], self.embed_size), dtype=torch.float16, device=self.device) 98 | top_interp = torch.zeros((img_points.shape[0], self.embed_size), dtype=torch.float16, device=self.device) 99 | 100 | for i in range(len(self.tile_sizes) - 1): 101 | ids = img_points[random_scale_bin == i] 102 | bottom_interp[random_scale_bin == i] = self.data_dict[i](ids) 103 | top_interp[random_scale_bin == i] = self.data_dict[i + 1](ids) 104 | 105 | return ( 106 | torch.lerp(bottom_interp, top_interp, random_scale_weight[..., None]), 107 | (random_scale_bin * stepsize + random_scale_weight * stepsize)[..., None], 108 | ) 109 | 110 | def _uniform_scales(self, img_points, scale): 111 | # img_points: (B, 3) # (img_ind, x, y) 112 | scale_bin = torch.floor( 113 | (scale - self.tile_sizes[0]) / (self.tile_sizes[-1] - self.tile_sizes[0]) * (self.tile_sizes.shape[0] - 1) 114 | ).to(torch.int64) 115 | scale_weight = (scale - self.tile_sizes[scale_bin]) / ( 116 | self.tile_sizes[scale_bin + 1] - self.tile_sizes[scale_bin] 117 | ) 118 | interp_lst = torch.stack([interp(img_points) for interp in self.data_dict.values()]) 119 | point_inds = torch.arange(img_points.shape[0]) 120 | interp = torch.lerp( 121 | interp_lst[scale_bin, point_inds], 122 | interp_lst[scale_bin + 1, point_inds], 123 | torch.Tensor([scale_weight]).half().to(self.device)[..., None], 124 | ) 125 | return interp / interp.norm(dim=-1, keepdim=True), scale -------------------------------------------------------------------------------- /pogs/encoders/image_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, abstractproperty 2 | from dataclasses import dataclass, field 3 | from typing import Type 4 | 5 | import torch 6 | from torch import nn 7 | from pogs.configs import base_config as cfg 8 | 9 | @dataclass 10 | class BaseImageEncoderConfig(cfg.InstantiateConfig): 11 | _target: Type = field(default_factory=lambda: BaseImageEncoder) 12 | 13 | 14 | class BaseImageEncoder(nn.Module): 15 | @abstractproperty 16 | def name(self) -> str: 17 | """ 18 | returns the name of the encoder 19 | """ 20 | 21 | @abstractproperty 22 | def embedding_dim(self) -> int: 23 | """ 24 | returns the dimension of the embeddings 25 | """ 26 | 27 | @abstractmethod 28 | def encode_image(self, input: torch.Tensor) -> torch.Tensor: 29 | """ 30 | Given a batch of input images, return their encodings 31 | """ 32 | 33 | @abstractmethod 34 | def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor: 35 | """ 36 | Given a batch of embeddings, return the relevancy to the given positive id 37 | """ -------------------------------------------------------------------------------- /pogs/encoders/openclip_encoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Tuple, Type 3 | 4 | import torch 5 | import torchvision 6 | 7 | try: 8 | import open_clip 9 | except ImportError: 10 | assert False, "open_clip is not installed, install it with `pip install open-clip-torch`" 11 | 12 | from pogs.encoders.image_encoder import (BaseImageEncoder, 13 | BaseImageEncoderConfig) 14 | from nerfstudio.viewer.viewer_elements import ViewerText 15 | 16 | 17 | @dataclass 18 | class OpenCLIPNetworkConfig(BaseImageEncoderConfig): 19 | _target: Type = field(default_factory=lambda: OpenCLIPNetwork) 20 | clip_model_type: str = "ViT-B-16" 21 | clip_model_pretrained: str = "laion2b_s34b_b88k" 22 | clip_n_dims: int = 512 23 | negatives: Tuple[str] = ("object", "things", "stuff", "texture") 24 | device: str = 'cuda:0' 25 | 26 | @property 27 | def name(self) -> str: 28 | return "openclip_{}_{}".format(self.clip_model_type, self.clip_model_pretrained) 29 | 30 | 31 | class OpenCLIPNetwork(BaseImageEncoder): 32 | def __init__(self, config: OpenCLIPNetworkConfig): 33 | super().__init__() 34 | self.config = config 35 | self.process = torchvision.transforms.Compose( 36 | [ 37 | torchvision.transforms.Resize((224, 224)), 38 | torchvision.transforms.Normalize( 39 | mean=[0.48145466, 0.4578275, 0.40821073], 40 | std=[0.26862954, 0.26130258, 0.27577711], 41 | ), 42 | ] 43 | ) 44 | model, _, _ = open_clip.create_model_and_transforms( 45 | self.config.clip_model_type, # e.g., ViT-B-16 46 | pretrained=self.config.clip_model_pretrained, # e.g., laion2b_s34b_b88k 47 | precision="fp16", 48 | device=self.config.device, 49 | # device='cuda:1', 50 | ) 51 | # model.to('cuda:1') 52 | model.eval() 53 | self.tokenizer = open_clip.get_tokenizer(self.config.clip_model_type) 54 | self.model = model.to(self.config.device) 55 | self.clip_n_dims = self.config.clip_n_dims 56 | 57 | self.positive_input = ViewerText("Positives", "", cb_hook=self.gui_cb) 58 | 59 | self.positives = self.positive_input.value.split(";") 60 | self.negatives = self.config.negatives 61 | with torch.no_grad(): 62 | tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(self.config.device) 63 | self.pos_embeds = model.encode_text(tok_phrases) 64 | tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to(self.config.device) 65 | self.neg_embeds = model.encode_text(tok_phrases) 66 | self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True) 67 | self.neg_embeds /= self.neg_embeds.norm(dim=-1, keepdim=True) 68 | 69 | assert ( 70 | self.pos_embeds.shape[1] == self.neg_embeds.shape[1] 71 | ), "Positive and negative embeddings must have the same dimensionality" 72 | assert ( 73 | self.pos_embeds.shape[1] == self.clip_n_dims 74 | ), "Embedding dimensionality must match the model dimensionality" 75 | 76 | @property 77 | def name(self) -> str: 78 | return "openclip_{}_{}".format(self.config.clip_model_type, self.config.clip_model_pretrained) 79 | 80 | @property 81 | def embedding_dim(self) -> int: 82 | return self.config.clip_n_dims 83 | 84 | def gui_cb(self,element): 85 | # element = element.to(self.config.device) 86 | self.set_positives(element.value.split(";")) 87 | 88 | def set_positives(self, text_list): 89 | self.positives = text_list 90 | with torch.no_grad(): 91 | tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(self.config.device) 92 | self.pos_embeds = self.model.encode_text(tok_phrases) 93 | self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True) 94 | self.pos_embeds = self.pos_embeds.to(self.config.device) 95 | 96 | def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor: 97 | phrases_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0).to(self.config.device) 98 | p = phrases_embeds.to(embed.dtype) # phrases x 512 99 | embed = embed.to(p.device) 100 | output = torch.mm(embed, p.T) # rays x phrases 101 | positive_vals = output[..., positive_id : positive_id + 1] # rays x 1 102 | negative_vals = output[..., len(self.positives) :] # rays x N_phrase 103 | repeated_pos = positive_vals.repeat(1, len(self.negatives)) # rays x N_phrase 104 | 105 | sims = torch.stack((repeated_pos, negative_vals), dim=-1) # rays x N-phrase x 2 106 | softmax = torch.softmax(10 * sims, dim=-1) # rays x n-phrase x 2 107 | best_id = softmax[..., 0].argmin(dim=1) # rays x 2 108 | return torch.gather(softmax, 1, best_id[..., None, None].expand(best_id.shape[0], len(self.negatives), 2))[ 109 | :, 0, : 110 | ] 111 | 112 | def encode_image(self, input): 113 | processed_input = self.process(input).half() 114 | return self.model.encode_image(processed_input) -------------------------------------------------------------------------------- /pogs/field_components/gaussian_fieldheadnames.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | # from nerfstudio.field_components.field_heads import FieldHeadNames 3 | 4 | class GaussianFieldHeadNames(Enum): 5 | """Possible field outputs""" 6 | HASHGRID = "hashgrid" 7 | CLIP = "clip" 8 | # DINO = "dino" 9 | INSTANCE = "instance" -------------------------------------------------------------------------------- /pogs/fields/gaussian_field.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Field for compound nerf model, adds scene contraction and image embeddings to instant ngp 17 | """ 18 | 19 | 20 | from typing import Dict, Literal, Optional, Tuple 21 | import numpy as np 22 | 23 | import torch 24 | from torch import Tensor, nn 25 | 26 | from nerfstudio.cameras.rays import RaySamples 27 | from nerfstudio.data.scene_box import SceneBox 28 | from nerfstudio.field_components.activations import trunc_exp 29 | from nerfstudio.field_components.embedding import Embedding 30 | from nerfstudio.field_components.encodings import HashEncoding, NeRFEncoding, SHEncoding 31 | from pogs.field_components.gaussian_fieldheadnames import GaussianFieldHeadNames 32 | 33 | from nerfstudio.field_components.mlp import MLP 34 | from nerfstudio.field_components.spatial_distortions import SpatialDistortion, SceneContraction 35 | from nerfstudio.fields.base_field import Field, get_normalized_directions 36 | 37 | try: 38 | import tinycudann as tcnn 39 | except ImportError: 40 | pass 41 | 42 | 43 | class GaussianField(Field): 44 | """Compound Field that uses TCNN 45 | 46 | Args: 47 | aabb: parameters of scene aabb bounds 48 | num_images: number of images in the dataset 49 | num_layers: number of hidden layers 50 | hidden_dim: dimension of hidden layers 51 | geo_feat_dim: output geo feat dimensions 52 | num_levels: number of levels of the hashmap for the base mlp 53 | base_res: base resolution of the hashmap for the base mlp 54 | max_res: maximum resolution of the hashmap for the base mlp 55 | log2_hashmap_size: size of the hashmap for the base mlp 56 | num_layers_color: number of hidden layers for color network 57 | num_layers_transient: number of hidden layers for transient network 58 | features_per_level: number of features per level for the hashgrid 59 | hidden_dim_color: dimension of hidden layers for color network 60 | hidden_dim_transient: dimension of hidden layers for transient network 61 | appearance_embedding_dim: dimension of appearance embedding 62 | transient_embedding_dim: dimension of transient embedding 63 | use_transient_embedding: whether to use transient embedding 64 | use_semantics: whether to use semantic segmentation 65 | num_semantic_classes: number of semantic classes 66 | use_pred_normals: whether to use predicted normals 67 | use_average_appearance_embedding: whether to use average appearance embedding or zeros for inference 68 | spatial_distortion: spatial distortion to apply to the scene 69 | """ 70 | 71 | def __init__( 72 | self, 73 | num_levels: int = 16, 74 | base_res: int = 16, 75 | max_res: int = 2048, 76 | log2_hashmap_size: int = 19, 77 | features_per_level: int = 2, 78 | implementation: Literal["tcnn", "torch"] = "tcnn", 79 | grid_layers: Tuple[int] = (12, 12), 80 | grid_sizes: Tuple[Tuple[int]] = (19, 19), 81 | grid_resolutions: Tuple[int] = ((16, 128), (128, 512)), 82 | n_features_level: int = 4, 83 | clip_n_dims: int = 512, 84 | 85 | feature_dims: int = 64, 86 | ) -> None: 87 | super().__init__() 88 | 89 | self.spatial_distortion: SceneContraction = SceneContraction() 90 | 91 | self.register_buffer("max_res", torch.tensor(max_res)) 92 | self.register_buffer("num_levels", torch.tensor(num_levels)) 93 | self.register_buffer("log2_hashmap_size", torch.tensor(log2_hashmap_size)) 94 | self.clip_encs = torch.nn.ModuleList( 95 | [ 96 | GaussianField._get_encoding( 97 | grid_resolutions[i][0], grid_resolutions[i][1], grid_layers[i], indim=3, hash_size=grid_sizes[i], features_per_level=n_features_level, 98 | ) for i in range(len(grid_layers)) 99 | ] 100 | ) 101 | tot_out_dims = sum([e.n_output_dims for e in self.clip_encs]) 102 | instance_n_dims = 128 103 | print("Total output dims: ", tot_out_dims) 104 | 105 | self.clip_net = tcnn.Network( 106 | n_input_dims=tot_out_dims+1, 107 | n_output_dims=clip_n_dims, 108 | network_config={ 109 | "otype": "CutlassMLP", 110 | "activation": "ReLU", 111 | "output_activation": "None", 112 | "n_neurons": 256, 113 | "n_hidden_layers": 3, 114 | }, 115 | ) 116 | 117 | self.instance_net = tcnn.Network( 118 | n_input_dims=tot_out_dims, 119 | n_output_dims=instance_n_dims, 120 | network_config={ 121 | "otype": "CutlassMLP", 122 | "activation": "ReLU", 123 | "output_activation": "None", 124 | "n_neurons": 256, 125 | "n_hidden_layers": 4, 126 | }, 127 | ) 128 | 129 | 130 | @staticmethod 131 | def _get_encoding(start_res, end_res, levels, indim=3, hash_size=19, features_per_level=8): 132 | growth = np.exp((np.log(end_res) - np.log(start_res)) / (levels - 1)) 133 | enc = tcnn.Encoding( 134 | n_input_dims=indim, 135 | encoding_config={ 136 | "otype": "HashGrid", 137 | "n_levels": levels, 138 | "n_features_per_level": features_per_level, 139 | "log2_hashmap_size": hash_size, 140 | "base_resolution": start_res, 141 | "per_level_scale": growth, 142 | }, 143 | ) 144 | return enc 145 | 146 | def get_outputs(self, positions, clip_scales) -> Dict[GaussianFieldHeadNames, Tensor]: 147 | # random scales, one scale 148 | positions = self.spatial_distortion(positions) 149 | positions = (positions + 2.0) / 4.0 150 | 151 | outputs = {} 152 | xs = [e(positions.view(-1, 3)) for e in self.clip_encs] 153 | x = torch.concat(xs, dim=-1) 154 | 155 | outputs[GaussianFieldHeadNames.HASHGRID] = x.view(positions.shape[0], -1) 156 | 157 | clip_pass = self.clip_net(torch.cat([x, clip_scales.view(-1, 1)], dim=-1)).view(positions.shape[0], -1) 158 | 159 | outputs[GaussianFieldHeadNames.CLIP] = (clip_pass / clip_pass.norm(dim=-1, keepdim=True)).to(torch.float32) 160 | 161 | epsilon = 1e-5 162 | instance_pass = self.instance_net(x).view(positions.shape[0], -1) 163 | outputs[GaussianFieldHeadNames.INSTANCE] = instance_pass / (instance_pass.norm(dim=-1, keepdim=True) + epsilon) 164 | 165 | return outputs 166 | 167 | def get_hash(self, positions) -> Tensor: 168 | positions = self.spatial_distortion(positions) 169 | positions = (positions + 2.0) / 4.0 170 | 171 | encodings = [e(positions.view(-1, 3)) for e in self.clip_encs] 172 | encoding = torch.concat(encodings, dim=-1) 173 | 174 | return encoding.to(torch.float32) 175 | 176 | def get_outputs_from_feature(self, features, clip_scale, random_pixels = None) -> Dict[GaussianFieldHeadNames, Tensor]: 177 | outputs = {} 178 | 179 | #clip_features is Nx32, and clip scale is a number, I want to cat clip scale to the end of clip_features where clip scale is an int 180 | if random_pixels is not None: 181 | clip_features = features[random_pixels] 182 | else: 183 | clip_features = features 184 | clip_pass = self.clip_net(torch.cat([clip_features, clip_scale.view(-1, 1)], dim=-1)) 185 | 186 | outputs[GaussianFieldHeadNames.CLIP] = (clip_pass / clip_pass.norm(dim=-1, keepdim=True)).to(torch.float32) 187 | 188 | epsilon = 1e-5 189 | instance_pass = self.instance_net(features) 190 | outputs[GaussianFieldHeadNames.INSTANCE] = instance_pass / (instance_pass.norm(dim=-1, keepdim=True) + epsilon) 191 | 192 | return outputs 193 | 194 | def get_instance_outputs_from_feature(self, features) -> Tensor: 195 | epsilon = 1e-5 196 | instance_pass = self.instance_net(features) 197 | outputs = instance_pass / (instance_pass.norm(dim=-1, keepdim=True) + epsilon) 198 | 199 | return outputs 200 | 201 | def get_clip_outputs_from_feature(self, features, clip_scale) -> Tensor: 202 | clip_pass = self.clip_net(torch.cat([features, clip_scale.view(-1, 1)], dim=-1)) 203 | outputs = (clip_pass / clip_pass.norm(dim=-1, keepdim=True)).to(torch.float32) 204 | return outputs -------------------------------------------------------------------------------- /pogs/grasping/generate_grasps_ply.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | dir_path = os.path.dirname(os.path.realpath(__file__)) 5 | contact_graspnet_path = os.path.join(dir_path,'../contact_graspnet_wrapper') 6 | 7 | sys.path.append(contact_graspnet_path) 8 | from prime_inference import modified_inference 9 | import argparse 10 | from prime_config_utils import load_config 11 | import numpy as np 12 | from prime_visualization_utils import visualize_grasps 13 | from autolab_core import RigidTransform 14 | import open3d as o3d 15 | tool_to_wrist = RigidTransform() 16 | # 0.1651 was old measurement is the measure dist from suction to 17 | # 0.1857375 Parallel Jaw gripper 18 | tool_to_wrist.translation = np.array([0, 0, 0]) 19 | tool_to_wrist.from_frame = "tool" 20 | tool_to_wrist.to_frame = "wrist" 21 | 22 | segmented_ply_filepath = "/home/lifelong/sms/sms/data/utils/Detic/outputs/2024_07_22_green_tape_bowl/prime_seg_gaussians.ply" 23 | full_ply_filepath = "/home/lifelong/sms/sms/data/utils/Detic/outputs/2024_07_22_green_tape_bowl/prime_full_gaussians.ply" 24 | bounding_box_filepath = "/home/lifelong/sms/sms/data/utils/Detic/2024_07_22_green_tape_bowl/table_bounding_cube.json" 25 | 26 | 27 | def filter_noise(points, colors=None): 28 | from sklearn.cluster import DBSCAN 29 | eps = 0.005 30 | min_samples = 10 31 | dbscan = DBSCAN(eps=eps, min_samples=min_samples) 32 | labels = dbscan.fit_predict(points) 33 | filtered_pointcloud = points[labels != -1] 34 | if colors is not None: 35 | filtered_colors = colors[labels != -1] 36 | else: 37 | filtered_colors = None 38 | return filtered_pointcloud, filtered_colors 39 | 40 | def generate_grasps(seg_np_path, full_np_path, pc_bounding_box_path, ckpt_dir, z_range, K, local_regions, filter_grasps, skip_border_objects, forward_passes, segmap_id, arg_configs, save_dir): 41 | 42 | global_config = load_config(ckpt_dir, batch_size=forward_passes, arg_configs=arg_configs) 43 | 44 | print(str(global_config)) 45 | print('pid: %s'%(str(os.getpid()))) 46 | 47 | pred_grasps_world, scores, contact_pts, points_world, pc_colors = modified_inference(global_config, ckpt_dir, seg_np_path, full_np_path,pc_bounding_box_path, z_range=z_range, 48 | K=K, local_regions=local_regions, filter_grasps=filter_grasps, segmap_id=segmap_id, 49 | forward_passes=forward_passes, skip_border_objects=skip_border_objects,debug=False) 50 | print("GENERATED GRASPS") 51 | sorted_idxs = np.argsort(scores[0])[::-1] 52 | best_scores = {0:scores[0][sorted_idxs][:1]} 53 | best_grasps = {0:pred_grasps_world[0][sorted_idxs][:1]} 54 | best_contact_pts = {0:contact_pts[0][sorted_idxs][:1]} 55 | 56 | point_cloud_world = o3d.geometry.PointCloud() 57 | point_cloud_world.points = o3d.utility.Vector3dVector(points_world) 58 | point_cloud_world.colors = o3d.utility.Vector3dVector(pc_colors) 59 | coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 60 | 61 | final_grasp_world_frame = best_grasps[0][0] 62 | grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 63 | grasp_point_world.transform(final_grasp_world_frame) 64 | pre_grasp_tf = np.array([[1,0,0,0], 65 | [0,1,0,0], 66 | [0,0,1,-0.1], 67 | [0,0,0,1]]) 68 | 69 | pre_grasp_world_frame = final_grasp_world_frame @ pre_grasp_tf 70 | 71 | pre_grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 72 | 73 | pre_grasp_point_world.transform(pre_grasp_world_frame) 74 | 75 | #o3d.visualization.draw_geometries([point_cloud_world,coordinate_frame,grasp_point_world,pre_grasp_point_world]) 76 | np.save(f'{FLAGS.save_dir}/pred_grasps_world.npy', pred_grasps_world[0]) 77 | 78 | np.save(f'{FLAGS.save_dir}/scores.npy', scores[0]) 79 | 80 | np.save(f'{FLAGS.save_dir}/contact_pts.npy', contact_pts[0]) 81 | 82 | np.save(f'{FLAGS.save_dir}/point_cloud_world.npy', points_world) 83 | np.save(f'{FLAGS.save_dir}/rgb_cloud_world.npy', pc_colors) 84 | np.save(f'{FLAGS.save_dir}/grasp_point_world.npy', final_grasp_world_frame) 85 | np.save(f'{FLAGS.save_dir}/pre_grasp_point_world.npy', pre_grasp_world_frame) 86 | 87 | return pred_grasps_world, scores[0] 88 | 89 | import pdb 90 | 91 | pdb.set_trace() 92 | world_to_cam_tf = np.array([[0,-1,0,0], 93 | [-1,0,0,0], 94 | [0,0,-1,0], 95 | [0,0,0,1]]) 96 | 97 | 98 | #visualize_grasps(pc_full, best_grasps, best_scores, plot_opencv_cam=True, pc_colors=pc_colors) 99 | # Create an Open3D point cloud object 100 | point_cloud_cam = o3d.geometry.PointCloud() 101 | 102 | # Set the points and colors 103 | point_cloud_cam.points = o3d.utility.Vector3dVector(pc_full) 104 | point_cloud_cam.colors = o3d.utility.Vector3dVector(pc_colors) 105 | 106 | # Step 2: Visualize the point cloud 107 | coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 108 | grasp_point = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 109 | grasp_point.transform(best_grasps[0][0]) 110 | # o3d.visualization.draw_geometries([point_cloud_cam,coordinate_frame,grasp_point]) 111 | 112 | ones = np.ones((pc_full.shape[0],1)) 113 | world_to_cam_tf = RigidTransform.load('/home/lifelong/sms/sms/ur5_interface/ur5_interface/calibration_outputs/world_to_extrinsic_zed_for_grasping.tf').matrix 114 | homogenous_points_cam = np.hstack((pc_full,ones)) 115 | homogenous_points_world = world_to_cam_tf @ homogenous_points_cam.T 116 | points_world = homogenous_points_world[:3,:] / homogenous_points_world[3,:][np.newaxis,:] 117 | points_world = points_world.T 118 | 119 | point_cloud_world = o3d.geometry.PointCloud() 120 | 121 | # Set the points and colors 122 | point_cloud_world.points = o3d.utility.Vector3dVector(points_world) 123 | point_cloud_world.colors = o3d.utility.Vector3dVector(pc_colors) 124 | coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 125 | panda_grasp_point_to_robotiq_grasp_point = np.array([[1,0,0,0],[0,1,0,0],[0,0,1,-0.03],[0,0,0,1]]) # -0.06 126 | final_grasp_world_frame = world_to_cam_tf @ best_grasps[0][0] @ panda_grasp_point_to_robotiq_grasp_point 127 | grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 128 | grasp_point_world.transform(final_grasp_world_frame) 129 | pre_grasp_tf = np.array([[1,0,0,0], 130 | [0,1,0,0], 131 | [0,0,1,-0.1], 132 | [0,0,0,1]]) 133 | pre_grasp_world_frame = final_grasp_world_frame @ pre_grasp_tf 134 | pre_grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 135 | pre_grasp_point_world.transform(pre_grasp_world_frame) 136 | # o3d.visualization.draw_geometries([point_cloud_world,coordinate_frame,grasp_point_world,pre_grasp_point_world]) 137 | pred_grasps_world = [] 138 | for i in range(len(pred_grasps_cam[0])): 139 | grasp = world_to_cam_tf @ pred_grasps_cam[0][i] @ panda_grasp_point_to_robotiq_grasp_point 140 | pred_grasps_world.append(grasp) 141 | pred_grasps_world = np.array(pred_grasps_world) 142 | np.save(f'{FLAGS.save_dir}/pred_grasps_world.npy', pred_grasps_world) 143 | np.save(f'{FLAGS.save_dir}/scores.npy', scores[0]) 144 | np.save(f'{FLAGS.save_dir}/contact_pts.npy', contact_pts[0]) 145 | 146 | return pred_grasps_world, scores[0] 147 | 148 | if __name__ == "__main__": 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('--seg_np_path', default='',required=True) 151 | parser.add_argument('--full_np_path', default='',required=True) 152 | parser.add_argument('--save_dir', default='',required=True) 153 | checkpoint_dir = os.path.join(dir_path,'../dependencies/contact_graspnet/checkpoints/scene_test_2048_bs3_hor_sigma_001') 154 | parser.add_argument('--ckpt_dir', default=checkpoint_dir, help='Log dir [default: checkpoints/scene_test_2048_bs3_hor_sigma_001]') 155 | parser.add_argument('--pc_bounding_box_path', default='', help='Input data: npz/npy file with keys either "depth" & camera matrix "K" or just point cloud "pc" in meters. Optionally, a 2D "segmap"',required=True) 156 | parser.add_argument('--K', default=None, help='Flat Camera Matrix, pass as "[fx, 0, cx, 0, fy, cy, 0, 0 ,1]"') 157 | parser.add_argument('--z_range', default=None, help='Z value threshold to crop the input point cloud') 158 | parser.add_argument('--local_regions', action='store_true', default=False, help='Crop 3D local regions around given segments.') 159 | parser.add_argument('--filter_grasps', action='store_true', default=True, help='Filter grasp contacts according to segmap.') 160 | parser.add_argument('--skip_border_objects', action='store_true', default=False, help='When extracting local_regions, ignore segments at depth map boundary.') 161 | parser.add_argument('--forward_passes', type=int, default=1, help='Run multiple parallel forward passes to mesh_utils more potential contact points.') 162 | parser.add_argument('--segmap_id', type=int, default=0, help='Only return grasps of the given object id') 163 | parser.add_argument('--arg_configs', nargs="*", type=str, default=[], help='overwrite config parameters') 164 | FLAGS = parser.parse_args() 165 | generate_grasps(FLAGS.seg_np_path, FLAGS.full_np_path, FLAGS.pc_bounding_box_path, FLAGS.ckpt_dir, FLAGS.z_range, FLAGS.K, FLAGS.local_regions, 166 | FLAGS.filter_grasps, FLAGS.skip_border_objects, FLAGS.forward_passes, FLAGS.segmap_id, FLAGS.arg_configs, FLAGS.save_dir) -------------------------------------------------------------------------------- /pogs/grasping/results/predictions_global.ply.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/pogs/grasping/results/predictions_global.ply.npz -------------------------------------------------------------------------------- /pogs/pogs_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | pogs configuration file. 3 | """ 4 | 5 | from nerfstudio.configs.base_config import ViewerConfig 6 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig 7 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 8 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig 9 | from nerfstudio.plugins.types import MethodSpecification 10 | from nerfstudio.engine.trainer import TrainerConfig as TrainerConfigBase 11 | from pogs.encoders.openclip_encoder import OpenCLIPNetworkConfig 12 | from pogs.pogs import POGSModelConfig 13 | 14 | 15 | from pogs.pogs_pipeline import POGSPipelineConfig 16 | from pogs.data.full_images_datamanager import FullImageDatamanagerConfig, FullImageDatamanager 17 | from pogs.data.depth_dataset import DepthDataset 18 | 19 | pogs_method = MethodSpecification( 20 | config = TrainerConfigBase( 21 | method_name="pogs", 22 | steps_per_eval_image=100, 23 | steps_per_eval_batch=100, 24 | steps_per_save=1000, 25 | max_num_iterations=4000, 26 | mixed_precision=False, 27 | gradient_accumulation_steps = {'camera_opt': 100,'color':10,'shs':10, 'lerf': 3}, 28 | pipeline=POGSPipelineConfig( 29 | datamanager=FullImageDatamanagerConfig( 30 | _target=FullImageDatamanager[DepthDataset], # Comment out the [DepthDataset] part to use RGB only datasets (e.g. polycam datasets) 31 | dataparser=NerfstudioDataParserConfig(load_3D_points=True, orientation_method='none', center_method='none', auto_scale_poses=False, depth_unit_scale_factor=1.0), 32 | network=OpenCLIPNetworkConfig( 33 | clip_model_type="ViT-B-16", clip_model_pretrained="laion2b_s34b_b88k", clip_n_dims=512, device='cuda:0' 34 | ), 35 | ), 36 | model=POGSModelConfig(), 37 | ), 38 | optimizers={ 39 | "means": { 40 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 41 | "scheduler": ExponentialDecaySchedulerConfig( 42 | lr_final=1.6e-6, 43 | max_steps=30000, 44 | ), 45 | }, 46 | "features_dc": { 47 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 48 | "scheduler": None, 49 | }, 50 | "features_rest": { 51 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), 52 | "scheduler": None, 53 | }, 54 | "opacities": { 55 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 56 | "scheduler": None, 57 | }, 58 | "scales": { 59 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), 60 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-3, max_steps=30000) 61 | }, 62 | "quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None}, 63 | "camera_opt": { 64 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 65 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000), 66 | }, 67 | "lerf": { 68 | "optimizer": AdamOptimizerConfig(lr=2.5e-3, eps=1e-15), 69 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-3, max_steps=15000), 70 | }, 71 | "dino_feats": { 72 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 73 | "scheduler": ExponentialDecaySchedulerConfig( 74 | lr_final=1e-3, 75 | max_steps=6000, 76 | ), 77 | }, 78 | "nn_projection": { 79 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 80 | "scheduler": ExponentialDecaySchedulerConfig( 81 | lr_final=1e-3, 82 | max_steps=6000, 83 | ), 84 | }, 85 | }, 86 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 87 | vis="viewer", 88 | ), 89 | description="Persistent Object Gaussian Splatting", 90 | ) -------------------------------------------------------------------------------- /pogs/scripts/track_main_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import viser 3 | import viser.transforms as vtf 4 | import time 5 | import numpy as np 6 | import tyro 7 | from pathlib import Path 8 | from autolab_core import RigidTransform 9 | from pogs.tracking.zed import Zed 10 | from pogs.tracking.optim import Optimizer 11 | import warp as wp 12 | from pogs.encoders.openclip_encoder import OpenCLIPNetworkConfig, OpenCLIPNetwork 13 | 14 | import os 15 | import cv2 16 | dir_path = os.path.dirname(os.path.realpath(__file__)) 17 | print(dir_path) 18 | WORLD_TO_ZED2 = RigidTransform.load(dir_path+"/../tracking/data/calibration_outputs/world_to_extrinsic_zed.tf") 19 | DEVICE = 'cuda:0' 20 | set_initial_frame = False 21 | 22 | def main( 23 | config_path: str = "/home/yujustin/pogs/outputs/drill/pogs/2025-02-07_004727/config.yml", 24 | offline_folder: str = '/home/yujustin/pogs/data/demonstrations/drill' 25 | ): 26 | """Quick interactive demo for object tracking. 27 | 28 | Args: 29 | config_path: Path to the nerfstudio POGS model config file. 30 | offline_folder: Path to the offline folder with images and depth 31 | """ 32 | image_folder = os.path.join(offline_folder,"left") 33 | depth_folder = os.path.join(offline_folder,"depth") 34 | image_paths = sorted(os.listdir(image_folder)) 35 | depth_paths = sorted(os.listdir(depth_folder)) 36 | 37 | server = viser.ViserServer() 38 | wp.init() 39 | # Set up the camera. 40 | 41 | clip_encoder = OpenCLIPNetworkConfig( 42 | clip_model_type="ViT-B-16", 43 | clip_model_pretrained="laion2b_s34b_b88k", 44 | clip_n_dims=512, 45 | device=DEVICE 46 | ).setup() # OpenCLIP encoder for language querying utils 47 | assert isinstance(clip_encoder, OpenCLIPNetwork) 48 | 49 | camera_tf = WORLD_TO_ZED2 50 | 51 | # Visualize the camera. 52 | camera_frame = server.add_frame( 53 | "camera", 54 | position=camera_tf.translation, # rough alignment. 55 | wxyz=camera_tf.quaternion, 56 | show_axes=True, 57 | axes_length=0.1, 58 | axes_radius=0.005, 59 | ) 60 | 61 | initial_image_path = os.path.join(image_folder,image_paths[0]) 62 | initial_depth_path = os.path.join(depth_folder,depth_paths[0]) 63 | img_numpy = cv2.imread(initial_image_path) 64 | depth_numpy = np.load(initial_depth_path) 65 | l = torch.from_numpy(cv2.cvtColor(img_numpy,cv2.COLOR_RGB2BGR)).to(DEVICE) 66 | depth = torch.from_numpy(depth_numpy).to(DEVICE) 67 | 68 | zedK = np.array([[1.05576221e+03, 0.00000000e+00, 9.62041199e+02], 69 | [0.00000000e+00, 1.05576221e+03, 5.61746765e+02], 70 | [0.00000000e+00, 0.00000000e+00, 1.00000000e+00]]) 71 | toad_opt = Optimizer( # Initialize the optimizer 72 | Path(config_path), 73 | zedK, 74 | l.shape[1], 75 | l.shape[0], 76 | init_cam_pose=torch.from_numpy( 77 | vtf.SE3( 78 | wxyz_xyz=np.array([*camera_frame.wxyz, *camera_frame.position]) 79 | ).as_matrix()[None, :3, :] 80 | ).float(), 81 | ) 82 | real_frames = [] 83 | rendered_rgb_frames = [] 84 | part_deltas = [] 85 | save_videos = True 86 | obj_label_list = [None for _ in range(toad_opt.num_groups)] 87 | initial_image_path = os.path.join(image_folder,image_paths[0]) 88 | initial_depth_path = os.path.join(depth_folder,depth_paths[0]) 89 | img_numpy = cv2.imread(initial_image_path) 90 | depth_numpy = np.load(initial_depth_path) 91 | l = torch.from_numpy(cv2.cvtColor(img_numpy,cv2.COLOR_RGB2BGR)).to(DEVICE) 92 | depth = torch.from_numpy(depth_numpy).to(DEVICE) 93 | toad_opt.set_frame(l,toad_opt.cam2world_ns_ds,depth) 94 | 95 | toad_opt.init_obj_pose() 96 | print("Starting main tracking loop") 97 | 98 | assert isinstance(toad_opt, Optimizer) 99 | while not toad_opt.initialized: 100 | time.sleep(0.1) 101 | if toad_opt.initialized: 102 | # start_time3 = time.time() 103 | for(img_path,depth_path) in zip(image_paths,depth_paths): 104 | full_image_path = os.path.join(image_folder,img_path) 105 | full_depth_path = os.path.join(depth_folder,depth_path) 106 | img_numpy = cv2.imread(full_image_path) 107 | depth_numpy = np.load(full_depth_path) 108 | left = torch.from_numpy(cv2.cvtColor(img_numpy,cv2.COLOR_RGB2BGR)).to(DEVICE) 109 | depth = torch.from_numpy(depth_numpy).to(DEVICE) 110 | # import pdb; pdb.set_trace 111 | start_time3 = time.time() 112 | toad_opt.set_observation(left,toad_opt.cam2world_ns,depth) 113 | print("Set observation in ", time.time()-start_time3) 114 | start_time5 = time.time() 115 | n_opt_iters = 15 116 | # with zed.raft_lock: 117 | outputs = toad_opt.step_opt(niter=n_opt_iters) 118 | print(f"{n_opt_iters} opt steps in ", time.time()-start_time5) 119 | 120 | # Add ZED img and GS render to viser 121 | rgb_img = left.cpu().numpy() 122 | for i in range(len(toad_opt.group_masks)): 123 | frame = toad_opt.optimizer.frame.roi_frames[i] 124 | xmin = frame.xmin 125 | xmax = frame.xmax 126 | ymin = frame.ymin 127 | ymax = frame.ymax 128 | rgb_img = cv2.rectangle(rgb_img, (xmin, ymin), (xmax, ymax),(255,0,0), 2) 129 | 130 | server.scene.add_image( 131 | "cam/zed_left", 132 | rgb_img, 133 | render_width=rgb_img.shape[1]/2500, 134 | render_height=rgb_img.shape[0]/2500, 135 | position = (-0.5, -0.5, 0.5), 136 | wxyz=(0, -1, 0, 0), 137 | visible=True 138 | ) 139 | 140 | server.scene.add_image( 141 | "cam/gs_render", 142 | outputs["rgb"].cpu().detach().numpy(), 143 | render_width=outputs["rgb"].shape[1]/2500, 144 | render_height=outputs["rgb"].shape[0]/2500, 145 | position = (0.5, -0.5, 0.5), 146 | wxyz=(0, -1, 0, 0), 147 | visible=True 148 | ) 149 | 150 | if save_videos: 151 | real_frames.append(rgb_img) 152 | rendered_rgb_frames.append(outputs["rgb"].cpu().detach().numpy()) 153 | 154 | tf_list = toad_opt.get_parts2world() 155 | part_deltas.append(tf_list) 156 | for idx, tf in enumerate(tf_list): 157 | server.add_frame( 158 | f"object/group_{idx}", 159 | position=tf.translation(), 160 | wxyz=tf.rotation().wxyz, 161 | show_axes=True, 162 | axes_length=0.05, 163 | axes_radius=.001 164 | ) 165 | mesh = toad_opt.toad_object.meshes[idx] 166 | server.add_mesh_trimesh( 167 | f"object/group_{idx}/mesh", 168 | mesh=mesh, 169 | ) 170 | if idx == toad_opt.max_relevancy_label: 171 | obj_label_list[idx] = server.add_label( 172 | f"object/group_{idx}/label", 173 | text=toad_opt.max_relevancy_text, 174 | position = (0,0,0.05), 175 | ) 176 | else: 177 | if obj_label_list[idx] is not None: 178 | obj_label_list[idx].remove() 179 | 180 | # Visualize pointcloud. 181 | start_time4 = time.time() 182 | K = torch.from_numpy(zedK).float().cuda() 183 | assert isinstance(left, torch.Tensor) and isinstance(depth, torch.Tensor) 184 | points, colors = Zed.project_depth(left, depth, K, depth_threshold=1.0, subsample=6) 185 | server.add_point_cloud( 186 | "camera/points", 187 | points=points, 188 | colors=colors, 189 | point_size=0.001, 190 | ) 191 | 192 | 193 | # except KeyboardInterrupt: 194 | # # Generate videos from the frames if the user interrupts the loop with ctrl+c 195 | # frames_dict = {"real_frames": real_frames, 196 | # "rendered_rgb": rendered_rgb_frames} 197 | # timestr = generate_videos(frames_dict, fps = 5, config_path=config_path.parent) 198 | 199 | # # Save part deltas to npy file 200 | # path = config_path.parent.joinpath(f"{timestr}") 201 | # np.save(path.joinpath("part_deltas_traj.npy"), np.array(part_deltas)) 202 | # exit() 203 | # except Exception as e: 204 | # print("An exception occured: ", e) 205 | # exit() 206 | 207 | 208 | if __name__ == "__main__": 209 | tyro.cli(main) 210 | -------------------------------------------------------------------------------- /pogs/tracking/atap_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List 4 | from pogs.pogs import POGSModel 5 | import warp as wp 6 | 7 | #https://openaccess.thecvf.com/content_CVPR_2019/papers/Barron_A_General_and_Adaptive_Robust_Loss_Function_CVPR_2019_paper.pdf 8 | @wp.func 9 | def jon_loss(x: float,alpha:float, c:float): 10 | pow_part = (((x/c)**2.0)/wp.abs(alpha-2.0) + 1.0) 11 | return (wp.abs(alpha-2.0)/alpha) * (wp.pow(pow_part,alpha/2.0) - 1.0) 12 | 13 | @wp.kernel 14 | def atap_loss(cur_means: wp.array(dtype = wp.vec3), dists: wp.array(dtype = float), ids: wp.array(dtype = int), 15 | match_ids: wp.array(dtype = int), group_ids1: wp.array(dtype = int), group_ids2: wp.array(dtype=int), 16 | connectivity_weights: wp.array(dtype = float,ndim = 2), loss: wp.array(dtype = float), alpha: float): 17 | tid = wp.tid() 18 | id1 = ids[tid] 19 | id2 = match_ids[tid] 20 | gid1 = group_ids1[tid] 21 | gid2 = group_ids2[tid] 22 | con_weight = connectivity_weights[gid1,gid2] 23 | curdist = wp.length(cur_means[id1] - cur_means[id2]) 24 | loss[tid] = jon_loss(curdist - dists[tid], alpha, 0.001) * con_weight * .001 25 | 26 | 27 | class ATAPLoss: 28 | touch_radius: float = .0015 29 | N: int = 500 30 | loss_mult: float = .2 31 | loss_alpha: float = 1.0 #rule: for jointed, use 1.0 alpha, for non-jointed use 0.1 ish 32 | def __init__(self, sms_model: POGSModel, group_masks: List[torch.Tensor], group_labels: torch.Tensor, dataset_scale: float = 1.0): 33 | """ 34 | Initializes the data structure to compute the loss between groups touching 35 | """ 36 | self.touch_radius = self.touch_radius * dataset_scale 37 | print(f"Touch radius is {self.touch_radius}") 38 | self.sms_model = sms_model 39 | self.group_masks = group_masks 40 | self.group_labels = group_labels 41 | self.nn_info = [] 42 | for grp in self.group_masks: 43 | with torch.no_grad(): 44 | dists, ids, match_ids, group_ids1, group_ids = self._radius_nn(grp, self.touch_radius) 45 | self.nn_info.append((dists, ids, match_ids, group_ids1, group_ids)) 46 | print(f"Group {len(self.nn_info)} has {len(ids)} neighbors") 47 | self.dists = torch.cat([x[0] for x in self.nn_info]).cuda() 48 | self.ids = torch.cat([x[1] for x in self.nn_info]).cuda().int() 49 | self.match_ids = torch.cat([x[2] for x in self.nn_info]).cuda().int() 50 | self.group_ids1 = torch.cat([x[3] for x in self.nn_info]).cuda().int() 51 | self.group_ids2 = torch.cat([x[4] for x in self.nn_info]).cuda().int() 52 | self.num_pairs = torch.cat([torch.tensor(len(x[1])).repeat(len(x[1])) for x in self.nn_info]).cuda().float() 53 | 54 | 55 | def __call__(self, connectivity_weights: torch.Tensor): 56 | """ 57 | Computes the loss between groups touching 58 | connectivity_weights: a tensor of shape (num_groups,num_groups) representing the weights between each group 59 | 60 | returns: a differentiable loss 61 | """ 62 | if len(self.group_masks) == 1: 63 | return torch.tensor(0.0,device='cuda') 64 | if self.dists.shape[0] == 0: 65 | return torch.tensor(0.0,device='cuda') 66 | assert connectivity_weights.shape == (len(self.group_masks),len(self.group_masks)), "connectivity weights must be a square matrix of size num_groups" 67 | loss = wp.empty(self.dists.shape[0], dtype=wp.float32, requires_grad=True, device='cuda') 68 | wp.launch( 69 | dim = self.dists.shape[0], 70 | kernel = atap_loss, 71 | inputs = [wp.from_torch(self.sms_model.gauss_params['means'],dtype=wp.vec3),wp.from_torch(self.dists), 72 | wp.from_torch(self.ids),wp.from_torch(self.match_ids),wp.from_torch(self.group_ids1), 73 | wp.from_torch(self.group_ids2),wp.from_torch(connectivity_weights),loss, self.loss_alpha] 74 | ) 75 | return (wp.to_torch(loss)/self.num_pairs).sum()*self.loss_mult 76 | 77 | 78 | def _radius_nn(self, group_mask: torch.Tensor, r: float): 79 | """ 80 | returns the nearest neighbors to gaussians in a group within a certain radius (and outside that group) 81 | returns -1 indices for neighbors outside the radius or within the same group 82 | """ 83 | global_group_ids = torch.zeros(self.sms_model.num_points,dtype=torch.long,device='cuda') 84 | for i,grp in enumerate(self.group_masks): 85 | global_group_ids[grp] = i 86 | from cuml.neighbors import NearestNeighbors 87 | model = NearestNeighbors(n_neighbors=self.N) 88 | means = self.sms_model.means.detach().cpu().numpy() 89 | model.fit(means) 90 | dists, match_ids = model.kneighbors(means) 91 | dists, match_ids = torch.tensor(dists,dtype=torch.float32,device='cuda'),torch.tensor(match_ids,dtype=torch.long,device='cuda') 92 | dists, match_ids = dists[group_mask], match_ids[group_mask] 93 | # filter matches outside the radius 94 | match_ids[dists>r] = -1 95 | # filter out ones within same group mask 96 | match_ids[group_mask[match_ids]] = -1 97 | ids = torch.arange(self.sms_model.num_points,dtype=torch.long,device='cuda')[group_mask].unsqueeze(-1).repeat(1,self.N) 98 | #flatten all the ids/dists/match_ids 99 | ids = ids[match_ids!=-1].flatten() 100 | dists = dists[match_ids!=-1].flatten() 101 | match_ids = match_ids[match_ids!=-1].flatten() 102 | return dists, ids, match_ids, global_group_ids[ids], global_group_ids[match_ids] -------------------------------------------------------------------------------- /pogs/tracking/data/ZED2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/pogs/tracking/data/ZED2.stl -------------------------------------------------------------------------------- /pogs/tracking/data/ZEDM.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/pogs/tracking/data/ZEDM.stl -------------------------------------------------------------------------------- /pogs/tracking/motion.py: -------------------------------------------------------------------------------- 1 | from ur5py.ur5 import UR5Robot 2 | import numpy as np 3 | from autolab_core import RigidTransform 4 | 5 | WRIST_TO_CAM = RigidTransform.load("/home/lifelong/sms/sms/ur5_interface/ur5_interface/calibration_outputs/wrist_to_cam.tf") 6 | 7 | class Motion: 8 | def __init__(self, robot: UR5Robot): 9 | self.robot = robot 10 | # robot = UR5Robot(gripper=1) 11 | self.clear_tcp() 12 | 13 | home_joints = np.array([0.30947089195251465, -1.2793572584735315, -2.035713497792379, -1.388848606740133, 1.5713528394699097, 0.34230729937553406]) 14 | robot.move_joint(home_joints,vel=1.0,acc=0.1) 15 | world_to_wrist = robot.get_pose() 16 | world_to_wrist.from_frame = "wrist" 17 | world_to_cam = world_to_wrist * WRIST_TO_CAM 18 | proper_world_to_cam_rotation = np.array([[0,1,0],[1,0,0],[0,0,-1]]) 19 | self.home_world_to_cam = RigidTransform(rotation=proper_world_to_cam_rotation,translation=world_to_cam.translation,from_frame='cam',to_frame='world') 20 | self.home_world_to_wrist = self.home_world_to_cam * WRIST_TO_CAM.inverse() 21 | 22 | self.robot.move_pose(self.home_world_to_wrist,vel=1.0,acc=0.1) 23 | 24 | 25 | def clear_tcp(self): 26 | tool_to_wrist = RigidTransform() 27 | tool_to_wrist.translation = np.array([0, 0, 0]) 28 | tool_to_wrist.from_frame = "tool" 29 | tool_to_wrist.to_frame = "wrist" 30 | self.robot.set_tcp(tool_to_wrist) 31 | 32 | -------------------------------------------------------------------------------- /pogs/tracking/observation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Callable, Generic, TypeVar 3 | from nerfstudio.cameras.cameras import Cameras 4 | from torchvision.transforms.functional import resize 5 | from pogs.tracking.utils import * 6 | from pogs.tracking.utils2 import * 7 | from copy import deepcopy 8 | 9 | T = TypeVar('T') 10 | class Future(Generic[T]): 11 | """ 12 | A simple wrapper for deferred execution of a callable until retrieved 13 | """ 14 | def __init__(self,callable): 15 | self.callable = callable 16 | self.executed = False 17 | 18 | def retrieve(self): 19 | if not self.executed: 20 | self.result = self.callable() 21 | self.executed = True 22 | return self.result 23 | 24 | class Frame: 25 | rasterize_resolution: int = 500 26 | camera: Cameras 27 | rgb: torch.Tensor 28 | metric_depth: bool 29 | _depth: Future[torch.Tensor] 30 | _dino_feats: Future[torch.Tensor] 31 | # _hand_mask: Future[torch.Tensor] 32 | 33 | @property 34 | def depth(self): 35 | return self._depth.retrieve() 36 | 37 | @property 38 | def dino_feats(self): 39 | return self._dino_feats.retrieve() 40 | 41 | # @property 42 | # def hand_mask(self): 43 | # return self._hand_mask.retrieve() 44 | 45 | @property 46 | def mask(self): 47 | return self._mask.retrieve() 48 | 49 | def __init__(self, rgb: torch.Tensor, camera: Cameras, dino_fn: Callable, metric_depth_img: Optional[torch.Tensor], 50 | xmin: Optional[float] = None, xmax: Optional[float] = None, ymin: Optional[float] = None, ymax: Optional[float] = None): 51 | 52 | self.camera = deepcopy(camera.to('cuda')) 53 | 54 | self._dino_fn = dino_fn 55 | self.rgb = resize( 56 | rgb.permute(2, 0, 1), 57 | (camera.height, camera.width), 58 | antialias=True, 59 | ).permute(1, 2, 0) 60 | self.metric_depth = metric_depth_img is not None 61 | self.obj_mask = None 62 | 63 | 64 | @torch.no_grad() 65 | def _get_depth(): 66 | if metric_depth_img is not None: 67 | depth = metric_depth_img 68 | else: 69 | raise FileNotFoundError 70 | depth = resize( 71 | depth.unsqueeze(0), 72 | (camera.height, camera.width), 73 | antialias=True, 74 | ).squeeze().unsqueeze(-1) 75 | return depth 76 | self._depth = Future(_get_depth) 77 | @torch.no_grad() 78 | def _get_dino(): 79 | dino_feats = dino_fn( 80 | rgb.permute(2, 0, 1).unsqueeze(0) 81 | ).squeeze() 82 | dino_feats = resize( 83 | dino_feats.permute(2, 0, 1), 84 | (camera.height, camera.width), 85 | antialias=True, 86 | ).permute(1, 2, 0) 87 | return dino_feats 88 | self._dino_feats = Future(_get_dino) 89 | # @torch.no_grad() 90 | # def _get_hand_mask(): 91 | # hand_mask = get_hand_mask((self.rgb * 255).to(torch.uint8)) 92 | # hand_mask = ( 93 | # torch.nn.functional.max_pool2d( 94 | # hand_mask[None, None], 3, padding=1, stride=1 95 | # ).squeeze() 96 | # == 0.0 97 | # ) 98 | # return hand_mask 99 | # self._hand_mask = Future(_get_hand_mask) 100 | @torch.no_grad() 101 | def _get_mask(): 102 | obj_mask = resize( 103 | self.obj_mask.unsqueeze(0), 104 | (camera.height, camera.width), 105 | antialias=True, 106 | ).squeeze(0) 107 | return obj_mask 108 | self._mask = Future(_get_mask) 109 | self.xmin, self.xmax, self.ymin, self.ymax = xmin, xmax, ymin, ymax 110 | 111 | 112 | 113 | class PosedObservation: 114 | """ 115 | Class for computing relevant data products for a frame and storing them 116 | """ 117 | max_roi_resolution: int = 490 118 | _frame: Frame 119 | _raw_rgb: torch.Tensor 120 | _original_camera: Cameras 121 | _original_depth: Optional[torch.Tensor] = None 122 | _roi_frames: Optional[List[Frame]] = None 123 | 124 | def __init__(self, rgb: torch.Tensor, camera: Cameras, dino_fn: Callable, metric_depth_img: Optional[torch.Tensor] = None): 125 | """ 126 | Initialize the frame 127 | 128 | rgb: HxWx3 tensor of the rgb image, normalized to [0,1] 129 | camera: Cameras object for the camera intrinsics and extrisics to render the frame at 130 | dino_fn: callable taking in 3HxW RGB image and outputting dino features CxHxW 131 | metric_depth_img: HxWx1 tensor of metric depth, if desired 132 | """ 133 | assert rgb.shape[0] == camera.height and rgb.shape[1] == camera.width, f"Input image should be the same size as the camera, got {rgb.shape} and {camera.height}x{camera.width}" 134 | self._dino_fn = dino_fn 135 | assert rgb.shape[-1] == 3, rgb.shape 136 | self._raw_rgb = rgb 137 | if metric_depth_img is not None: 138 | self._original_depth = metric_depth_img 139 | self._original_camera = deepcopy(camera.to('cuda')) 140 | cam = deepcopy(camera.to('cuda')) 141 | 142 | self._frame = Frame(rgb, cam, dino_fn, metric_depth_img) 143 | self._roi_frames = [] 144 | self._obj_masks = None 145 | 146 | 147 | @property 148 | def frame(self): 149 | return self._frame 150 | 151 | @property 152 | def roi_frames(self): 153 | if len(self._roi_frames) == 0: 154 | raise RuntimeError("ROIs not set") 155 | return self._roi_frames 156 | 157 | def add_roi(self, xmin, xmax, ymin, ymax): 158 | assert xmin < xmax and ymin < ymax 159 | assert xmin >= 0 and ymin >= 0 160 | assert xmax <= 1.0 and ymax <= 1.0, "xmin and ymin should be normalized" 161 | # convert normalized to pixels in original image 162 | xmin,xmax,ymin,ymax = int(xmin*(self._original_camera.width-1)), int(xmax*(self._original_camera.width-1)),\ 163 | int(ymin*(self._original_camera.height-1)), int(ymax*(self._original_camera.height-1)) 164 | # adjust these value to be multiples of 14, dino patch size 165 | xlen = ((xmax - xmin)//14) * 14 166 | ylen = ((ymax - ymin)//14) * 14 167 | xmax = xmin + xlen 168 | ymax = ymin + ylen 169 | rgb = self._raw_rgb[ymin:ymax, xmin:xmax].clone() 170 | camera = crop_camera(self._original_camera, xmin, xmax, ymin, ymax) 171 | if max(camera.width.item(),camera.height.item()) > self.max_roi_resolution: 172 | camera.rescale_output_resolution(self.max_roi_resolution/max(camera.width.item(),camera.height.item())) 173 | depth = self._original_depth[ymin:ymax, xmin:xmax].clone().squeeze(-1) 174 | self._roi_frames.append(Frame(rgb, camera, self._dino_fn, depth, xmin, xmax, ymin, ymax)) 175 | 176 | def update_roi(self, idx, xmin, xmax, ymin, ymax): 177 | assert len(self._roi_frames) > idx 178 | assert xmin < xmax and ymin < ymax 179 | assert xmin >= 0 and ymin >= 0 180 | assert xmax <= 1.0 and ymax <= 1.0, "xmin and ymin should be normalized" 181 | # convert normalized to pixels in original image 182 | xmin,xmax,ymin,ymax = int(xmin*(self._original_camera.width-1)), int(xmax*(self._original_camera.width-1)),\ 183 | int(ymin*(self._original_camera.height-1)), int(ymax*(self._original_camera.height-1)) 184 | # adjust these value to be multiples of 14, dino patch size 185 | xlen = ((xmax - xmin)//14) * 14 186 | ylen = ((ymax - ymin)//14) * 14 187 | xmax = xmin + xlen 188 | ymax = ymin + ylen 189 | rgb = self._raw_rgb[ymin:ymax, xmin:xmax].clone() 190 | camera = crop_camera(self._original_camera, xmin, xmax, ymin, ymax) 191 | if max(camera.width.item(),camera.height.item()) > self.max_roi_resolution: 192 | camera.rescale_output_resolution(self.max_roi_resolution/max(camera.width.item(),camera.height.item())) 193 | depth = self._original_depth[ymin:ymax, xmin:xmax].clone().squeeze(-1) 194 | 195 | self._roi_frames[idx] = Frame(rgb, camera, self._dino_fn, depth, xmin, xmax, ymin, ymax) 196 | if len(self._obj_masks) > 0: 197 | 198 | self._roi_frames[idx].obj_mask = self._obj_masks[idx].squeeze(0)[ymin:ymax, xmin:xmax].clone() -------------------------------------------------------------------------------- /pogs/tracking/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """Rigid transforms implemented in PyTorch, ported from jaxlie.""" 2 | 3 | from . import utils as utils 4 | from ._base import MatrixLieGroup as MatrixLieGroup 5 | from ._base import SEBase as SEBase 6 | from ._base import SOBase as SOBase 7 | from ._se3 import SE3 as SE3 8 | from ._so3 import SO3 as SO3 -------------------------------------------------------------------------------- /pogs/tracking/transforms/_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload 3 | 4 | import numpy as onp 5 | import torch 6 | from torch import Tensor 7 | from typing_extensions import Self, final, override 8 | 9 | GroupType = TypeVar("GroupType", bound="MatrixLieGroup") 10 | SEGroupType = TypeVar("SEGroupType", bound="SEBase") 11 | 12 | 13 | class MatrixLieGroup(abc.ABC): 14 | """Interface definition for matrix Lie groups.""" 15 | 16 | # Class properties. 17 | # > These will be set in `_utils.register_lie_group()`. 18 | 19 | matrix_dim: ClassVar[int] 20 | """Dimension of square matrix output from `.as_matrix()`.""" 21 | 22 | parameters_dim: ClassVar[int] 23 | """Dimension of underlying parameters, `.parameters()`.""" 24 | 25 | tangent_dim: ClassVar[int] 26 | """Dimension of tangent space.""" 27 | 28 | space_dim: ClassVar[int] 29 | """Dimension of coordinates that can be transformed.""" 30 | 31 | def __init__( 32 | # Notes: 33 | # - For the constructor signature to be consistent with subclasses, `parameters` 34 | # should be marked as positional-only. But this isn't possible in Python 3.7. 35 | # - This method is implicitly overriden by the dataclass decorator and 36 | # should _not_ be marked abstract. 37 | self, 38 | parameters: Tensor, 39 | ): 40 | """Construct a group object from its underlying parameters.""" 41 | raise NotImplementedError() 42 | 43 | # Shared implementations. 44 | 45 | @overload 46 | def __matmul__(self: GroupType, other: GroupType) -> GroupType: ... 47 | 48 | @overload 49 | def __matmul__(self, other: Tensor) -> Tensor: ... 50 | 51 | def __matmul__( 52 | self: GroupType, other: Union[GroupType, Tensor] 53 | ) -> Union[GroupType, Tensor]: 54 | """Overload for the `@` operator. 55 | 56 | Switches between the group action (`.apply()`) and multiplication 57 | (`.multiply()`) based on the type of `other`. 58 | """ 59 | if isinstance(other, (onp.ndarray, Tensor)): 60 | return self.apply(target=other) 61 | elif isinstance(other, MatrixLieGroup): 62 | assert self.space_dim == other.space_dim 63 | return self.multiply(other=other) 64 | else: 65 | assert False, f"Invalid argument type for `@` operator: {type(other)}" 66 | 67 | # Factory. 68 | 69 | @classmethod 70 | @abc.abstractmethod 71 | def identity( 72 | cls: Type[GroupType], device: Union[torch.device, str], dtype: torch.dtype 73 | ) -> GroupType: 74 | """Returns identity element. 75 | 76 | Returns: 77 | Identity element. 78 | """ 79 | 80 | @classmethod 81 | @abc.abstractmethod 82 | def from_matrix(cls: Type[GroupType], matrix: Tensor) -> GroupType: 83 | """Get group member from matrix representation. 84 | 85 | Args: 86 | matrix: Matrix representaiton. 87 | 88 | Returns: 89 | Group member. 90 | """ 91 | 92 | # Accessors. 93 | 94 | @abc.abstractmethod 95 | def as_matrix(self) -> Tensor: 96 | """Get transformation as a matrix. Homogeneous for SE groups.""" 97 | 98 | @abc.abstractmethod 99 | def parameters(self) -> Tensor: 100 | """Get underlying representation.""" 101 | 102 | # Operations. 103 | 104 | @abc.abstractmethod 105 | def apply(self, target: Tensor) -> Tensor: 106 | """Applies group action to a point. 107 | 108 | Args: 109 | target: Point to transform. 110 | 111 | Returns: 112 | Transformed point. 113 | """ 114 | 115 | @abc.abstractmethod 116 | def multiply(self: Self, other: Self) -> Self: 117 | """Composes this transformation with another. 118 | 119 | Returns: 120 | self @ other 121 | """ 122 | 123 | @classmethod 124 | @abc.abstractmethod 125 | def exp(cls: Type[GroupType], tangent: Tensor) -> GroupType: 126 | """Computes `expm(wedge(tangent))`. 127 | 128 | Args: 129 | tangent: Tangent vector to take the exponential of. 130 | 131 | Returns: 132 | Output. 133 | """ 134 | 135 | @abc.abstractmethod 136 | def log(self) -> Tensor: 137 | """Computes `vee(logm(transformation matrix))`. 138 | 139 | Returns: 140 | Output. Shape should be `(tangent_dim,)`. 141 | """ 142 | 143 | @abc.abstractmethod 144 | def adjoint(self) -> Tensor: 145 | """Computes the adjoint, which transforms tangent vectors between tangent 146 | spaces. 147 | 148 | More precisely, for a transform `GroupType`: 149 | ``` 150 | GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType 151 | ``` 152 | 153 | In robotics, typically used for transforming twists, wrenches, and Jacobians 154 | across different reference frames. 155 | 156 | Returns: 157 | Output. Shape should be `(tangent_dim, tangent_dim)`. 158 | """ 159 | 160 | @abc.abstractmethod 161 | def inverse(self: GroupType) -> GroupType: 162 | """Computes the inverse of our transform. 163 | 164 | Returns: 165 | Output. 166 | """ 167 | 168 | @abc.abstractmethod 169 | def normalize(self: GroupType) -> GroupType: 170 | """Normalize/projects values and returns. 171 | 172 | Returns: 173 | GroupType: Normalized group member. 174 | """ 175 | 176 | # @classmethod 177 | # @abc.abstractmethod 178 | # def sample_uniform(cls: Type[GroupType], key: Tensor) -> GroupType: 179 | # """Draw a uniform sample from the group. Translations (if applicable) are in the 180 | # range [-1, 1]. 181 | # 182 | # Args: 183 | # key: PRNG key, as returned by `jax.random.PRNGKey()`. 184 | # 185 | # Returns: 186 | # Sampled group member. 187 | # """ 188 | 189 | def get_batch_axes(self) -> Tuple[int, ...]: 190 | """Return any leading batch axes in contained parameters. If an array of shape 191 | `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will 192 | return `(100,)`.""" 193 | return self.parameters().shape[:-1] 194 | 195 | 196 | class SOBase(MatrixLieGroup): 197 | """Base class for special orthogonal groups.""" 198 | 199 | 200 | ContainedSOType = TypeVar("ContainedSOType", bound=SOBase) 201 | 202 | 203 | class SEBase(Generic[ContainedSOType], MatrixLieGroup): 204 | """Base class for special Euclidean groups. 205 | 206 | Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional 207 | translation vector. 208 | """ 209 | 210 | # SE-specific interface. 211 | 212 | @classmethod 213 | @abc.abstractmethod 214 | def from_rotation_and_translation( 215 | cls: Type[SEGroupType], 216 | rotation: ContainedSOType, 217 | translation: Tensor, 218 | ) -> SEGroupType: 219 | """Construct a rigid transform from a rotation and a translation. 220 | 221 | Args: 222 | rotation: Rotation term. 223 | translation: Translation term. 224 | 225 | Returns: 226 | Constructed transformation. 227 | """ 228 | 229 | @final 230 | @classmethod 231 | def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType: 232 | return cls.from_rotation_and_translation( 233 | rotation=rotation, 234 | translation=rotation.parameters().new_zeros( 235 | (*rotation.parameters().shape[:-1], cls.space_dim), 236 | dtype=rotation.parameters().dtype, 237 | ), 238 | ) 239 | 240 | @abc.abstractmethod 241 | def rotation(self) -> ContainedSOType: 242 | """Returns a transform's rotation term.""" 243 | 244 | @abc.abstractmethod 245 | def translation(self) -> Tensor: 246 | """Returns a transform's translation term.""" 247 | 248 | # Overrides. 249 | 250 | @final 251 | @override 252 | def apply(self, target: Tensor) -> Tensor: 253 | return self.rotation() @ target + self.translation() # type: ignore 254 | 255 | @final 256 | @override 257 | def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: 258 | return type(self).from_rotation_and_translation( 259 | rotation=self.rotation() @ other.rotation(), 260 | translation=(self.rotation() @ other.translation()) + self.translation(), 261 | ) 262 | 263 | @final 264 | @override 265 | def inverse(self: SEGroupType) -> SEGroupType: 266 | R_inv = self.rotation().inverse() 267 | return type(self).from_rotation_and_translation( 268 | rotation=R_inv, 269 | translation=-(R_inv @ self.translation()), 270 | ) 271 | 272 | @final 273 | @override 274 | def normalize(self: SEGroupType) -> SEGroupType: 275 | return type(self).from_rotation_and_translation( 276 | rotation=self.rotation().normalize(), 277 | translation=self.translation(), 278 | ) -------------------------------------------------------------------------------- /pogs/tracking/transforms/_se3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import cast 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | from typing_extensions import Union, override 10 | 11 | from . import _base 12 | from ._so3 import SO3 13 | from .utils import get_epsilon, register_lie_group 14 | 15 | 16 | def _skew(omega: Tensor) -> Tensor: 17 | """ 18 | Returns the skew-symmetric form of a length-3 vector. 19 | :param omega (*, 3) 20 | :returns (*, 3, 3) 21 | """ 22 | 23 | wx, wy, wz = omega.unbind(dim=-1) 24 | o = torch.zeros_like(wx) 25 | return torch.stack( 26 | [o, -wz, wy, wz, o, -wx, -wy, wx, o], 27 | dim=-1, 28 | ).reshape(*wx.shape, 3, 3) 29 | 30 | 31 | @register_lie_group( 32 | matrix_dim=4, 33 | parameters_dim=7, 34 | tangent_dim=6, 35 | space_dim=3, 36 | ) 37 | @dataclass(frozen=True) 38 | class SE3(_base.SEBase[SO3]): 39 | """Special Euclidean group for proper rigid transforms in 3D. 40 | 41 | Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization 42 | is `(vx, vy, vz, omega_x, omega_y, omega_z)`. 43 | """ 44 | 45 | # SE3-specific. 46 | 47 | wxyz_xyz: Tensor 48 | """Internal parameters. wxyz quaternion followed by xyz translation.""" 49 | 50 | @override 51 | def __repr__(self) -> str: 52 | quat = np.round(self.wxyz_xyz[..., :4].numpy(force=True), 5) 53 | trans = np.round(self.wxyz_xyz[..., 4:].numpy(force=True), 5) 54 | return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})" 55 | 56 | # SE-specific. 57 | 58 | @classmethod 59 | @override 60 | def from_rotation_and_translation( 61 | cls, 62 | rotation: SO3, 63 | translation: Tensor, 64 | ) -> SE3: 65 | assert translation.shape[-1] == 3 66 | return SE3(wxyz_xyz=torch.cat([rotation.wxyz, translation], dim=-1)) 67 | 68 | @override 69 | def rotation(self) -> SO3: 70 | return SO3(wxyz=self.wxyz_xyz[..., :4]) 71 | 72 | @override 73 | def translation(self) -> Tensor: 74 | return self.wxyz_xyz[..., 4:] 75 | 76 | # Factory. 77 | 78 | @classmethod 79 | @override 80 | def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SE3: 81 | return SE3( 82 | wxyz_xyz=torch.tensor( 83 | [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device=device, dtype=dtype 84 | ) 85 | ) 86 | 87 | @classmethod 88 | @override 89 | def from_matrix(cls, matrix: Tensor) -> SE3: 90 | assert matrix.shape[-2:] == (4, 4) 91 | # Currently assumes bottom row is [0, 0, 0, 1]. 92 | return SE3.from_rotation_and_translation( 93 | rotation=SO3.from_matrix(matrix[..., :3, :3]), 94 | translation=matrix[..., :3, 3], 95 | ) 96 | 97 | # Accessors. 98 | 99 | @override 100 | def as_matrix(self) -> Tensor: 101 | R = self.rotation().as_matrix() # (*, 3, 3) 102 | t = self.translation().unsqueeze(-1) # (*, 3, 1) 103 | dims = R.shape[:-2] 104 | bottom = ( 105 | torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device) 106 | .reshape(*(1,) * len(dims), 1, 4) 107 | .repeat(*dims, 1, 1) 108 | ) 109 | return torch.cat([torch.cat([R, t], dim=-1), bottom], dim=-2) 110 | 111 | @override 112 | def parameters(self) -> Tensor: 113 | return self.wxyz_xyz 114 | 115 | # Operations. 116 | 117 | @classmethod 118 | @override 119 | def exp(cls, tangent: Tensor) -> SE3: 120 | # Reference: 121 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 122 | 123 | # (x, y, z, omega_x, omega_y, omega_z) 124 | assert tangent.shape[-1] == 6 125 | 126 | rotation = SO3.exp(tangent[..., 3:]) 127 | 128 | theta_squared = torch.square(tangent[..., 3:]).sum(dim=-1) # (*) 129 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 130 | 131 | theta_squared_safe = cast( 132 | Tensor, 133 | torch.where( 134 | use_taylor, 135 | 1.0, # Any non-zero value should do here. 136 | theta_squared, 137 | ), 138 | ) 139 | del theta_squared 140 | theta_safe = torch.sqrt(theta_squared_safe) 141 | 142 | skew_omega = _skew(tangent[..., 3:]) 143 | dtype = skew_omega.dtype 144 | device = skew_omega.device 145 | V = torch.where( 146 | use_taylor[..., None, None], 147 | rotation.as_matrix(), 148 | ( 149 | torch.eye(3, device=device, dtype=dtype) 150 | + ((1.0 - torch.cos(theta_safe)) / (theta_squared_safe))[ 151 | ..., None, None 152 | ] 153 | * skew_omega 154 | + ( 155 | (theta_safe - torch.sin(theta_safe)) 156 | / (theta_squared_safe * theta_safe) 157 | )[..., None, None] 158 | * torch.einsum("...ij,...jk->...ik", skew_omega, skew_omega) 159 | ), 160 | ) 161 | 162 | return SE3.from_rotation_and_translation( 163 | rotation=rotation, 164 | translation=torch.einsum("...ij,...j->...i", V, tangent[..., :3]), 165 | ) 166 | 167 | 168 | @override 169 | def log(self) -> Tensor: 170 | # Reference: 171 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 172 | omega = self.rotation().log() 173 | theta_squared = torch.square(omega).sum(dim=-1) # (*) 174 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 175 | 176 | skew_omega = _skew(omega) 177 | 178 | # Shim to avoid NaNs in jnp.where branches, which cause failures for 179 | # reverse-mode AD. 180 | theta_squared_safe = torch.where( 181 | use_taylor, 182 | 1.0, # Any non-zero value should do here. 183 | theta_squared, 184 | ) 185 | del theta_squared 186 | theta_safe = torch.sqrt(theta_squared_safe) 187 | half_theta_safe = theta_safe / 2.0 188 | 189 | dtype = omega.dtype 190 | device = omega.device 191 | V_inv = torch.where( 192 | use_taylor[..., None, None], 193 | torch.eye(3, device=device, dtype=dtype) 194 | - 0.5 * skew_omega 195 | + torch.matmul(skew_omega, skew_omega) / 12.0, 196 | ( 197 | torch.eye(3, device=device, dtype=dtype) 198 | - 0.5 * skew_omega 199 | + ( 200 | 1.0 201 | - theta_safe 202 | * torch.cos(half_theta_safe) 203 | / (2.0 * torch.sin(half_theta_safe)) 204 | )[..., None, None] 205 | / theta_squared_safe[..., None, None] 206 | * torch.matmul(skew_omega, skew_omega) 207 | ), 208 | ) 209 | return torch.cat( 210 | [torch.einsum("...ij,...j->...i", V_inv, self.translation()), omega], dim=-1 211 | ) 212 | 213 | @override 214 | def adjoint(self) -> Tensor: 215 | R = self.rotation().as_matrix() 216 | dims = R.shape[:-2] 217 | # (*, 6, 6) 218 | return torch.cat( 219 | [ 220 | torch.cat([R, torch.matmul(_skew(self.translation()), R)], dim=-1), 221 | torch.cat([torch.zeros((*dims, 3, 3)), R], dim=-1), 222 | ], 223 | dim=-2, 224 | ) -------------------------------------------------------------------------------- /pogs/tracking/transforms/_so3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | from typing_extensions import override 10 | 11 | from . import _base 12 | from .utils import get_epsilon, register_lie_group 13 | 14 | 15 | @register_lie_group( 16 | matrix_dim=3, 17 | parameters_dim=4, 18 | tangent_dim=3, 19 | space_dim=3, 20 | ) 21 | @dataclass(frozen=True) 22 | class SO3(_base.SOBase): 23 | """Special orthogonal group for 3D rotations. 24 | 25 | Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is 26 | `(omega_x, omega_y, omega_z)`. 27 | """ 28 | 29 | # SO3-specific. 30 | 31 | wxyz: Tensor 32 | """Internal parameters. `(w, x, y, z)` quaternion.""" 33 | 34 | @override 35 | def __repr__(self) -> str: 36 | wxyz = np.round(self.wxyz.numpy(force=True), 5) 37 | return f"{self.__class__.__name__}(wxyz={wxyz})" 38 | 39 | @staticmethod 40 | def from_x_radians(theta: Tensor) -> SO3: 41 | """Generates a x-axis rotation. 42 | 43 | Args: 44 | angle: X rotation, in radians. 45 | 46 | Returns: 47 | Output. 48 | """ 49 | zeros = torch.zeros_like(theta) 50 | return SO3.exp(torch.stack([theta, zeros, zeros], dim=-1)) 51 | 52 | @staticmethod 53 | def from_y_radians(theta: Tensor) -> SO3: 54 | """Generates a y-axis rotation. 55 | 56 | Args: 57 | angle: Y rotation, in radians. 58 | 59 | Returns: 60 | Output. 61 | """ 62 | zeros = torch.zeros_like(theta) 63 | return SO3.exp(torch.stack([zeros, theta, zeros], dim=-1)) 64 | 65 | @staticmethod 66 | def from_z_radians(theta: Tensor) -> SO3: 67 | """Generates a z-axis rotation. 68 | 69 | Args: 70 | angle: Z rotation, in radians. 71 | 72 | Returns: 73 | Output. 74 | """ 75 | zeros = torch.zeros_like(theta) 76 | return SO3.exp(torch.stack([zeros, zeros, theta], dim=-1)) 77 | 78 | @staticmethod 79 | def from_rpy_radians( 80 | roll: Tensor, 81 | pitch: Tensor, 82 | yaw: Tensor, 83 | ) -> SO3: 84 | """Generates a transform from a set of Euler angles. Uses the ZYX mobile robot 85 | convention. 86 | 87 | Args: 88 | roll: X rotation, in radians. Applied first. 89 | pitch: Y rotation, in radians. Applied second. 90 | yaw: Z rotation, in radians. Applied last. 91 | 92 | Returns: 93 | Output. 94 | """ 95 | return ( 96 | SO3.from_z_radians(yaw) 97 | @ SO3.from_y_radians(pitch) 98 | @ SO3.from_x_radians(roll) 99 | ) 100 | 101 | @staticmethod 102 | def from_quaternion_xyzw(xyzw: Tensor) -> SO3: 103 | """Construct a rotation from an `xyzw` quaternion. 104 | 105 | Note that `wxyz` quaternions can be constructed using the default dataclass 106 | constructor. 107 | 108 | Args: 109 | xyzw: xyzw quaternion. Shape should be (4,). 110 | 111 | Returns: 112 | Output. 113 | """ 114 | assert xyzw.shape == (4,) 115 | return SO3(torch.roll(xyzw, shifts=1, dims=-1)) 116 | 117 | def as_quaternion_xyzw(self) -> Tensor: 118 | """Grab parameters as xyzw quaternion.""" 119 | return torch.roll(self.wxyz, shifts=-1, dims=-1) 120 | 121 | # Factory. 122 | 123 | @classmethod 124 | @override 125 | def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SO3: 126 | return SO3(wxyz=torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype)) 127 | 128 | @classmethod 129 | @override 130 | def from_matrix(cls, matrix: Tensor) -> SO3: 131 | assert matrix.shape[-2:] == (3, 3) 132 | 133 | # Modified from: 134 | # > "Converting a Rotation Matrix to a Quaternion" from Mike Day 135 | # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf 136 | 137 | def case0(m): 138 | t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2] 139 | q = torch.stack( 140 | [ 141 | m[..., 2, 1] - m[..., 1, 2], 142 | t, 143 | m[..., 1, 0] + m[..., 0, 1], 144 | m[..., 0, 2] + m[..., 2, 0], 145 | ], 146 | dim=-1, 147 | ) 148 | return t, q 149 | 150 | def case1(m): 151 | t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2] 152 | q = torch.stack( 153 | [ 154 | m[..., 0, 2] - m[..., 2, 0], 155 | m[..., 1, 0] + m[..., 0, 1], 156 | t, 157 | m[..., 2, 1] + m[..., 1, 2], 158 | ], 159 | dim=-1, 160 | ) 161 | return t, q 162 | 163 | def case2(m): 164 | t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2] 165 | q = torch.stack( 166 | [ 167 | m[..., 1, 0] - m[..., 0, 1], 168 | m[..., 0, 2] + m[..., 2, 0], 169 | m[..., 2, 1] + m[..., 1, 2], 170 | t, 171 | ], 172 | dim=-1, 173 | ) 174 | return t, q 175 | 176 | def case3(m): 177 | t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] 178 | q = torch.stack( 179 | [ 180 | t, 181 | m[..., 2, 1] - m[..., 1, 2], 182 | m[..., 0, 2] - m[..., 2, 0], 183 | m[..., 1, 0] - m[..., 0, 1], 184 | ], 185 | dim=-1, 186 | ) 187 | return t, q 188 | 189 | # Compute four cases, then pick the most precise one. 190 | # Probably worth revisiting this! 191 | case0_t, case0_q = case0(matrix) 192 | case1_t, case1_q = case1(matrix) 193 | case2_t, case2_q = case2(matrix) 194 | case3_t, case3_q = case3(matrix) 195 | 196 | cond0 = matrix[..., 2, 2] < 0 197 | cond1 = matrix[..., 0, 0] > matrix[..., 1, 1] 198 | cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1] 199 | 200 | t = torch.where( 201 | cond0, 202 | torch.where(cond1, case0_t, case1_t), 203 | torch.where(cond2, case2_t, case3_t), 204 | ) 205 | q = torch.where( 206 | cond0[..., None], 207 | torch.where(cond1[..., None], case0_q, case1_q), 208 | torch.where(cond2[..., None], case2_q, case3_q), 209 | ) 210 | return SO3(wxyz=q * 0.5 / torch.sqrt(t[..., None])) 211 | 212 | # Accessors. 213 | 214 | @override 215 | def as_matrix(self) -> Tensor: 216 | norm_sq = torch.square(self.wxyz).sum(dim=-1, keepdim=True) 217 | qvec = self.wxyz * torch.sqrt(2.0 / norm_sq) # (*, 4) 218 | Q = torch.einsum("...i,...j->...ij", qvec, qvec) # (*, 4, 4) 219 | return torch.stack( 220 | [ 221 | 1.0 - Q[..., 2, 2] - Q[..., 3, 3], 222 | Q[..., 1, 2] - Q[..., 3, 0], 223 | Q[..., 1, 3] + Q[..., 2, 0], 224 | Q[..., 1, 2] + Q[..., 3, 0], 225 | 1.0 - Q[..., 1, 1] - Q[..., 3, 3], 226 | Q[..., 2, 3] - Q[..., 1, 0], 227 | Q[..., 1, 3] - Q[..., 2, 0], 228 | Q[..., 2, 3] + Q[..., 1, 0], 229 | 1.0 - Q[..., 1, 1] - Q[..., 2, 2], 230 | ], 231 | dim=-1, 232 | ).reshape(*qvec.shape[:-1], 3, 3) 233 | 234 | @override 235 | def parameters(self) -> Tensor: 236 | return self.wxyz 237 | 238 | # Operations. 239 | 240 | @override 241 | def apply(self, target: Tensor) -> Tensor: 242 | assert target.shape[-1] == 3 243 | 244 | # Compute using quaternion multiplys. 245 | padded_target = torch.cat([torch.ones_like(target[..., :1]), target], dim=-1) 246 | out = self.multiply(SO3(wxyz=padded_target).multiply(self.inverse())) 247 | return out.wxyz[..., 1:] 248 | 249 | @override 250 | def multiply(self, other: SO3) -> SO3: 251 | w0, x0, y0, z0 = self.wxyz.unbind(dim=-1) 252 | w1, x1, y1, z1 = other.wxyz.unbind(dim=-1) 253 | wxyz2 = torch.stack( 254 | [ 255 | -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, 256 | x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, 257 | -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, 258 | x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, 259 | ], 260 | dim=-1, 261 | ) 262 | 263 | return SO3(wxyz=wxyz2) 264 | 265 | @classmethod 266 | @override 267 | def exp(cls, tangent: Tensor) -> SO3: 268 | # Reference: 269 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 270 | 271 | assert tangent.shape[-1] == 3 272 | 273 | theta_squared = torch.square(tangent).sum(dim=-1) # (*) 274 | theta_pow_4 = theta_squared * theta_squared 275 | use_taylor = theta_squared < get_epsilon(tangent.dtype) 276 | 277 | safe_theta = torch.sqrt( 278 | torch.where( 279 | use_taylor, 280 | torch.ones_like(theta_squared), # Any constant value should do here. 281 | theta_squared, 282 | ) 283 | ) 284 | safe_half_theta = 0.5 * safe_theta 285 | 286 | real_factor = torch.where( 287 | use_taylor, 288 | 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, 289 | torch.cos(safe_half_theta), 290 | ) 291 | 292 | imaginary_factor = torch.where( 293 | use_taylor, 294 | 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, 295 | torch.sin(safe_half_theta) / safe_theta, 296 | ) 297 | 298 | return SO3( 299 | wxyz=torch.cat( 300 | [ 301 | real_factor[..., None], 302 | imaginary_factor[..., None] * tangent, 303 | ], 304 | dim=-1, 305 | ) 306 | ) 307 | 308 | @override 309 | def log(self) -> Tensor: 310 | # Reference: 311 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 312 | 313 | w, xyz = torch.split(self.wxyz, [1, 3], dim=-1) # (*, 1), (*, 3) 314 | norm_sq = torch.square(xyz).sum(dim=-1, keepdim=True) # (*, 1) 315 | use_taylor = norm_sq < get_epsilon(norm_sq.dtype) 316 | 317 | norm_safe = torch.sqrt( 318 | torch.where( 319 | use_taylor, 320 | torch.ones_like(norm_sq), # Any non-zero value should do here. 321 | norm_sq, 322 | ) 323 | ) 324 | w_safe = torch.where(use_taylor, w, torch.ones_like(w)) 325 | atan_n_over_w = torch.atan2( 326 | torch.where(w < 0, -norm_safe, norm_safe), 327 | torch.abs(w), 328 | ) 329 | atan_factor = torch.where( 330 | use_taylor, 331 | 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3, 332 | torch.where( 333 | torch.abs(w) < get_epsilon(w.dtype), 334 | torch.where(w > 0, 1.0, -1.0) * torch.pi / norm_safe, 335 | 2.0 * atan_n_over_w / norm_safe, 336 | ), 337 | ) 338 | 339 | return atan_factor * xyz 340 | 341 | @override 342 | def adjoint(self) -> Tensor: 343 | return self.as_matrix() 344 | 345 | @override 346 | def inverse(self) -> SO3: 347 | # Negate complex terms. 348 | w, xyz = torch.split(self.wxyz, [1, 3], dim=-1) 349 | return SO3(wxyz=torch.cat([w, -xyz], dim=-1)) 350 | 351 | @override 352 | def normalize(self) -> SO3: 353 | return SO3(wxyz=self.wxyz / torch.linalg.norm(self.wxyz, dim=-1, keepdim=True)) -------------------------------------------------------------------------------- /pogs/tracking/transforms/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import get_epsilon, register_lie_group 2 | 3 | __all__ = ["get_epsilon", "register_lie_group"] -------------------------------------------------------------------------------- /pogs/tracking/transforms/utils/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable, Type, TypeVar 2 | 3 | import torch 4 | 5 | if TYPE_CHECKING: 6 | from .._base import MatrixLieGroup 7 | 8 | 9 | T = TypeVar("T", bound="MatrixLieGroup") 10 | 11 | 12 | def get_epsilon(dtype: torch.dtype) -> float: 13 | """Helper for grabbing type-specific precision constants. 14 | 15 | Args: 16 | dtype: Datatype. 17 | 18 | Returns: 19 | Output float. 20 | """ 21 | return { 22 | torch.float32: 1e-5, 23 | torch.float64: 1e-10, 24 | }[dtype] 25 | 26 | 27 | def register_lie_group( 28 | *, 29 | matrix_dim: int, 30 | parameters_dim: int, 31 | tangent_dim: int, 32 | space_dim: int, 33 | ) -> Callable[[Type[T]], Type[T]]: 34 | """Decorator for registering Lie group dataclasses. 35 | 36 | Sets dimensionality class variables. 37 | """ 38 | 39 | def _wrap(cls: Type[T]) -> Type[T]: 40 | # Register dimensions as class attributes. 41 | cls.matrix_dim = matrix_dim 42 | cls.parameters_dim = parameters_dim 43 | cls.tangent_dim = tangent_dim 44 | cls.space_dim = space_dim 45 | 46 | return cls 47 | 48 | return _wrap -------------------------------------------------------------------------------- /pogs/tracking/tri_zed.py: -------------------------------------------------------------------------------- 1 | import pyzed.sl as sl 2 | from typing import Optional, Tuple 3 | import torch 4 | import numpy as np 5 | from threading import Lock 6 | import plotly 7 | from plotly import express as px 8 | import trimesh 9 | from pathlib import Path 10 | # from raftstereo.raft_stereo import * 11 | from autolab_core import RigidTransform 12 | import torch 13 | import torch.nn.functional as tfn 14 | import os 15 | import numpy as np 16 | import pathlib 17 | from matplotlib import pyplot as plt 18 | torch._C._jit_set_profiling_executor(False) 19 | 20 | def is_tensor(data): 21 | return type(data) == torch.Tensor 22 | 23 | def is_tuple(data): 24 | return isinstance(data, tuple) 25 | 26 | def is_list(data): 27 | return isinstance(data, list) or isinstance(data, torch.nn.ModuleList) 28 | 29 | def is_dict(data): 30 | return isinstance(data, dict) or isinstance(data, torch.nn.ModuleDict) 31 | 32 | def is_seq(data): 33 | return is_tuple(data) or is_list(data) 34 | 35 | def iterate1(func): 36 | """Decorator to iterate over a list (first argument)""" 37 | def inner(var, *args, **kwargs): 38 | if is_seq(var): 39 | return [func(v, *args, **kwargs) for v in var] 40 | elif is_dict(var): 41 | return {key: func(val, *args, **kwargs) for key, val in var.items()} 42 | else: 43 | return func(var, *args, **kwargs) 44 | return inner 45 | 46 | 47 | @iterate1 48 | def interpolate(tensor, size, scale_factor, mode): 49 | if size is None and scale_factor is None: 50 | return tensor 51 | if is_tensor(size): 52 | size = size.shape[-2:] 53 | return tfn.interpolate( 54 | tensor, size=size, scale_factor=scale_factor, 55 | recompute_scale_factor=False, mode=mode, 56 | align_corners=None, 57 | ) 58 | 59 | 60 | def resize_input( 61 | rgb: torch.Tensor, 62 | intrinsics: torch.Tensor = None, 63 | resize: tuple = None 64 | ): 65 | """Resizes input data 66 | 67 | Args: 68 | rgb (torch.Tensor): input image (B,3,H,W) 69 | intrinsics (torch.Tensor): camera intrinsics (B,3,3) 70 | resize (tuple, optional): resize shape. Defaults to None. 71 | 72 | Returns: 73 | rgb: resized image (B,3,h,w) 74 | intrinsics: resized intrinsics (B,3,3) 75 | """ 76 | 77 | # Don't resize if not requested 78 | if resize is None: 79 | if intrinsics is None: 80 | return rgb 81 | else: 82 | return rgb, intrinsics 83 | # Resize rgb 84 | orig_shape = [float(v) for v in rgb.shape[-2:]] 85 | rgb = interpolate(rgb, mode="bilinear", scale_factor=None, size=resize) 86 | # Return only rgb if there are no intrinsics 87 | if intrinsics is None: 88 | return rgb 89 | # Resize intrinsics 90 | shape = [float(v) for v in rgb.shape[-2:]] 91 | intrinsics = intrinsics.clone() 92 | intrinsics[:, 0] *= shape[1] / orig_shape[1] 93 | intrinsics[:, 1] *= shape[0] / orig_shape[0] 94 | # return resized input 95 | return rgb, intrinsics 96 | 97 | def format_image(rgb): 98 | return torch.tensor(rgb.transpose(2,0,1)[None]).to(torch.float32).cuda() / 255.0 99 | class StereoModel(torch.nn.Module): 100 | """Learned Stereo model. 101 | 102 | Takes as input two images plus intrinsics and outputs a metrically scaled depth map. 103 | 104 | Taken from: https://github.com/ToyotaResearchInstitute/mmt_stereo_inference 105 | Paper here: https://arxiv.org/pdf/2109.11644.pdf 106 | Authors: Krishna Shankar, Mark Tjersland, Jeremy Ma, Kevin Stone, Max Bajracharya 107 | 108 | Pre-trained checkpoint here: s3://tri-ml-models/efm/depth/stereo.pt 109 | 110 | Args: 111 | cfg (Config): configuration file to initialize the model 112 | ckpt (str, optional): checkpoint path to load a pre-trained model. Defaults to None. 113 | baseline (float): Camera baseline. Defaults to 0.12 (ZED baseline) 114 | """ 115 | 116 | def __init__(self, ckpt: str = None): 117 | super().__init__() 118 | # Initialize model 119 | self.model = torch.jit.load(ckpt).cuda() 120 | self.model.eval() 121 | 122 | def inference( 123 | self, 124 | baseline: float, 125 | rgb_left: torch.Tensor, 126 | rgb_right: torch.Tensor, 127 | intrinsics: torch.Tensor, 128 | resize: tuple = None, 129 | ): 130 | """Performs inference on input data 131 | 132 | Args: 133 | rgb_left (torch.Tensor): input float32 image (B,3,H,W) 134 | rgb_right (torch.Tensor): input float32 image (B,3,H,W) 135 | intrinsics (torch.Tensor): camera intrinsics (B,3,3) 136 | resize (tuple, optional): resize shape. Defaults to None. 137 | 138 | Returns: 139 | depth: output depth map (B,1,H,W) 140 | """ 141 | 142 | rgb_left, intrinsics = resize_input( 143 | rgb=rgb_left, intrinsics=intrinsics, resize=resize 144 | ) 145 | rgb_right = resize_input(rgb=rgb_right, resize=resize) 146 | 147 | with torch.no_grad(): 148 | output, _ = self.model(rgb_left, rgb_right) 149 | 150 | disparity_sparse = output["disparity_sparse"] 151 | mask = disparity_sparse != 0 152 | depth = torch.zeros_like(disparity_sparse) 153 | depth[mask] = baseline * intrinsics[0, 0, 0] / disparity_sparse[mask] 154 | # depth = baseline * intrinsics[0, 0, 0] / output["disparity"] 155 | rgb = (rgb_left.squeeze(0).permute(1,2,0).cpu().detach().numpy()*255).astype(np.uint8) 156 | return depth, output["disparity"], disparity_sparse,rgb 157 | 158 | 159 | class Zed(): 160 | width: int 161 | """Width of the rgb/depth images.""" 162 | height: int 163 | """Height of the rgb/depth images.""" 164 | raft_lock: Lock 165 | """Lock for the camera, for raft-stereo depth!""" 166 | 167 | zed_mesh: trimesh.Trimesh 168 | """Trimesh of the ZED camera.""" 169 | cam_to_zed: RigidTransform 170 | """Transform from left camera to ZED camera base.""" 171 | 172 | def __init__(self, flip_mode, resolution, fps, cam_id=None, recording_file=None, start_time=0.0): 173 | init = sl.InitParameters() 174 | if cam_id is not None: 175 | init.set_from_serial_number(cam_id) 176 | self.cam_id = cam_id 177 | self.width = None 178 | self.debug_ = False 179 | self.height = None 180 | self.init_res = None 181 | # Set camera flip mode 182 | if flip_mode: 183 | init.camera_image_flip = sl.FLIP_MODE.ON 184 | else: 185 | init.camera_image_flip = sl.FLIP_MODE.OFF 186 | 187 | if recording_file is not None: 188 | init.set_from_svo_file(recording_file) 189 | 190 | # Configure camera resolution 191 | if resolution == '720p': 192 | init.camera_resolution = sl.RESOLUTION.HD720 193 | self.height = 720 194 | self.width = 1280 195 | self.init_res = 1280 196 | elif resolution == '1080p': 197 | init.camera_resolution = sl.RESOLUTION.HD1080 198 | self.height = 1080 199 | self.width = 1920 200 | self.init_res = 1920 201 | elif resolution == '2k': 202 | init.camera_resolution = sl.RESOLUTION.HD2k 203 | self.height = 1242 204 | self.width = 2208 205 | self.init_res = 2208 206 | else: 207 | print("Only 720p, 1080p, and 2k supported by Zed") 208 | exit() 209 | # Disable native ZED depth computation (we'll use RAFT-Stereo instead) 210 | init.depth_mode = sl.DEPTH_MODE.NONE 211 | init.sdk_verbose = 1 212 | init.camera_fps = fps 213 | self.cam = sl.Camera() 214 | init.camera_disable_self_calib = True 215 | status = self.cam.open(init) 216 | if recording_file is not None: 217 | fps = self.cam.get_camera_information().camera_configuration.fps 218 | self.cam.set_svo_position(int(start_time*fps)) 219 | if status != sl.ERROR_CODE.SUCCESS: #Ensure the camera has opened succesfully 220 | print("Camera Open : "+repr(status)+". Exit program.") 221 | exit() 222 | else: 223 | print("Opened camera") 224 | res = sl.Resolution() 225 | res.width = self.width 226 | res.height = self.height 227 | left_cx = self.get_K(cam="left")[0, 2] 228 | right_cx = self.get_K(cam="right")[0, 2] 229 | self.cx_diff = right_cx - left_cx # /1920 230 | self.f_ = self.get_K(cam="left")[0,0] 231 | self.cx_ = left_cx 232 | self.cy_ = self.get_K(cam="left")[1,2] 233 | self.baseline_ = self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.get_translation().get()[0] / 1000.0 234 | self.intrinsics_ = torch.tensor([[ 235 | [self.f_,0,self.cx_], 236 | [0,self.f_,self.cy_], 237 | [0,0,1] 238 | ]]).to(torch.float32).cuda() 239 | # Create lock for raft -- gpu threading messes up CUDA memory state, with curobo... 240 | self.raft_lock = Lock() 241 | self.dir_path = pathlib.Path(__file__).parent.resolve() 242 | self.stereo_ckpt = os.path.join(self.dir_path,'models/stereo_20230724.pt') #We use stereo model from this paper: https://arxiv.org/abs/2109.11644. However, you can sub this in for any realtime stereo model (including the default Zed model). 243 | with self.raft_lock: 244 | self.model = StereoModel(self.stereo_ckpt) 245 | 246 | # left_cx = self.get_K(cam='left')[0,2] 247 | # right_cx = self.get_K(cam='right')[0,2] 248 | # self.cx_diff = (right_cx-left_cx) 249 | 250 | 251 | 252 | # For visualiation. 253 | 254 | zedM_path = Path(__file__).parent / Path("data/ZEDM.stl") 255 | zed2_path = Path(__file__).parent / Path("data/ZED2.stl") 256 | self.zedM_mesh = trimesh.load(str(zedM_path)) 257 | self.zed2_mesh = trimesh.load(str(zed2_path)) 258 | # assert isinstance(zed_mesh, trimesh.Trimesh) 259 | self.zed_mesh = self.zed2_mesh 260 | self.cam_to_zed = RigidTransform( 261 | rotation=RigidTransform.quaternion_from_axis_angle( 262 | np.array([1, 0, 0]) * (np.pi / 2) 263 | ), 264 | translation=np.array([0.06, 0.042, -0.035]), 265 | ) 266 | 267 | def prime_get_frame(self): 268 | res = sl.Resolution() 269 | res.width = self.width 270 | res.height = self.height 271 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS: 272 | left_rgb = sl.Mat() 273 | right_rgb = sl.Mat() 274 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT) 275 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT) 276 | self.height = self.height - self.height % 32 277 | self.width = self.width - self.width % 32 278 | left_cropped = np.flip(left_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy() 279 | right_cropped = np.flip(right_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy() 280 | with self.raft_lock: 281 | tridepth, disparity, disparity_sparse,cropped_rgb = self.model.inference(rgb_left=format_image(left_cropped),rgb_right=format_image(right_cropped),intrinsics=self.intrinsics_,baseline=self.baseline_) 282 | return left_cropped,right_cropped 283 | else: 284 | print("Couldn't grab frame") 285 | 286 | def get_frame( 287 | self, depth=True 288 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 289 | res = sl.Resolution() 290 | res.width = self.width 291 | res.height = self.height 292 | r = self.width/self.init_res 293 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS: 294 | left_rgb = sl.Mat() 295 | right_rgb = sl.Mat() 296 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT, sl.MEM.CPU, res) 297 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT, sl.MEM.CPU, res) 298 | self.height = self.height - self.height % 32 299 | self.width = self.width - self.width % 32 300 | left_cropped = np.flip(left_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy() 301 | right_cropped = np.flip(right_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy() 302 | 303 | with self.raft_lock: 304 | tridepth, disparity, disparity_sparse,cropped_rgb = self.model.inference(rgb_left=format_image(left_cropped),rgb_right=format_image(right_cropped),intrinsics=self.intrinsics_,baseline=self.baseline_) 305 | if(self.debug_): 306 | import pdb 307 | pdb.set_trace() 308 | plt.imshow(tridepth.detach().cpu().numpy()[0,0],cmap='jet') 309 | plt.savefig('/home/lifelong/prime_raft.png') 310 | import pdb 311 | pdb.set_trace() 312 | return torch.from_numpy(left_cropped).cuda(), torch.from_numpy(right_cropped).cuda(), tridepth[0,0] 313 | elif self.cam.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED: 314 | print("End of recording file") 315 | return None,None,None 316 | else: 317 | raise RuntimeError("Could not grab frame") 318 | 319 | def get_K(self,cam='left') -> np.ndarray: 320 | calib = self.cam.get_camera_information().camera_configuration.calibration_parameters 321 | if cam=='left': 322 | intrinsics = calib.left_cam 323 | else: 324 | intrinsics = calib.right_cam 325 | r = self.width/self.init_res 326 | K = np.array([[intrinsics.fx*r, 0, intrinsics.cx*r], 327 | [0, intrinsics.fy*r, intrinsics.cy*r], 328 | [0, 0, 1]]) 329 | return K 330 | 331 | def get_stereo_transform(self): 332 | transform = self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.m 333 | transform[:3,3] /= 1000#convert to meters 334 | return transform 335 | 336 | def start_record(self, out_path): 337 | recordingParameters = sl.RecordingParameters() 338 | recordingParameters.compression_mode = sl.SVO_COMPRESSION_MODE.H264 339 | recordingParameters.video_filename = out_path 340 | err = self.cam.enable_recording(recordingParameters) 341 | 342 | def stop_record(self): 343 | self.cam.disable_recording() 344 | 345 | @staticmethod 346 | def plotly_render(frame) -> plotly.graph_objs.Figure: 347 | fig = px.imshow(frame) 348 | fig.update_layout( 349 | margin=dict(l=0, r=0, t=0, b=0), 350 | showlegend=False, 351 | yaxis_visible=False, 352 | yaxis_showticklabels=False, 353 | xaxis_visible=False, 354 | xaxis_showticklabels=False, 355 | ) 356 | return fig 357 | 358 | @staticmethod 359 | def project_depth( 360 | rgb: torch.Tensor, 361 | depth: torch.Tensor, 362 | K: torch.Tensor, 363 | depth_threshold: float = 1.0, 364 | subsample: int = 4, 365 | ) -> Tuple[np.ndarray, np.ndarray]: 366 | """Deproject RGBD image to point cloud, using provided intrinsics. 367 | Also threshold/subsample pointcloud for visualization speed.""" 368 | 369 | img_wh = rgb.shape[:2][::-1] 370 | 371 | grid = ( 372 | torch.stack( 373 | torch.meshgrid( 374 | torch.arange(img_wh[0], device="cuda"), 375 | torch.arange(img_wh[1], device="cuda"), 376 | indexing="xy", 377 | ), 378 | 2, 379 | ) 380 | + 0.5 381 | ) 382 | 383 | homo_grid = torch.concat( 384 | [grid, torch.ones((grid.shape[0], grid.shape[1], 1), device="cuda")], 385 | dim=2 386 | ).reshape(-1, 3) 387 | local_dirs = torch.matmul(torch.linalg.inv(K),homo_grid.T).T 388 | points = (local_dirs * depth.reshape(-1,1)).float() 389 | points = points.reshape(-1,3) 390 | 391 | mask = depth.reshape(-1, 1) <= depth_threshold 392 | points = points.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy() 393 | colors = rgb.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy() 394 | 395 | return (points, colors) 396 | 397 | 398 | import tyro 399 | def main(name: str) -> None: 400 | # def main() -> None: 401 | 402 | import torch 403 | from viser import ViserServer 404 | # zed = Zed(recording_file="exps/eyeglasses/2024-06-06_014947/traj.svo2") 405 | #jzed = Zed(recording_file="test.svo2") 406 | # import pdb; pdb.set_trace() 407 | zed = Zed(recording_file = "exps/scissors/2024-06-06_155342/traj.svo2") 408 | # zed.start_record(f"/home/chungmin/Documents/please2/toad/motion_vids/{name}.svo2") 409 | import os 410 | # os.makedirs(out_dir,exist_ok=True) 411 | i = 0 412 | 413 | 414 | #code for visualizing poincloud 415 | import viser 416 | from matplotlib import pyplot as plt 417 | import viser.transforms as tf 418 | v = ViserServer() 419 | gui_reset_up = v.add_gui_button( 420 | "Reset up direction", 421 | hint="Set the camera control 'up' direction to the current camera's 'up'.", 422 | ) 423 | 424 | @gui_reset_up.on_click 425 | def _(event: viser.GuiEvent) -> None: 426 | client = event.client 427 | assert client is not None 428 | client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( 429 | [0.0, -1.0, 0.0] 430 | ) 431 | while True: 432 | left,right,depth = zed.get_frame() 433 | left = left.cpu().numpy() 434 | depth = depth.cpu().numpy() 435 | # import matplotlib.pyplot as plt 436 | # plt.imshow(left) 437 | # plt.show() 438 | K = zed.get_K() 439 | T_world_camera = np.eye(4) 440 | 441 | img_wh = left.shape[:2][::-1] 442 | 443 | grid = ( 444 | np.stack(np.meshgrid(np.arange(img_wh[0]), np.arange(img_wh[1])), 2) + 0.5 445 | ) 446 | 447 | homo_grid = np.concatenate([grid,np.ones((grid.shape[0],grid.shape[1],1))],axis=2).reshape(-1,3) 448 | local_dirs = np.matmul(np.linalg.inv(K),homo_grid.T).T 449 | points = (local_dirs * depth.reshape(-1,1)).astype(np.float32) 450 | points = points.reshape(-1,3) 451 | v.add_point_cloud("points", points = points.reshape(-1,3), colors=left.reshape(-1,3),point_size=.001) 452 | 453 | if __name__ == "__main__": 454 | tyro.cli(main) -------------------------------------------------------------------------------- /pogs/tracking/utils.py: -------------------------------------------------------------------------------- 1 | import warp as wp 2 | import torch 3 | from nerfstudio.cameras.cameras import Cameras 4 | from pogs.tracking.transforms import SE3, SO3 5 | 6 | def extrapolate_poses(p1_7v, p2_7v, lam, rot_scale_factor=3.0, thresh=True): 7 | ext_7v = [] 8 | for i in range(len(p2_7v)): 9 | r1 = SO3(p1_7v[i,3:]) 10 | t1 = SE3.from_rotation_and_translation(r1, p1_7v[i,:3]) 11 | r2 = SO3(p2_7v[i,3:]) 12 | t2 = SE3.from_rotation_and_translation(r2, p2_7v[i,:3]) 13 | t_2_1 = t1.inverse() @ t2 14 | delta_pos = t_2_1.translation()*lam 15 | delta_rot = SO3.exp((t_2_1.rotation().log() * lam * rot_scale_factor)) 16 | if thresh and delta_pos.norm().item() < 0.05: # Threshold for small deltas to avoid oscillations 17 | new_t = t2 18 | else: 19 | new_t = (t2 @ SE3.from_rotation_and_translation(delta_rot, delta_pos)) 20 | ext_7v.append(new_t.wxyz_xyz.roll(3,dims=-1)) 21 | return torch.stack(ext_7v) 22 | 23 | def zero_optim_state(optimizer:torch.optim.Adam): 24 | # import pdb; pdb.set_trace() 25 | param = optimizer.param_groups[0]["params"][0] 26 | param_state = optimizer.state[param] 27 | if "max_exp_avg_sq" in param_state: 28 | # for amsgrad 29 | param_state["max_exp_avg_sq"] = torch.zeros(param_state["max_exp_avg_sq"].shape, device=param_state["max_exp_avg_sq"].device) 30 | if "exp_avg" in param_state: 31 | param_state["exp_avg"] = torch.zeros(param_state["exp_avg"].shape, device=param_state["exp_avg"].device) 32 | param_state["exp_avg_sq"] = torch.zeros(param_state["exp_avg_sq"].shape, device=param_state["exp_avg_sq"].device) 33 | 34 | def replace_in_optim(optimizer:torch.optim.Adam, new_params): 35 | """replaces the parameters in the optimizer""" 36 | param = optimizer.param_groups[0]["params"][0] 37 | param_state = optimizer.state[param] 38 | 39 | del optimizer.state[param] 40 | optimizer.state[new_params[0]] = param_state 41 | optimizer.param_groups[0]["params"] = new_params 42 | del param 43 | 44 | @wp.func 45 | def poses_7vec_to_transform(poses: wp.array(dtype=float, ndim=2), i: int): 46 | """ 47 | Kernel helper for converting x y z qw qx qy qz to a wp.Transformation 48 | """ 49 | position = wp.vector(poses[i,0], poses[i,1], poses[i,2]) 50 | quaternion = wp.quaternion(poses[i,4], poses[i,5], poses[i,6], poses[i,3]) 51 | return wp.transformation(position, quaternion) 52 | 53 | @wp.kernel 54 | def apply_to_model( 55 | # init_o2w: wp.array(dtype=float, ndim=2), 56 | init_p2ws: wp.array(dtype=float, ndim=2), 57 | # o_delta: wp.array(dtype=float, ndim=2), 58 | p_deltas: wp.array(dtype=float, ndim=2), 59 | group_labels: wp.array(dtype=int), 60 | means: wp.array(dtype=wp.vec3), 61 | quats: wp.array(dtype=float, ndim=2), 62 | #outputs 63 | means_out: wp.array(dtype=wp.vec3), 64 | quats_out: wp.array(dtype=float, ndim=2), 65 | ): 66 | """ 67 | Kernel for applying the transforms to a gaussian splat 68 | 69 | [removed] init_o2w: 1x7 tensor of initial object to world poses 70 | init_p2ws: Nx7 tensor of initial pose to object poses 71 | [removed] o_delta: Nx7 tensor of object pose deltas represented as objnew_to_objoriginal 72 | p_deltas: Nx7 tensor of pose deltas represented as partnew_to_partoriginal 73 | group_labels: N, tensor of group labels (0->K-1) for K groups 74 | means: Nx3 tensor of means 75 | quats: Nx4 tensor of quaternions (wxyz) 76 | means_out: Nx3 tensor of output means 77 | quats_out: Nx4 tensor of output quaternions (wxyz) 78 | """ 79 | tid = wp.tid() 80 | group_id = group_labels[tid] 81 | # o2w_T = poses_7vec_to_transform(init_o2w,0) 82 | p2w_T = poses_7vec_to_transform(init_p2ws,group_id) 83 | # odelta_T = poses_7vec_to_transform(o_delta,0) 84 | pdelta_T = poses_7vec_to_transform(p_deltas,group_id) 85 | g2w_T = wp.transformation(means[tid], wp.quaternion(quats[tid, 1], quats[tid, 2], quats[tid, 3], quats[tid, 0])) 86 | g2p_T = wp.transform_inverse(p2w_T) * g2w_T 87 | new_g2w_T = p2w_T * pdelta_T * g2p_T 88 | means_out[tid] = wp.transform_get_translation(new_g2w_T) 89 | new_quat = wp.transform_get_rotation(new_g2w_T) 90 | quats_out[tid, 0] = new_quat[3] #w 91 | quats_out[tid, 1] = new_quat[0] #x 92 | quats_out[tid, 2] = new_quat[1] #y 93 | quats_out[tid, 3] = new_quat[2] #z 94 | 95 | def identity_7vec(device='cuda'): 96 | """ 97 | Returns a 7-tensor of identity pose 98 | """ 99 | return torch.tensor([[0, 0, 0, 1, 0, 0, 0]], dtype=torch.float32, device=device) 100 | 101 | def normalized_quat_to_rotmat(quat): 102 | """ 103 | Converts a quaternion to a 3x3 rotation matrix 104 | """ 105 | assert quat.shape[-1] == 4, quat.shape 106 | w, x, y, z = torch.unbind(quat, dim=-1) 107 | mat = torch.stack( 108 | [ 109 | 1 - 2 * (y**2 + z**2), 110 | 2 * (x * y - w * z), 111 | 2 * (x * z + w * y), 112 | 2 * (x * y + w * z), 113 | 1 - 2 * (x**2 + z**2), 114 | 2 * (y * z - w * x), 115 | 2 * (x * z - w * y), 116 | 2 * (y * z + w * x), 117 | 1 - 2 * (x**2 + y**2), 118 | ], 119 | dim=-1, 120 | ) 121 | return mat.reshape(quat.shape[:-1] + (3, 3)) 122 | 123 | def torch_posevec_to_mat(posevecs): 124 | """ 125 | Converts a Nx7-tensor to Nx4x4 matrix 126 | 127 | posevecs: Nx7 tensor of pose vectors 128 | returns: Nx4x4 tensor of transformation matrices 129 | """ 130 | assert posevecs.shape[-1] == 7, posevecs.shape 131 | assert len(posevecs.shape) == 2, posevecs.shape 132 | out = torch.eye(4, device=posevecs.device).unsqueeze(0).expand(posevecs.shape[0], -1, -1) 133 | out[:, :3, 3] = posevecs[:, :3] 134 | out[:, :3, :3] = normalized_quat_to_rotmat(posevecs[:, 3:]) 135 | return out 136 | 137 | def mnn_matcher(feat_a, feat_b): 138 | """ 139 | Returns mutual nearest neighbors between two sets of features 140 | 141 | feat_a: NxD 142 | feat_b: MxD 143 | return: K, K (indices in feat_a and feat_b) 144 | """ 145 | device = feat_a.device 146 | sim = feat_a.mm(feat_b.t()) 147 | nn12 = torch.max(sim, dim=1)[1] 148 | nn21 = torch.max(sim, dim=0)[1] 149 | ids1 = torch.arange(0, sim.shape[0], device=device) 150 | mask = ids1 == nn21[nn12] 151 | return ids1[mask], nn12[mask] 152 | 153 | def crop_camera(camera: Cameras, xmin, xmax, ymin, ymax): 154 | height = torch.tensor(ymax - ymin,device='cuda').view(1,1).int() 155 | width = torch.tensor(xmax - xmin,device='cuda').view(1,1).int() 156 | cx = torch.tensor(camera.cx.clone() - xmin,device='cuda').view(1,1) 157 | cy = torch.tensor(camera.cy.clone() - ymin,device='cuda').view(1,1) 158 | fx = camera.fx.clone() 159 | fy = camera.fy.clone() 160 | return Cameras(camera.camera_to_worlds.clone(), fx, fy, cx, cy, width, height) -------------------------------------------------------------------------------- /pogs/tracking/utils2.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation 6 | import moviepy as mpy 7 | import wandb 8 | from pogs.tracking.observation import PosedObservation, Frame 9 | from torchvision.transforms.functional import to_pil_image 10 | from pogs.pogs import POGSModel 11 | import time 12 | import cv2 13 | 14 | 15 | def generate_videos(frames_dict, fps=30, config_path=None): 16 | import datetime 17 | timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 18 | for key in frames_dict.keys(): 19 | frames = frames_dict[key] 20 | if len(frames)>1: 21 | if frames[0].max() > 1: 22 | frames = [f for f in frames] 23 | else: 24 | frames = [f*255 for f in frames] 25 | clip = mpy.ImageSequenceClip(frames, fps=fps) 26 | if config_path is None: 27 | clip.write_videofile(f"{timestr}/{key}.mp4", codec="libx264") 28 | else: 29 | path = config_path.joinpath(f"{timestr}") 30 | if not path.exists(): 31 | path.mkdir(parents=True) 32 | clip.write_videofile(str(path.joinpath(f"{key}.mp4")), codec="libx264") 33 | try: 34 | wandb.log({f"{key}": wandb.Video(str(path.joinpath(f"{key}.mp4")))}) 35 | except: 36 | pass 37 | return timestr 38 | 39 | 40 | def overlay(image, mask, color, alpha, resize=None): 41 | """Combines image and its segmentation mask into a single image. 42 | https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay 43 | 44 | Params: 45 | image: Training image. np.ndarray, 46 | mask: Segmentation mask. np.ndarray, 47 | color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0) 48 | alpha: Segmentation mask's transparency. float = 0.5, 49 | resize: If provided, both image and its mask are resized before blending them together. 50 | tuple[int, int] = (1024, 1024)) 51 | 52 | Returns: 53 | image_combined: The combined image. np.ndarray 54 | 55 | """ 56 | color = color[::-1] 57 | colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) 58 | colored_mask = np.moveaxis(colored_mask, 0, -1) 59 | masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) 60 | image_overlay = masked.filled() 61 | 62 | if resize is not None: 63 | image = cv2.resize(image.transpose(1, 2, 0), resize) 64 | image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) 65 | 66 | image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) 67 | 68 | return image_combined -------------------------------------------------------------------------------- /pogs/tracking/zed.py: -------------------------------------------------------------------------------- 1 | import pyzed.sl as sl 2 | from typing import Optional, Tuple 3 | import torch 4 | import numpy as np 5 | from threading import Lock 6 | import plotly 7 | from plotly import express as px 8 | import trimesh 9 | from pathlib import Path 10 | # from raftstereo.raft_stereo import * 11 | from autolab_core import RigidTransform 12 | from matplotlib import pyplot as plt 13 | class Zed(): 14 | width: int 15 | """Width of the rgb/depth images.""" 16 | height: int 17 | """Height of the rgb/depth images.""" 18 | raft_lock: Lock 19 | """Lock for the camera, for raft-stereo depth!""" 20 | 21 | zed_mesh: trimesh.Trimesh 22 | """Trimesh of the ZED camera.""" 23 | cam_to_zed: RigidTransform 24 | """Transform from left camera to ZED camera base.""" 25 | 26 | def __init__(self, recording_file = None, start_time = 0.0): 27 | init = sl.InitParameters() 28 | if recording_file is not None: 29 | init.set_from_svo_file(recording_file) 30 | # disable depth 31 | init.camera_image_flip = sl.FLIP_MODE.ON 32 | init.depth_mode=sl.DEPTH_MODE.NONE 33 | init.camera_resolution = sl.RESOLUTION.HD1080 34 | init.sdk_verbose = 1 35 | init.camera_fps = 30 36 | else: 37 | init.camera_resolution = sl.RESOLUTION.HD720 38 | init.sdk_verbose = 1 39 | init.camera_fps = 30 40 | # flip camera 41 | # init.camera_image_flip = sl.FLIP_MODE.ON 42 | init.depth_mode=sl.DEPTH_MODE.NONE 43 | init.depth_minimum_distance = 100#millimeters 44 | self.init_res = 1920 if init.camera_resolution == sl.RESOLUTION.HD1080 else 1280 45 | print("INIT RES",self.init_res) 46 | self.debug_ = False 47 | self.width = 1280 48 | self.height = 720 49 | self.cam = sl.Camera() 50 | status = self.cam.open(init) 51 | if recording_file is not None: 52 | fps = self.cam.get_camera_information().camera_configuration.fps 53 | self.cam.set_svo_position(int(start_time*fps)) 54 | if status != sl.ERROR_CODE.SUCCESS: #Ensure the camera has opened succesfully 55 | print("Camera Open : "+repr(status)+". Exit program.") 56 | exit() 57 | else: 58 | print("Opened camera") 59 | 60 | left_cx = self.get_K(cam="left")[0, 2] 61 | right_cx = self.get_K(cam="right")[0, 2] 62 | self.cx_diff = right_cx - left_cx # /1920 63 | self.f_ = self.get_K(cam="left")[0,0] 64 | self.cx_ = left_cx 65 | self.cy_ = self.get_K(cam="left")[1,2] 66 | 67 | # Create lock for raft -- gpu threading messes up CUDA memory state, with curobo... 68 | self.raft_lock = Lock() 69 | with self.raft_lock: 70 | self.model = create_raft() 71 | 72 | # left_cx = self.get_K(cam='left')[0,2] 73 | # right_cx = self.get_K(cam='right')[0,2] 74 | # self.cx_diff = (right_cx-left_cx) 75 | 76 | # For visualiation. 77 | zed_path = Path(__file__).parent / Path("data/ZEDM.stl") 78 | zed_mesh = trimesh.load(str(zed_path)) 79 | assert isinstance(zed_mesh, trimesh.Trimesh) 80 | self.zed_mesh = zed_mesh 81 | self.cam_to_zed = RigidTransform( 82 | rotation=RigidTransform.quaternion_from_axis_angle( 83 | np.array([1, 0, 0]) * (np.pi / 2) 84 | ), 85 | translation=np.array([0.06, 0.042, -0.035]), # Numbers are for ZED 2 86 | ) 87 | 88 | def get_frame( 89 | self, depth=True 90 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 91 | res = sl.Resolution() 92 | res.width = self.width 93 | res.height = self.height 94 | r = self.width/self.init_res 95 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS: 96 | left_rgb = sl.Mat() 97 | right_rgb = sl.Mat() 98 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT, sl.MEM.CPU, res) 99 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT, sl.MEM.CPU, res) 100 | left,right = torch.from_numpy(np.flip(left_rgb.get_data()[...,:3],axis=2).copy()).cuda(), torch.from_numpy(np.flip(right_rgb.get_data()[...,:3],axis=2).copy()).cuda() 101 | if depth: 102 | left_torch,right_torch = left.permute(2,0,1),right.permute(2,0,1) 103 | with self.raft_lock: 104 | flow = raft_inference(left_torch,right_torch,self.model) 105 | fx = self.get_K()[0,0] 106 | depth = fx*self.get_stereo_transform()[0,3]/(flow.abs()+self.cx_diff) 107 | if(self.debug_): 108 | plt.imshow(depth.detach().cpu().numpy(),cmap='jet') 109 | plt.savefig('/home/lifelong/justin_raft.png') 110 | import pdb 111 | pdb.set_trace() 112 | else: 113 | depth = None 114 | return left, right, depth 115 | elif self.cam.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED: 116 | print("End of recording file") 117 | return None,None,None 118 | else: 119 | raise RuntimeError("Could not grab frame") 120 | 121 | def get_K(self,cam='left') -> np.ndarray: 122 | calib = self.cam.get_camera_information().camera_configuration.calibration_parameters 123 | if cam=='left': 124 | intrinsics = calib.left_cam 125 | else: 126 | intrinsics = calib.right_cam 127 | r = self.width/self.init_res 128 | K = np.array([[intrinsics.fx*r, 0, intrinsics.cx*r], [0, intrinsics.fy*r, intrinsics.cy*r], [0, 0, 1]]) 129 | return K 130 | 131 | def get_stereo_transform(self): 132 | transform = self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.m 133 | transform[:3,3] /= 1000#convert to meters 134 | return transform 135 | 136 | def start_record(self, out_path): 137 | recordingParameters = sl.RecordingParameters() 138 | recordingParameters.compression_mode = sl.SVO_COMPRESSION_MODE.H264 139 | recordingParameters.video_filename = out_path 140 | err = self.cam.enable_recording(recordingParameters) 141 | 142 | def stop_record(self): 143 | self.cam.disable_recording() 144 | 145 | @staticmethod 146 | def plotly_render(frame) -> plotly.graph_objs.Figure: 147 | fig = px.imshow(frame) 148 | fig.update_layout( 149 | margin=dict(l=0, r=0, t=0, b=0), 150 | showlegend=False, 151 | yaxis_visible=False, 152 | yaxis_showticklabels=False, 153 | xaxis_visible=False, 154 | xaxis_showticklabels=False, 155 | ) 156 | return fig 157 | 158 | @staticmethod 159 | def project_depth( 160 | rgb: torch.Tensor, 161 | depth: torch.Tensor, 162 | K: torch.Tensor, 163 | depth_threshold: float = 1.0, 164 | subsample: int = 4, 165 | ) -> Tuple[np.ndarray, np.ndarray]: 166 | """Deproject RGBD image to point cloud, using provided intrinsics. 167 | Also threshold/subsample pointcloud for visualization speed.""" 168 | 169 | img_wh = rgb.shape[:2][::-1] 170 | 171 | grid = ( 172 | torch.stack( 173 | torch.meshgrid( 174 | torch.arange(img_wh[0], device="cuda"), 175 | torch.arange(img_wh[1], device="cuda"), 176 | indexing="xy", 177 | ), 178 | 2, 179 | ) 180 | + 0.5 181 | ) 182 | 183 | homo_grid = torch.concat( 184 | [grid, torch.ones((grid.shape[0], grid.shape[1], 1), device="cuda")], 185 | dim=2 186 | ).reshape(-1, 3) 187 | local_dirs = torch.matmul(torch.linalg.inv(K),homo_grid.T).T 188 | points = (local_dirs * depth.reshape(-1,1)).float() 189 | points = points.reshape(-1,3) 190 | 191 | mask = depth.reshape(-1, 1) <= depth_threshold 192 | points = points.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy() 193 | colors = rgb.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy() 194 | 195 | return (points, colors) 196 | 197 | 198 | import tyro 199 | def main(name: str) -> None: 200 | # def main() -> None: 201 | 202 | import torch 203 | from viser import ViserServer 204 | # zed = Zed(recording_file="exps/eyeglasses/2024-06-06_014947/traj.svo2") 205 | #jzed = Zed(recording_file="test.svo2") 206 | # import pdb; pdb.set_trace() 207 | zed = Zed(recording_file = "exps/scissors/2024-06-06_155342/traj.svo2") 208 | # zed.start_record(f"/home/chungmin/Documents/please2/toad/motion_vids/{name}.svo2") 209 | import os 210 | # os.makedirs(out_dir,exist_ok=True) 211 | i = 0 212 | 213 | 214 | #code for visualizing poincloud 215 | import viser 216 | from matplotlib import pyplot as plt 217 | import viser.transforms as tf 218 | v = ViserServer() 219 | gui_reset_up = v.add_gui_button( 220 | "Reset up direction", 221 | hint="Set the camera control 'up' direction to the current camera's 'up'.", 222 | ) 223 | 224 | @gui_reset_up.on_click 225 | def _(event: viser.GuiEvent) -> None: 226 | client = event.client 227 | assert client is not None 228 | client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( 229 | [0.0, -1.0, 0.0] 230 | ) 231 | while True: 232 | left,right,depth = zed.get_frame() 233 | left = left.cpu().numpy() 234 | depth = depth.cpu().numpy() 235 | # import matplotlib.pyplot as plt 236 | # plt.imshow(left) 237 | # plt.show() 238 | K = zed.get_K() 239 | T_world_camera = np.eye(4) 240 | 241 | img_wh = left.shape[:2][::-1] 242 | 243 | grid = ( 244 | np.stack(np.meshgrid(np.arange(img_wh[0]), np.arange(img_wh[1])), 2) + 0.5 245 | ) 246 | 247 | homo_grid = np.concatenate([grid,np.ones((grid.shape[0],grid.shape[1],1))],axis=2).reshape(-1,3) 248 | local_dirs = np.matmul(np.linalg.inv(K),homo_grid.T).T 249 | points = (local_dirs * depth.reshape(-1,1)).astype(np.float32) 250 | points = points.reshape(-1,3) 251 | v.add_point_cloud("points", points = points.reshape(-1,3), colors=left.reshape(-1,3),point_size=.001) 252 | 253 | if __name__ == "__main__": 254 | tyro.cli(main) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pogs" 3 | version = "0.1.1" 4 | requires-python = ">=3.10" 5 | 6 | dependencies=[ 7 | "Pillow", 8 | "jaxtyping", 9 | "rich", 10 | "open-clip-torch", 11 | "numpy==1.26.4", 12 | "torchtyping", 13 | "autolab_core", 14 | "moviepy", 15 | "kornia", 16 | "iopath", 17 | "transformers==4.44.0", 18 | "typeguard>=4.0.0", 19 | "awscli" 20 | ] 21 | [tool.setuptools] 22 | include-package-data = true 23 | 24 | [tool.setuptools.packages.find] 25 | include = ["pogs*"] 26 | 27 | [project.entry-points.'nerfstudio.method_configs'] 28 | pogs = 'pogs.pogs_config:pogs_method' -------------------------------------------------------------------------------- /scripts/dino_pca_visualization.py: -------------------------------------------------------------------------------- 1 | """ Usage example: 2 | python dino_pca_visualization.py --image_path shelf_iron.png 3 | """ 4 | 5 | import PIL.Image as Image 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from sklearn.decomposition import PCA 9 | from pogs.data.utils.dino_dataloader import DinoDataloader, get_img_resolution 10 | import tyro 11 | import rich 12 | import torch 13 | from pogs.data.utils.dino_extractor import ViTExtractor 14 | from pogs.data.utils.feature_dataloader import FeatureDataloader 15 | from tqdm import tqdm 16 | from torchvision import transforms 17 | from typing import Tuple 18 | 19 | def main( 20 | image_path: str = "shelf_iron.png", 21 | dino_model_type: str = "dinov2_vitl14", 22 | dino_stride: int = 14, 23 | device: str = "cuda", 24 | keep_cuda: bool = True, 25 | ): 26 | extractor = ViTExtractor(dino_model_type, dino_stride) 27 | image = Image.open(image_path) 28 | image = np.array(image) 29 | if image.dtype == np.uint8: 30 | image = np.array(image) / 255.0 31 | 32 | if image.dtype == np.float64: 33 | image = image.astype(np.float32) 34 | h, w = get_img_resolution(image.shape[0], image.shape[1]) 35 | preprocess = transforms.Compose([ 36 | transforms.Resize((h,w),antialias=True, interpolation=transforms.InterpolationMode.BICUBIC), 37 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 38 | ]) 39 | 40 | preproc_image = preprocess(torch.from_numpy(image).permute(2,0,1).unsqueeze(0)).to(device) 41 | dino_embeds = [] 42 | for image in tqdm(preproc_image, desc="dino", total=1, leave=False): 43 | with torch.no_grad(): 44 | descriptors = extractor.model.get_intermediate_layers(image.unsqueeze(0),reshape=True)[0].squeeze().permute(1,2,0)/10 45 | if keep_cuda: 46 | dino_embeds.append(descriptors) 47 | else: 48 | dino_embeds.append(descriptors.cpu().detach()) 49 | out = dino_embeds[0] 50 | 51 | patch_h = out.shape[0] 52 | patch_w = out.shape[1] 53 | total_features = out.squeeze(0).squeeze(0).reshape(-1, out.shape[-1]) 54 | pca = PCA(n_components=3) 55 | pca.fit(total_features.cpu().numpy()) 56 | pca_features = pca.transform(total_features.cpu().numpy()) 57 | 58 | # visualize PCA components for finding a proper threshold 59 | # 3 histograms for 3 components 60 | plt.subplot(2, 2, 1) 61 | plt.hist(pca_features[:, 0]) 62 | plt.subplot(2, 2, 2) 63 | plt.hist(pca_features[:, 1]) 64 | plt.subplot(2, 2, 3) 65 | plt.hist(pca_features[:, 2]) 66 | plt.savefig(f"{image_path}_pca_hist.png") 67 | 68 | plt.clf() 69 | 70 | # Visualize PCA components 71 | pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / \ 72 | (pca_features[:, 0].max() - pca_features[:, 0].min()) 73 | 74 | plt.imshow(pca_features[0 : patch_h*patch_w, 0].reshape(patch_h, patch_w), cmap='gist_rainbow') 75 | plt.axis('off') 76 | 77 | plt.savefig(f"{image_path}_pca.png", bbox_inches='tight', pad_inches = 0) 78 | plt.margins(0,0) 79 | plt.show() 80 | 81 | if __name__ == "__main__": 82 | tyro.cli(main) -------------------------------------------------------------------------------- /scripts/shelf_iron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/scripts/shelf_iron.png -------------------------------------------------------------------------------- /scripts/shelf_iron.png_pca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/scripts/shelf_iron.png_pca.png -------------------------------------------------------------------------------- /scripts/shelf_iron.png_pca_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/scripts/shelf_iron.png_pca_hist.png --------------------------------------------------------------------------------