├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── media
├── POGS_servoing.gif
└── POGS_teaser.gif
├── pogs
├── calibration_outputs
│ └── world_to_extrinsic_zed_for_grasping_down.tf
├── camera
│ ├── capture_utils.py
│ └── zed_stereo.py
├── configs
│ ├── base_config.py
│ ├── camera_config.yaml
│ └── conda_compile_pointnet_tfops.sh
├── contact_graspnet_wrapper
│ ├── prime_config_utils.py
│ ├── prime_inference.py
│ └── prime_visualization_utils.py
├── data
│ ├── depth_dataset.py
│ ├── full_images_datamanager.py
│ └── utils
│ │ ├── detic_dataloader.py
│ │ ├── dino_dataloader.py
│ │ ├── dino_extractor.py
│ │ ├── feature_dataloader.py
│ │ ├── patch_embedding_dataloader.py
│ │ └── pyramid_embedding_dataloader.py
├── encoders
│ ├── image_encoder.py
│ └── openclip_encoder.py
├── field_components
│ └── gaussian_fieldheadnames.py
├── fields
│ └── gaussian_field.py
├── grasping
│ ├── generate_grasps_ply.py
│ └── results
│ │ └── predictions_global.ply.npz
├── model_components
│ └── losses.py
├── pogs.py
├── pogs_config.py
├── pogs_pipeline.py
├── scripts
│ ├── calibrate_cameras.py
│ ├── scene_capture.py
│ ├── track_main_demo.py
│ └── track_main_online_demo.py
└── tracking
│ ├── atap_loss.py
│ ├── data
│ ├── ZED2.stl
│ └── ZEDM.stl
│ ├── motion.py
│ ├── observation.py
│ ├── optim.py
│ ├── rigid_group_optimizer.py
│ ├── toad_object.py
│ ├── transforms
│ ├── __init__.py
│ ├── _base.py
│ ├── _se3.py
│ ├── _so3.py
│ └── utils
│ │ ├── __init__.py
│ │ └── _utils.py
│ ├── tri_zed.py
│ ├── utils.py
│ ├── utils2.py
│ └── zed.py
├── pyproject.toml
└── scripts
├── dino_pca_visualization.py
├── shelf_iron.png
├── shelf_iron.png_pca.png
└── shelf_iron.png_pca_hist.png
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ckpt
2 |
3 | *.egg-info
4 |
5 | *.ipynb
6 |
7 | *.log
8 |
9 | *.model
10 |
11 | *.pth
12 |
13 | *.pyc
14 |
15 | *.npy
16 |
17 | *.ply
18 |
19 | __pycache__/
20 |
21 | pogs/sample_data/*
22 |
23 | pogs/configs/__pycache__/*
24 | pogs/data/__pycache__/*
25 | pogs/data/utils/__pycache__/*
26 | pogs/encoders/__pycache__/*
27 |
28 | pogs/scripts/outputs/*
29 |
30 | pogs/ur5_interface/ur5_interface/data/*
31 | pogs/ur5_interface/ur5_interface/scripts/outputs/*
32 | outputs/*
33 |
34 | pogs/tracking/__pycache__/*
35 | pogs/ur5_interface/ur5_interface/outputs/*
36 | pogs/ur5_interface/ur5_interface/scripts/results/*
37 |
38 | data/*
39 | media/POGS.mp4
40 | outputs/*
41 | pogs/tracking/models/
42 | *.mp4
43 | *.ipynb
44 | .vscode/
45 | pogs/data/utils/datasets/
46 | pogs/calibration_outputs/world_to_extrinsic_zed.tf
47 | pogs/calibration_outputs/wrist_to_zed_mini.tf
48 | pogs/pogs/grasping/results
49 | pogs/results
50 | pogs/robot_description
51 | pogs/scripts/goto_click_extrinsic_cam.py
52 | pogs/scripts/goto_click_wrist_cam.py
53 | pogs/scripts/test_grasp_visualization.py
54 | pogs/scripts/results
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "pogs/dependencies/detectron2"]
2 | path = pogs/dependencies/detectron2
3 | url = https://github.com/facebookresearch/detectron2.git
4 | [submodule "pogs/dependencies/Detic"]
5 | path = pogs/dependencies/Detic
6 | url = https://github.com/facebookresearch/Detic.git
7 | [submodule "pogs/dependencies/ur5py"]
8 | path = pogs/dependencies/ur5py
9 | url = https://github.com/kushtimusPrime/ur5py
10 | [submodule "pogs/dependencies/raftstereo"]
11 | path = pogs/dependencies/raftstereo
12 | url = https://github.com/kushtimusPrime/RAFT-Stereo
13 | [submodule "pogs/dependencies/contact_graspnet"]
14 | path = pogs/dependencies/contact_graspnet
15 | url = https://github.com/NVlabs/contact_graspnet.git
16 | [submodule "pogs/dependencies/nerfstudio"]
17 | path = pogs/dependencies/nerfstudio
18 | url = https://github.com/nerfstudio-project/nerfstudio.git
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Justin Yu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Persistent Object Gaussian Splat (POGS) for Tracking Human and Robot Manipulation of Irregularly Shaped Objects
2 |
3 |
4 |
5 | [[Website]](https://berkeleyautomation.github.io/POGS/)
6 |
7 |
8 |
9 |
10 |
11 |

12 |
13 |

14 |
15 |
16 |
17 |
18 |
19 | This repository contains the official implementation for [POGS](https://berkeleyautomation.github.io/POGS/).
20 |
21 | Tested on Python 3.10, cuda 11.8, using conda.
22 |
23 | ## Installation
24 | 1. Create conda environment and install relevant packages
25 | ```
26 | conda create --name pogs_env -y python=3.10
27 | conda activate pogs_env
28 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
29 |
30 | pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu118
31 |
32 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
33 | pip install jaxtyping rich
34 | pip install gsplat --index-url https://docs.gsplat.studio/whl/pt20cu118
35 | pip install warp-lang
36 | ```
37 |
38 | 2. [`cuml`](https://docs.rapids.ai/install) is required (for global clustering).
39 | The best way to install it is with pip: `pip install --extra-index-url=https://pypi.nvidia.com cudf-cu11==25.4.* cuml-cu11==25.4.*`
40 |
41 | 3. Install POGS!
42 | ```
43 | git clone https://github.com/uynitsuj/POGS.git --recurse-submodules
44 | cd POGS
45 | python -m pip install -e .
46 | python -m pip install pogs/dependencies/nerfstudio/
47 | pip install fast_simplification==0.1.9
48 | pip install numpy==1.26.4
49 | ns-install-cli
50 | ```
51 | ### Robot Interaction Code Installation (UR5 Specific)
52 |
53 | There is also a physical robot component with the UR5 and ZED 2 cameras. To install relevant libraries:
54 | #### ur5py
55 | ```
56 | pip install ur_rtde==1.4.2 cowsay opt-einsum pyvista autolab-core
57 | pip install -e /pogs/dependencies/ur5py
58 | ```
59 |
60 | #### RAFT-Stereo
61 | ```
62 | cd ~/POGS/pogs/dependencies/raftstereo
63 | bash download_models.sh
64 | pip install -e .
65 | ```
66 |
67 | #### Contact-Graspnet
68 | Contact Graspnet relies on some older library setups, so we couldn't merge everything into 1 conda environment. However, we can make it work by making this separate conda environment and then calling it in a subprocess.
69 | ```
70 | conda deactivate
71 | conda create --name contact_graspnet_env python=3.8
72 | conda activate contact_graspnet_env
73 | conda install -c conda-forge cudatoolkit=11.2
74 | conda install -c conda-forge cudnn=8.2
75 | # If you don't have cuda installed at /usr/local/cuda then you can install on your conda env and run these two lines
76 | conda install -c conda-forge cudatoolkit-dev
77 | export CUDA_HOME=/path/to/anaconda/envs/contact_graspnet_env/bin/nvcc
78 | pip install tensorflow==2.5 tensorflow-gpu==2.5
79 | pip install opencv-python-headless pyyaml pyrender tqdm mayavi
80 | pip install open3d==0.10.0 typing-extensions==3.7.4 trimesh==3.8.12 configobj==5.0.6 matplotlib==3.3.2 pyside2==5.11.0 scikit-image==0.19.0 numpy==1.19.2 scipy==1.9.1 vtk==9.3.1
81 | # if you have cuda installed at /usr/local/cuda run these lines
82 | cd ~/POGS/pogs/dependencies/contact_graspnet
83 | sh compile_pointnet_tfops.sh
84 | # if you have cuda installed on your conda env run these lines
85 | cd ~/POGS/pogs/configs
86 | cp conda_compile_pointnet_tfops.sh ~/pogs/pogs/dependencies/contact_graspnet/
87 | cd ~/POGS/pogs/dependencies/contact_graspnet
88 | sh conda_compile_pointnet_tfops.sh
89 | pip install autolab-core
90 | ```
91 |
92 | #### Download Models and Data
93 | ##### Model
94 | Download trained models from [here](https://drive.google.com/drive/folders/1tBHKf60K8DLM5arm-Chyf7jxkzOr5zGl?usp=sharing) and copy them into the `checkpoints/` folder.
95 | ##### Test data
96 | Download the test data from [here](https://drive.google.com/drive/folders/1TqpM2wHAAo0j3i1neu3Xeru3_WnsYQnx?usp=sharing) and copy them them into the `test_data/` folder.
97 |
98 | ## Usage
99 | ### Calibrate wrist mounted and third person cameras
100 | Before training/tracking POGS, make sure wrist mounted camera and third-person view camera are calibrated. We use an Aruco marker for the calibration
101 | ```
102 | conda activate pogs_env
103 | cd ~/POGS/pogs/scripts
104 | python calibrate_cameras.py
105 | ```
106 |
107 | ### Scene Capture
108 | Script used to perform hemisphere capture with robot on tabletop scene. We used manual trajectory but you can also put the robot in "teach" mode to capture trajectory.
109 | ```
110 | conda activate pogs_env
111 | cd ~/POGS/pogs/scripts
112 | python scene_capture.py --scene DATA_NAME
113 | ```
114 |
115 | ### Train POGS
116 | Script used to train the POGS for 4000 steps
117 | ```
118 | conda activate pogs_env
119 | ns-train pogs --data /path/to/data/folder
120 | ```
121 | Once the POGS has completed training, there are N steps to then actually define/save the object clusters.
122 | 1. Hit the cluster scene button.
123 | 2. It will take 10-20 seconds, but then after, you should see your objects as specific clusters. If not, hit Toggle RGB/Cluster and try to cluster the scene again but change the Cluster Eps (lower normally works better).
124 | 3. Once you have your scene clustered, hit Toggle RGB/Cluster.
125 | 4. Then, hit Click and click on your desired object (green ball will appear on object).
126 | 5. Hit Crop to Click, and it should isolate the object.
127 | 6. A draggable coordinate frame will pop up to indicate the object's origin, drag it to where you want it to be. (For experiments, this was what we used to align for object reset or tool servoing)
128 | 7. Hit Add Crop to Group List
129 | 8. Repeat steps 4-7 for all objects in scene
130 | 9. Hit View Crop Group List
131 | Once you have trained the POGS, make sure you have the config file and checkpoint directory from the terminal saved.
132 |
133 | ### Run POGS for grasping
134 | Script for letting you use a POGS to track an object online and grasp it.
135 | ```
136 | conda activate pogs_env
137 | python ~/POGS/pogs/scripts/track_main_online_demo.py --config_path /path/to/config/yml
138 | ```
139 |
140 | ## Bibtex
141 | If you find POGS useful for your work please consider citing:
142 | ```
143 | @article{yu2025pogs,
144 | author = {Yu, Justin and Hari, Kush and El-Refai, Karim and Dalil, Arnav and Kerr, Justin and Kim, Chung-Min and Cheng, Richard, and Irshad, Muhammad Z. and Goldberg, Ken},
145 | title = {Persistent Object Gaussian Splat (POGS) for Tracking Human and Robot Manipulation of Irregularly Shaped Objects},
146 | journal = {ICRA},
147 | year = {2025},
148 | }
149 | ```
150 |
--------------------------------------------------------------------------------
/media/POGS_servoing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/media/POGS_servoing.gif
--------------------------------------------------------------------------------
/media/POGS_teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/media/POGS_teaser.gif
--------------------------------------------------------------------------------
/pogs/calibration_outputs/world_to_extrinsic_zed_for_grasping_down.tf:
--------------------------------------------------------------------------------
1 | zed_extrinsic_for_grasping
2 | world
3 | 0.0 0.0 0.0
4 | 0.0 -1.0 0.0
5 | -1.0 0.0 0.0
6 | 0.0 0.0 -1.0
--------------------------------------------------------------------------------
/pogs/camera/capture_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from autolab_core import RigidTransform, CameraIntrinsics
3 | from autolab_core.transformations import euler_matrix, euler_from_matrix
4 | from PIL import Image
5 | import json
6 | from scipy.optimize import minimize
7 | from typing import List
8 |
9 | def estimate_cam2rob(H_chess_cams: List[RigidTransform], H_rob_worlds: List[RigidTransform]):
10 | '''
11 | Estimates transform between camera and robot end-effector frame using least-squares optimization.
12 | This implements the hand-eye calibration procedure using multiple observations.
13 |
14 | Parameters:
15 | -----------
16 | H_chess_cams : List[RigidTransform]
17 | List of transformations from chessboard frame to camera frame for each observation
18 | H_rob_worlds : List[RigidTransform]
19 | List of transformations from robot end-effector frame to world frame for each observation
20 |
21 | Returns:
22 | --------
23 | H_cam_rob : RigidTransform
24 | Estimated transformation from camera frame to robot end-effector frame
25 | H_chess_world : RigidTransform
26 | Estimated transformation from chessboard frame to world frame
27 | '''
28 | def residual(x):
29 | '''
30 | Computes the residual error for the current estimate of transformations.
31 |
32 | Parameters:
33 | -----------
34 | x : numpy.ndarray
35 | Parameter vector formatted as [x, y, z, rx, ry, rz] for chessboard-to-world
36 | followed by [x, y, z, rx, ry, rz] for camera-to-robot
37 |
38 | Returns:
39 | --------
40 | float
41 | Sum of position and orientation errors across all observations
42 | '''
43 | err = 0
44 | # Extract transformations from parameter vector
45 | H_chess_world = RigidTransform(translation=x[:3],
46 | rotation=euler_matrix(x[3], x[4], x[5])[:3, :3],
47 | from_frame='chess', to_frame='world')
48 | H_cam_rob = RigidTransform(translation=x[6:9],
49 | rotation=euler_matrix(x[9], x[10], x[11])[:3, :3],
50 | from_frame='cam', to_frame='rob')
51 |
52 | # Compute error across all observations
53 | for H_chess_cam, H_rob_world in zip(H_chess_cams, H_rob_worlds):
54 | # Estimate chessboard-to-world transform using current parameters
55 | H_chess_world_est = H_rob_world * H_cam_rob * H_chess_cam
56 |
57 | # Compute translation error
58 | err += np.linalg.norm(H_chess_world.translation - H_chess_world_est.translation)
59 |
60 | # Compute rotation error using Euler angles
61 | rot_diff = H_chess_world.rotation @ np.linalg.inv(H_chess_world_est.rotation)
62 | eul_diff = euler_from_matrix(rot_diff)
63 | err += np.linalg.norm(eul_diff)
64 |
65 | print(err)
66 | return err
67 |
68 | # Initialize parameters with zeros
69 | x0 = np.zeros(12)
70 |
71 | # Perform optimization using Sequential Least Squares Programming
72 | res = minimize(residual, x0, method='SLSQP')
73 | print(res)
74 |
75 | # Alert if optimization did not converge
76 | if not res.success:
77 | input("Optimization was not successful, press enter to acknowledge")
78 |
79 | # Extract optimized parameters
80 | x = res.x
81 | H_chess_world = RigidTransform(translation=x[:3],
82 | rotation=euler_matrix(x[3], x[4], x[5])[:3, :3],
83 | from_frame='chess', to_frame='world')
84 | H_cam_rob = RigidTransform(translation=x[6:9],
85 | rotation=euler_matrix(x[9], x[10], x[11])[:3, :3],
86 | from_frame='cam', to_frame='rob')
87 |
88 | return H_cam_rob, H_chess_world
89 |
90 |
91 | def point_at(cam_t, obstacle_t, extra_R=np.eye(3)):
92 | '''
93 | Computes a transformation that orients a camera to point at a specific 3D point.
94 |
95 | Parameters:
96 | -----------
97 | cam_t : numpy.ndarray
98 | 3D position of the camera/gripper in world coordinates
99 | obstacle_t : numpy.ndarray
100 | 3D position of the target location to point the camera at
101 | extra_R : numpy.ndarray, optional
102 | Additional rotation matrix to apply to the camera frame, useful for fine-tuning orientation
103 |
104 | Returns:
105 | --------
106 | RigidTransform
107 | Transformation representing camera pose pointing at the target
108 | '''
109 | # Compute the direction vector from camera to target
110 | dir = obstacle_t - cam_t
111 | z_axis = dir / np.linalg.norm(dir)
112 |
113 | # Compute the x-axis perpendicular to z-axis and world z-axis
114 | # Note: change the sign if camera positioning is difficult
115 | x_axis_dir = -np.cross(np.array((0, 0, 1)), z_axis)
116 |
117 | # Handle special case when camera direction is aligned with world z-axis
118 | if np.linalg.norm(x_axis_dir) < 1e-10:
119 | x_axis_dir = np.array((0, 1, 0))
120 | x_axis = x_axis_dir / np.linalg.norm(x_axis_dir)
121 |
122 | # Complete the right-handed coordinate system with y-axis
123 | y_axis_dir = np.cross(z_axis, x_axis)
124 | y_axis = y_axis_dir / np.linalg.norm(y_axis_dir)
125 |
126 | # Create rotation matrix from axes and apply extra rotation
127 | # Post-multiply to rotate the camera with respect to itself
128 | R = RigidTransform.rotation_from_axes(x_axis, y_axis, z_axis) @ extra_R
129 |
130 | # Create and return the complete transformation
131 | H = RigidTransform(translation=cam_t, rotation=R, from_frame='camera', to_frame='base_link')
132 | return H
133 |
134 | def save_data(imgs, poses, savedir, intr: CameraIntrinsics):
135 | '''
136 | Saves a collection of images and camera poses in the NeRF dataset format.
137 |
138 | Parameters:
139 | -----------
140 | imgs : List[numpy.ndarray]
141 | List of images captured from each viewpoint
142 | poses : List[RigidTransform]
143 | List of camera poses corresponding to each image
144 | savedir : str
145 | Directory path to save the dataset
146 | intr : CameraIntrinsics
147 | Camera intrinsic parameters
148 |
149 | Notes:
150 | ------
151 | The NeRF format includes:
152 | - Individual image files
153 | - A transforms.json file containing camera parameters and poses
154 | - Special conventions for coordinate systems (flipped y and z axes)
155 | '''
156 | import os
157 | os.makedirs(savedir, exist_ok=True)
158 |
159 | # Initialize data dictionary for transforms.json
160 | data_dict = dict()
161 | data_dict['frames'] = []
162 |
163 | # Add camera intrinsic parameters
164 | data_dict['fl_x'] = intr.fx
165 | data_dict['fl_y'] = intr.fy
166 | data_dict['cx'] = intr.cx
167 | data_dict['cy'] = intr.cy
168 | data_dict['h'] = imgs[0].shape[0]
169 | data_dict['w'] = imgs[0].shape[1]
170 |
171 | # NeRF-specific parameters
172 | data_dict['aabb_scale'] = 2
173 | data_dict['scale'] = 1.2
174 |
175 | pil_images = []
176 | for i, (im, p) in enumerate(zip(imgs, poses)):
177 | # Remove alpha channel if present
178 | if im.shape[2] == 4:
179 | im = im[..., :3]
180 |
181 | # Save image file
182 | img = Image.fromarray(im)
183 | pil_images.append(img)
184 | img.save(f'{savedir}/img{i}.jpg')
185 |
186 | # Convert pose to NeRF format (flip y and z axes)
187 | mat = p.matrix
188 | mat[:3, 1] *= -1
189 | mat[:3, 2] *= -1
190 |
191 | # Add frame information to data dictionary
192 | frame = {'file_path': f'img{i}.jpg', 'transform_matrix': mat.tolist()}
193 | data_dict['frames'].append(frame)
194 |
195 | # Save transforms.json file
196 | with open(f"{savedir}/transforms.json", 'w') as fp:
197 | json.dump(data_dict, fp)
198 |
199 | def load_data(savedir):
200 | '''
201 | Loads a NeRF dataset from a directory.
202 |
203 | Parameters:
204 | -----------
205 | savedir : str
206 | Directory path containing the NeRF dataset
207 |
208 | Returns:
209 | --------
210 | Tuple[List[numpy.ndarray], List[RigidTransform]]
211 | Tuple containing (images, camera_poses)
212 |
213 | Notes:
214 | ------
215 | This function reverses the coordinate system conversions applied in save_data
216 | to restore the original camera poses.
217 | '''
218 | import os
219 | if not os.path.exists(savedir):
220 | raise FileNotFoundError(f'{savedir} does not exist')
221 |
222 | # Load transforms.json file
223 | with open(f"{savedir}/transforms.json", 'r') as fp:
224 | data_dict = json.load(fp)
225 |
226 | poses = []
227 | imgs = []
228 |
229 | # Process each frame in the dataset
230 | for frame in data_dict['frames']:
231 | # Load image from file
232 | img = Image.open(f'{savedir}/{frame["file_path"]}')
233 | imgs.append(np.array(img))
234 |
235 | # Convert pose matrix from NeRF format back to original format
236 | mat = np.array(frame['transform_matrix'])
237 | mat[:3, 1] *= -1
238 | mat[:3, 2] *= -1
239 |
240 | # Create RigidTransform from matrix
241 | poses.append(RigidTransform(*RigidTransform.rotation_and_translation_from_matrix(mat)))
242 |
243 | return imgs, poses
--------------------------------------------------------------------------------
/pogs/camera/zed_stereo.py:
--------------------------------------------------------------------------------
1 | import pyzed.sl as sl
2 | import numpy as np
3 | from autolab_core import CameraIntrinsics, PointCloud, RgbCloud
4 | from raftstereo.raft_stereo import *
5 | from raftstereo.utils.utils import InputPadder
6 | import argparse
7 | from dataclasses import dataclass, field
8 | from typing import List
9 | import os
10 | DIR_PATH = os.path.dirname(os.path.realpath(__file__))
11 |
12 | # Configuration class for the RAFT stereo model parameters
13 | class RAFTConfig:
14 | # Path to the pre-trained RAFT stereo model checkpoint
15 | restore_ckpt: str = str(os.path.join(DIR_PATH,'../dependencies/raftstereo/models/raftstereo-middlebury.pth'))
16 | # Hidden dimensions for network layers
17 | hidden_dims: List[int] = [128]*3
18 | # Correlation implementation type
19 | corr_implementation: str = "reg"
20 | # Whether to use a shared backbone for both images
21 | shared_backbone: bool = False
22 | # Correlation pyramid levels for feature matching
23 | corr_levels: int = 4
24 | # Radius of correlation lookups
25 | corr_radius: int = 4
26 | # Number of downsampling layers
27 | n_downsample: int = 2
28 | # Normalization type for context network
29 | context_norm: str = "batch"
30 | # Whether to use slow-fast GRU for iterative updates
31 | slow_fast_gru: bool = False
32 | # Enable mixed precision for faster computation and memory efficiency
33 | mixed_precision: bool = False
34 | # Number of GRU layers in the update block
35 | n_gru_layers: int = 3
36 |
37 | # ZED stereo camera interface with RAFT stereo depth estimation
38 | class Zed:
39 | """
40 | ZED camera wrapper for stereo image capture and depth estimation using RAFT-Stereo algorithm.
41 | Provides functions for camera initialization, image capture, depth estimation, and point cloud generation.
42 | """
43 | def __init__(self, flip_mode, resolution, fps, cam_id=None, recording_file=None, start_time=0.0):
44 | """
45 | Initialize the ZED camera with specified parameters.
46 |
47 | Args:
48 | flip_mode (bool): Whether to flip the camera image
49 | resolution (str): Camera resolution ('720p', '1080p', or '2k')
50 | fps (int): Frames per second
51 | cam_id (int, optional): Camera serial number for multi-camera setups
52 | recording_file (str, optional): Path to SVO recording file for playback
53 | start_time (float, optional): Start time in seconds for SVO playback
54 | """
55 | init = sl.InitParameters()
56 | if cam_id is not None:
57 | init.set_from_serial_number(cam_id)
58 | self.cam_id = cam_id
59 | self.height_ = None
60 | self.width_ = None
61 |
62 | # Set camera flip mode
63 | if flip_mode:
64 | init.camera_image_flip = sl.FLIP_MODE.ON
65 | else:
66 | init.camera_image_flip = sl.FLIP_MODE.OFF
67 |
68 | # Configure for SVO file playback if provided
69 | if recording_file is not None:
70 | init.set_from_svo_file(recording_file)
71 |
72 | # Configure camera resolution
73 | if resolution == '720p':
74 | init.camera_resolution = sl.RESOLUTION.HD720
75 | self.height_ = 720
76 | self.width_ = 1280
77 | elif resolution == '1080p':
78 | init.camera_resolution = sl.RESOLUTION.HD1080
79 | self.height_ = 1080
80 | self.width_ = 1920
81 | elif resolution == '2k':
82 | init.camera_resolution = sl.RESOLUTION.HD2k
83 | self.height_ = 1242
84 | self.width_ = 2208
85 | else:
86 | print("Only 720p, 1080p, and 2k supported by Zed")
87 | exit()
88 |
89 | # Disable native ZED depth computation (we'll use RAFT-Stereo instead)
90 | init.depth_mode = sl.DEPTH_MODE.NONE
91 | init.sdk_verbose = 1
92 | init.camera_fps = fps
93 | self.cam = sl.Camera()
94 | init.camera_disable_self_calib = True
95 |
96 | # Open the camera with the specified parameters
97 | status = self.cam.open(init)
98 | self.recording_file = recording_file
99 | self.start_time = start_time
100 |
101 | # Set SVO playback position if applicable
102 | if recording_file is not None:
103 | fps = self.cam.get_camera_information().camera_configuration.fps
104 | self.cam.set_svo_position(int(start_time * fps))
105 |
106 | # Check if camera opened successfully
107 | if status != sl.ERROR_CODE.SUCCESS:
108 | print("Camera Open : " + repr(status) + ". Exit program.")
109 | exit()
110 | else:
111 | print("Opened camera")
112 |
113 | # Calculate stereo parameters for depth estimation
114 | left_cx = self.get_K(cam="left")[0, 2]
115 | right_cx = self.get_K(cam="right")[0, 2]
116 | self.cx_diff = right_cx - left_cx # Horizontal principal point difference for disparity calculation
117 | self.f_ = self.get_K(cam="left")[0,0] # Focal length
118 | self.cx_ = left_cx # Principal point x-coordinate
119 | self.cy_ = self.get_K(cam="left")[1,2] # Principal point y-coordinate
120 | self.Tx_ = self.get_stereo_transform()[0,3] # Baseline (translation between cameras)
121 |
122 | # RAFT-Stereo parameters
123 | self.valid_iters_ = 32 # Number of iterations for RAFT-Stereo flow estimation
124 | self.padder_ = InputPadder(torch.empty(1,3,self.height_,self.width_).shape, divis_by=32)
125 | self.model = self.create_raft() # Initialize RAFT-Stereo model
126 |
127 | def create_raft(self):
128 | """
129 | Initialize and load the RAFT-Stereo model for depth estimation.
130 |
131 | Returns:
132 | torch.nn.Module: Loaded RAFT-Stereo model in evaluation mode
133 | """
134 | raft_args = RAFTConfig()
135 | model = torch.nn.DataParallel(RAFTStereo(raft_args), device_ids=[0])
136 | model.load_state_dict(torch.load(raft_args.restore_ckpt))
137 |
138 | model = model.module
139 | model = model.to('cuda')
140 | model = model.eval()
141 | return model
142 |
143 | def load_image_raft(self, im):
144 | """
145 | Prepare image for RAFT-Stereo model inference.
146 |
147 | Args:
148 | im (numpy.ndarray): RGB image array
149 |
150 | Returns:
151 | torch.Tensor: Processed image tensor on GPU
152 | """
153 | img = torch.from_numpy(im).permute(2,0,1).float()
154 | return img[None].to('cuda')
155 |
156 | def get_depth_image_and_pointcloud(self, left_img, right_img, from_frame):
157 | """
158 | Compute depth image and point cloud from stereo images using RAFT-Stereo.
159 |
160 | Args:
161 | left_img (numpy.ndarray): Left RGB image
162 | right_img (numpy.ndarray): Right RGB image
163 | from_frame (str): Coordinate frame name for point cloud
164 |
165 | Returns:
166 | tuple: (depth_image, points, rgbs) containing depth map and colored point cloud
167 | """
168 | with torch.no_grad():
169 | # Prepare images for RAFT model
170 | image1 = self.load_image_raft(left_img)
171 | image2 = self.load_image_raft(right_img)
172 | image1, image2 = self.padder_.pad(image1, image2)
173 |
174 | # Run RAFT-Stereo to compute disparity (negative of optical flow)
175 | _, flow_up = self.model(image1, image2, iters=self.valid_iters_, test_mode=True)
176 | flow_up = self.padder_.unpad(flow_up).squeeze()
177 |
178 | flow_up_np = -flow_up.detach().cpu().numpy().squeeze()
179 |
180 | # Convert disparity to depth using the stereo camera equation:
181 | # depth = (focal_length * baseline) / disparity
182 | depth_image = (self.f_ * self.Tx_) / abs(flow_up_np + self.cx_diff)
183 | rows, cols = depth_image.shape
184 | y, x = np.meshgrid(range(rows), range(cols), indexing="ij")
185 |
186 | # Convert depth image to x,y,z point cloud using the pinhole camera model
187 | Z = depth_image # Depth values
188 | # X = (x - cx) * Z / fx
189 | X = (x - self.cx_) * Z / self.f_
190 | # Y = (y - cy) * Z / fy
191 | Y = (y - self.cy_) * Z / self.f_
192 | points = np.stack((X,Y,Z), axis=-1)
193 | rgbs = left_img
194 |
195 | # Remove points with zero depth (invalid or background points)
196 | non_zero_indices = np.all(points != [0, 0, 0], axis=-1)
197 | points = points[non_zero_indices]
198 | rgbs = rgbs[non_zero_indices]
199 |
200 | # Format as PointCloud and RgbCloud objects for compatibility with autolab_core
201 | points = points.reshape(-1,3)
202 | rgbs = rgbs.reshape(-1,3)
203 | points = PointCloud(points.T, from_frame)
204 | rgbs = RgbCloud(rgbs.T, from_frame)
205 | return depth_image, points, rgbs
206 |
207 | def get_frame(self, depth=True, cam="left"):
208 | """
209 | Capture a frame from the ZED camera and optionally compute depth.
210 |
211 | Args:
212 | depth (bool): Whether to compute depth
213 | cam (str): Which camera to use as reference ("left" or "right")
214 |
215 | Returns:
216 | tuple: (left_image, right_image, depth_map) or None if frame capture failed
217 | """
218 | res = sl.Resolution()
219 | res.width = self.width_
220 | res.height = self.height_
221 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS:
222 | # Retrieve stereo images
223 | left_rgb = sl.Mat()
224 | right_rgb = sl.Mat()
225 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT, sl.MEM.CPU, res)
226 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT, sl.MEM.CPU, res)
227 | left, right = (
228 | torch.from_numpy(
229 | np.flip(left_rgb.get_data()[..., :3], axis=2).copy()
230 | ).cuda(),
231 | torch.from_numpy(
232 | np.flip(right_rgb.get_data()[..., :3], axis=2).copy()
233 | ).cuda(),
234 | )
235 |
236 | # Compute depth if requested
237 | if depth:
238 | left_torch, right_torch = left.permute(2, 0, 1), right.permute(2, 0, 1)
239 |
240 | # Handle different reference views (left or right camera)
241 | if cam == "left":
242 | flow = raft_inference(left_torch, right_torch, self.model)
243 | else:
244 | right_torch = torch.flip(right_torch, dims=[2])
245 | left_torch = torch.flip(left_torch, dims=[2])
246 | flow = raft_inference(right_torch, left_torch, self.model)
247 |
248 | # Compute depth from disparity using stereo camera equation
249 | fx = self.get_K()[0, 0] # Focal length
250 | depth = (
251 | fx * self.get_stereo_transform()[0, 3] / (flow.abs() + self.cx_diff)
252 | )
253 |
254 | if cam != "left":
255 | depth = torch.flip(depth, dims=[1])
256 | else:
257 | depth = None
258 | return left, right, depth
259 | elif self.cam.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED:
260 | print("End of recording file")
261 | return None, None, None
262 | else:
263 | raise RuntimeError("Could not grab frame")
264 |
265 | def get_K(self, cam="left"):
266 | """
267 | Get camera intrinsic matrix (calibration matrix K).
268 |
269 | Args:
270 | cam (str): Which camera to use ("left" or "right")
271 |
272 | Returns:
273 | numpy.ndarray: 3x3 intrinsic matrix K
274 | """
275 | calib = (
276 | self.cam.get_camera_information().camera_configuration.calibration_parameters
277 | )
278 | if cam == "left":
279 | intrinsics = calib.left_cam
280 | else:
281 | intrinsics = calib.right_cam
282 | K = np.array(
283 | [
284 | [intrinsics.fx, 0, intrinsics.cx],
285 | [0, intrinsics.fy, intrinsics.cy],
286 | [0, 0, 1],
287 | ]
288 | )
289 | return K
290 |
291 | def get_intr(self, cam="left"):
292 | """
293 | Get camera intrinsics in autolab_core format.
294 |
295 | Args:
296 | cam (str): Which camera to use ("left" or "right")
297 |
298 | Returns:
299 | CameraIntrinsics: Camera intrinsics object
300 | """
301 | calib = (
302 | self.cam.get_camera_information().camera_configuration.calibration_parameters
303 | )
304 | if cam == "left":
305 | intrinsics = calib.left_cam
306 | else:
307 | intrinsics = calib.right_cam
308 | return CameraIntrinsics(
309 | frame="zed",
310 | fx=intrinsics.fx,
311 | fy=intrinsics.fy,
312 | cx=intrinsics.cx,
313 | cy=intrinsics.cy,
314 | width=1280,
315 | height=720,
316 | )
317 |
318 | def get_stereo_transform(self):
319 | """
320 | Get transformation matrix from left to right camera.
321 |
322 | Returns:
323 | numpy.ndarray: 4x4 transformation matrix (in meters)
324 | """
325 | transform = (
326 | self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.m
327 | )
328 | transform[:3, 3] /= 1000 # Convert from millimeters to meters
329 | return transform
330 |
331 | def start_record(self, out_path):
332 | """
333 | Start recording to an SVO file.
334 |
335 | Args:
336 | out_path (str): Output file path
337 | """
338 | recordingParameters = sl.RecordingParameters()
339 | recordingParameters.compression_mode = sl.SVO_COMPRESSION_MODE.H264
340 | recordingParameters.video_filename = out_path
341 | err = self.cam.enable_recording(recordingParameters)
342 |
343 | def stop_record(self):
344 | """Stop recording to SVO file."""
345 | self.cam.disable_recording()
346 |
347 | def get_rgb_depth(self, cam="left"):
348 | """
349 | Get RGB images and depth map.
350 |
351 | Args:
352 | cam (str): Which camera to use as reference ("left" or "right")
353 |
354 | Returns:
355 | tuple: (left_image, right_image, depth_map) as numpy arrays
356 | """
357 | left, right, depth = self.get_frame(cam=cam)
358 | return left.cpu().numpy(), right.cpu().numpy(), depth.cpu().numpy()
359 |
360 | def get_rgb(self, cam="left"):
361 | """
362 | Get only RGB images without computing depth.
363 |
364 | Args:
365 | cam (str): Which camera to use as reference ("left" or "right")
366 |
367 | Returns:
368 | tuple: (left_image, right_image) as numpy arrays
369 | """
370 | left, right, _ = self.get_frame(depth=False, cam=cam)
371 | return left.cpu().numpy(), right.cpu().numpy()
372 |
373 | def get_ns_intrinsics(self):
374 | """
375 | Get camera intrinsics in NeRF Studio format.
376 |
377 | Returns:
378 | dict: Camera intrinsics dictionary compatible with NeRF Studio
379 | """
380 | calib = (
381 | self.cam.get_camera_information().camera_configuration.calibration_parameters
382 | )
383 | calibration_parameters_l = calib.left_cam
384 | return {
385 | "w": self.width_,
386 | "h": self.height_,
387 | "fl_x": calibration_parameters_l.fx,
388 | "fl_y": calibration_parameters_l.fy,
389 | "cx": calibration_parameters_l.cx,
390 | "cy": calibration_parameters_l.cy,
391 | "k1": calibration_parameters_l.disto[0],
392 | "k2": calibration_parameters_l.disto[1],
393 | "p1": calibration_parameters_l.disto[3],
394 | "p2": calibration_parameters_l.disto[4],
395 | "camera_model": "OPENCV",
396 | }
397 |
398 | def get_zed_depth(self):
399 | """
400 | Get native ZED depth map (not using RAFT-Stereo).
401 | Note: This requires depth_mode to be enabled in InitParameters.
402 |
403 | Returns:
404 | numpy.ndarray: Depth map
405 | """
406 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS:
407 | depth = sl.Mat()
408 | self.cam.retrieve_measure(depth, sl.MEASURE.DEPTH)
409 | return depth.get_data()
410 | else:
411 | raise RuntimeError("Could not grab frame")
412 |
413 | def close(self):
414 | """Close the camera and release resources."""
415 | self.cam.close()
416 | self.cam = None
417 | print("Closed camera")
418 |
419 | def reopen(self):
420 | """Reopen the camera with previously specified parameters."""
421 | if self.cam is None:
422 | init = sl.InitParameters()
423 | # Configuration for SVO file playback
424 | if self.recording_file is not None:
425 | init.set_from_svo_file(self.recording_file)
426 | init.camera_image_flip = sl.FLIP_MODE.OFF
427 | init.depth_mode = sl.DEPTH_MODE.NONE
428 | init.camera_resolution = sl.RESOLUTION.HD1080
429 | init.sdk_verbose = 1
430 | init.camera_fps = 15
431 | # Configuration for live camera
432 | else:
433 | init.camera_resolution = sl.RESOLUTION.HD720
434 | init.sdk_verbose = 1
435 | init.camera_fps = 15
436 | init.camera_image_flip = sl.FLIP_MODE.OFF
437 | init.depth_mode = sl.DEPTH_MODE.NONE
438 | init.depth_minimum_distance = 100 # millimeters
439 |
440 | self.cam = sl.Camera()
441 | init.camera_disable_self_calib = True
442 | status = self.cam.open(init)
443 |
444 | # Set SVO playback position if applicable
445 | if self.recording_file is not None:
446 | fps = self.cam.get_camera_information().camera_configuration.fps
447 | self.cam.set_svo_position(int(self.start_time * fps))
448 |
449 | # Check if camera opened successfully
450 | if status != sl.ERROR_CODE.SUCCESS:
451 | print("Camera Open : " + repr(status) + ". Exit program.")
452 | exit()
453 | else:
454 | print("Opened camera")
455 | print(
456 | "Current Exposure is set to: ",
457 | self.cam.get_camera_settings(sl.VIDEO_SETTINGS.EXPOSURE),
458 | )
459 |
460 | # Reinitialize RAFT model and stereo parameters
461 | self.model = self.create_raft()
462 | left_cx = self.get_K(cam="left")[0, 2]
463 | right_cx = self.get_K(cam="right")[0, 2]
464 | self.cx_diff = right_cx - left_cx
--------------------------------------------------------------------------------
/pogs/configs/base_config.py:
--------------------------------------------------------------------------------
1 |
2 | """Base Configs"""
3 |
4 |
5 | from __future__ import annotations
6 |
7 | from dataclasses import dataclass, field
8 | from pathlib import Path
9 | from typing import Any, List, Literal, Optional, Tuple, Type
10 |
11 | # Pretty printing class
12 | class PrintableConfig:
13 | """Printable Config defining str function"""
14 |
15 | def __str__(self):
16 | lines = [self.__class__.__name__ + ":"]
17 | for key, val in vars(self).items():
18 | if isinstance(val, Tuple):
19 | flattened_val = "["
20 | for item in val:
21 | flattened_val += str(item) + "\n"
22 | flattened_val = flattened_val.rstrip("\n")
23 | val = flattened_val + "]"
24 | lines += f"{key}: {str(val)}".split("\n")
25 | return "\n ".join(lines)
26 |
27 |
28 | # Base instantiate configs
29 | @dataclass
30 | class InstantiateConfig(PrintableConfig):
31 | """Config class for instantiating an the class specified in the _target attribute."""
32 |
33 | _target: Type
34 |
35 | def setup(self, **kwargs) -> Any:
36 | """Returns the instantiated object using the config."""
37 | return self._target(self, **kwargs)
--------------------------------------------------------------------------------
/pogs/configs/camera_config.yaml:
--------------------------------------------------------------------------------
1 | third_view_zed:
2 | exposure: 100
3 | flip_mode: false
4 | fps: 30
5 | gain: 31
6 | id: 22008760
7 | resolution: 1080p
8 | wrist_mounted_zed:
9 | exposure: 67
10 | flip_mode: true
11 | fps: 30
12 | gain: 28
13 | id: 16347230
14 | resolution: 720p
15 |
--------------------------------------------------------------------------------
/pogs/configs/conda_compile_pointnet_tfops.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Make sure the conda environment is activated so that CONDA_PREFIX is defined.
3 |
4 | # Replace the system paths with conda environment paths:
5 | CUDA_INCLUDE=" -I${CONDA_PREFIX}/include/"
6 | CUDA_LIB=" -L${CONDA_PREFIX}/lib/"
7 |
8 | # Get TensorFlow compile and link flags from the active Python:
9 | TF_CFLAGS=$(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
10 | TF_LFLAGS=$(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
11 |
12 | cd pointnet2/tf_ops/sampling
13 |
14 | nvcc -std=c++11 -c -o tf_sampling_g.cu.o tf_sampling_g.cu \
15 | ${CUDA_INCLUDE} ${TF_CFLAGS} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
16 |
17 | g++ -std=c++11 -shared -o tf_sampling_so.so tf_sampling.cpp \
18 | tf_sampling_g.cu.o ${CUDA_INCLUDE} ${TF_CFLAGS} -fPIC -lcudart ${TF_LFLAGS} ${CUDA_LIB}
19 |
20 | echo 'testing sampling'
21 | python3 tf_sampling.py
22 |
23 | cd ../grouping
24 |
25 | nvcc -std=c++11 -c -o tf_grouping_g.cu.o tf_grouping_g.cu \
26 | ${CUDA_INCLUDE} ${TF_CFLAGS} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
27 |
28 | g++ -std=c++11 -shared -o tf_grouping_so.so tf_grouping.cpp \
29 | tf_grouping_g.cu.o ${CUDA_INCLUDE} ${TF_CFLAGS} -fPIC -lcudart ${TF_LFLAGS} ${CUDA_LIB}
30 |
31 | echo 'testing grouping'
32 | python3 tf_grouping_op_test.py
33 |
34 | cd ../3d_interpolation
35 |
36 | g++ -std=c++11 tf_interpolate.cpp -o tf_interpolate_so.so -shared -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2
37 |
38 | echo 'testing interpolate'
39 | python3 tf_interpolate_op_test.py
40 |
--------------------------------------------------------------------------------
/pogs/contact_graspnet_wrapper/prime_config_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | def recursive_key_value_assign(d,ks,v):
5 | """
6 | Recursive value assignment to a nested dict
7 |
8 | Arguments:
9 | d {dict} -- dict
10 | ks {list} -- list of hierarchical keys
11 | v {value} -- value to assign
12 | """
13 |
14 | if len(ks) > 1:
15 | recursive_key_value_assign(d[ks[0]],ks[1:],v)
16 | elif len(ks) == 1:
17 | d[ks[0]] = v
18 |
19 | def load_config(checkpoint_dir, batch_size=None, max_epoch=None, data_path=None, arg_configs=[], save=False):
20 | """
21 | Loads yaml config file and overwrites parameters with function arguments and --arg_config parameters
22 |
23 | Arguments:
24 | checkpoint_dir {str} -- Checkpoint directory where config file was copied to
25 |
26 | Keyword Arguments:
27 | batch_size {int} -- [description] (default: {None})
28 | max_epoch {int} -- "epochs" (number of scenes) to train (default: {None})
29 | data_path {str} -- path to scenes with contact grasp data (default: {None})
30 | arg_configs {list} -- Overwrite config parameters by hierarchical command line arguments (default: {[]})
31 | save {bool} -- Save overwritten config file (default: {False})
32 |
33 | Returns:
34 | [dict] -- Config
35 | """
36 |
37 | config_path = os.path.join(checkpoint_dir, 'config.yaml')
38 | config_path = config_path if os.path.exists(config_path) else os.path.join(os.path.dirname(__file__),'config.yaml')
39 | with open(config_path,'r') as f:
40 | global_config = yaml.safe_load(f)
41 |
42 | for conf in arg_configs:
43 | k_str, v = conf.split(':')
44 | try:
45 | v = eval(v)
46 | except:
47 | pass
48 | ks = [int(k) if k.isdigit() else k for k in k_str.split('.')]
49 |
50 | recursive_key_value_assign(global_config, ks, v)
51 |
52 | if batch_size is not None:
53 | global_config['OPTIMIZER']['batch_size'] = int(batch_size)
54 | if max_epoch is not None:
55 | global_config['OPTIMIZER']['max_epoch'] = int(max_epoch)
56 | if data_path is not None:
57 | global_config['DATA']['data_path'] = data_path
58 |
59 | global_config['DATA']['classes'] = None
60 |
61 | if save:
62 | with open(os.path.join(checkpoint_dir, 'config.yaml'),'w') as f:
63 | yaml.dump(global_config, f)
64 |
65 | return global_config
66 |
67 |
--------------------------------------------------------------------------------
/pogs/contact_graspnet_wrapper/prime_visualization_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mayavi.mlab as mlab
3 | import matplotlib.pyplot as plt
4 | from matplotlib import cm
5 |
6 | import mesh_utils
7 |
8 | def plot_mesh(mesh, cam_trafo=np.eye(4), mesh_pose=np.eye(4)):
9 | """
10 | Plots mesh in mesh_pose from
11 |
12 | Arguments:
13 | mesh {trimesh.base.Trimesh} -- input mesh, e.g. gripper
14 |
15 | Keyword Arguments:
16 | cam_trafo {np.ndarray} -- 4x4 transformation from world to camera coords (default: {np.eye(4)})
17 | mesh_pose {np.ndarray} -- 4x4 transformation from mesh to world coords (default: {np.eye(4)})
18 | """
19 |
20 | homog_mesh_vert = np.pad(mesh.vertices, (0, 1), 'constant', constant_values=(0, 1))
21 | mesh_cam = homog_mesh_vert.dot(mesh_pose.T).dot(cam_trafo.T)[:,:3]
22 | mlab.triangular_mesh(mesh_cam[:, 0],
23 | mesh_cam[:, 1],
24 | mesh_cam[:, 2],
25 | mesh.faces,
26 | colormap='Blues',
27 | opacity=0.5)
28 |
29 | def plot_coordinates(t,r, tube_radius=0.005):
30 | """
31 | plots coordinate frame
32 |
33 | Arguments:
34 | t {np.ndarray} -- translation vector
35 | r {np.ndarray} -- rotation matrix
36 |
37 | Keyword Arguments:
38 | tube_radius {float} -- radius of the plotted tubes (default: {0.005})
39 | """
40 | mlab.plot3d([t[0],t[0]+0.2*r[0,0]], [t[1],t[1]+0.2*r[1,0]], [t[2],t[2]+0.2*r[2,0]], color=(1,0,0), tube_radius=tube_radius, opacity=1)
41 | mlab.plot3d([t[0],t[0]+0.2*r[0,1]], [t[1],t[1]+0.2*r[1,1]], [t[2],t[2]+0.2*r[2,1]], color=(0,1,0), tube_radius=tube_radius, opacity=1)
42 | mlab.plot3d([t[0],t[0]+0.2*r[0,2]], [t[1],t[1]+0.2*r[1,2]], [t[2],t[2]+0.2*r[2,2]], color=(0,0,1), tube_radius=tube_radius, opacity=1)
43 |
44 | def show_image(rgb, segmap):
45 | """
46 | Overlay rgb image with segmentation and imshow segment
47 |
48 | Arguments:
49 | rgb {np.ndarray} -- color image
50 | segmap {np.ndarray} -- integer segmap of same size as rgb
51 | """
52 | plt.figure()
53 | figManager = plt.get_current_fig_manager()
54 |
55 | plt.ion()
56 | plt.show()
57 |
58 | if rgb is not None:
59 | plt.imshow(rgb)
60 | if segmap is not None:
61 | cmap = plt.get_cmap('rainbow')
62 | cmap.set_under(alpha=0.0)
63 | plt.imshow(segmap, cmap=cmap, alpha=0.5, vmin=0.0001)
64 | plt.draw()
65 | plt.pause(0.001)
66 |
67 | def visualize_grasps(full_pc, pred_grasps_cam, scores, plot_opencv_cam=False, pc_colors=None, gripper_openings=None, gripper_width=0.08):
68 | """Visualizes colored point cloud and predicted grasps. If given, colors grasps by segmap regions.
69 | Thick grasp is most confident per segment. For scene point cloud predictions, colors grasps according to confidence.
70 |
71 | Arguments:
72 | full_pc {np.ndarray} -- Nx3 point cloud of the scene
73 | pred_grasps_cam {dict[int:np.ndarray]} -- Predicted 4x4 grasp trafos per segment or for whole point cloud
74 | scores {dict[int:np.ndarray]} -- Confidence scores for grasps
75 |
76 | Keyword Arguments:
77 | plot_opencv_cam {bool} -- plot camera coordinate frame (default: {False})
78 | pc_colors {np.ndarray} -- Nx3 point cloud colors (default: {None})
79 | gripper_openings {dict[int:np.ndarray]} -- Predicted grasp widths (default: {None})
80 | gripper_width {float} -- If gripper_openings is None, plot grasp widths (default: {0.008})
81 | """
82 |
83 | print('Visualizing...takes time')
84 | cm = plt.get_cmap('rainbow')
85 | cm2 = plt.get_cmap('gist_rainbow')
86 |
87 | fig = mlab.figure('Pred Grasps')
88 | mlab.view(azimuth=180, elevation=180, distance=0.2)
89 | draw_pc_with_colors(full_pc, pc_colors)
90 | colors = [cm(1. * i/len(pred_grasps_cam))[:3] for i in range(len(pred_grasps_cam))]
91 | colors2 = {k:cm2(0.5*np.max(scores[k]))[:3] for k in pred_grasps_cam if np.any(pred_grasps_cam[k])}
92 |
93 | if plot_opencv_cam:
94 | plot_coordinates(np.zeros(3,),np.eye(3,3))
95 | for i,k in enumerate(pred_grasps_cam):
96 | if np.any(pred_grasps_cam[k]):
97 | gripper_openings_k = np.ones(len(pred_grasps_cam[k]))*gripper_width if gripper_openings is None else gripper_openings[k]
98 | if len(pred_grasps_cam) > 1:
99 | draw_grasps(pred_grasps_cam[k], np.eye(4), color=colors[i], gripper_openings=gripper_openings_k)
100 | draw_grasps([pred_grasps_cam[k][np.argmax(scores[k])]], np.eye(4), color=colors2[k],
101 | gripper_openings=[gripper_openings_k[np.argmax(scores[k])]], tube_radius=0.0025)
102 | else:
103 | colors3 = [cm2(0.5*score)[:3] for score in scores[k]]
104 | draw_grasps(pred_grasps_cam[k], np.eye(4), colors=colors3, gripper_openings=gripper_openings_k)
105 | mlab.show()
106 |
107 | def draw_pc_with_colors(pc, pc_colors=None, single_color=(0.3,0.3,0.3), mode='2dsquare', scale_factor=0.0018):
108 | """
109 | Draws colored point clouds
110 |
111 | Arguments:
112 | pc {np.ndarray} -- Nx3 point cloud
113 | pc_colors {np.ndarray} -- Nx3 point cloud colors
114 |
115 | Keyword Arguments:
116 | single_color {tuple} -- single color for point cloud (default: {(0.3,0.3,0.3)})
117 | mode {str} -- primitive type to plot (default: {'point'})
118 | scale_factor {float} -- Scale of primitives. Does not work for points. (default: {0.002})
119 |
120 | """
121 |
122 | if pc_colors is None:
123 | mlab.points3d(pc[:, 0], pc[:, 1], pc[:, 2], color=single_color, scale_factor=scale_factor, mode=mode)
124 | else:
125 | #create direct grid as 256**3 x 4 array
126 | def create_8bit_rgb_lut():
127 | xl = np.mgrid[0:256, 0:256, 0:256]
128 | lut = np.vstack((xl[0].reshape(1, 256**3),
129 | xl[1].reshape(1, 256**3),
130 | xl[2].reshape(1, 256**3),
131 | 255 * np.ones((1, 256**3)))).T
132 | return lut.astype('int32')
133 |
134 | scalars = pc_colors[:,0]*256**2 + pc_colors[:,1]*256 + pc_colors[:,2]
135 | rgb_lut = create_8bit_rgb_lut()
136 | points_mlab = mlab.points3d(pc[:, 0], pc[:, 1], pc[:, 2], scalars, mode=mode, scale_factor=.0018)
137 | points_mlab.glyph.scale_mode = 'scale_by_vector'
138 | points_mlab.module_manager.scalar_lut_manager.lut._vtk_obj.SetTableRange(0, rgb_lut.shape[0])
139 | points_mlab.module_manager.scalar_lut_manager.lut.number_of_colors = rgb_lut.shape[0]
140 | points_mlab.module_manager.scalar_lut_manager.lut.table = rgb_lut
141 |
142 | def draw_grasps(grasps, cam_pose, gripper_openings, color=(0,1.,0), colors=None, show_gripper_mesh=False, tube_radius=0.0008):
143 | """
144 | Draws wireframe grasps from given camera pose and with given gripper openings
145 |
146 | Arguments:
147 | grasps {np.ndarray} -- Nx4x4 grasp pose transformations
148 | cam_pose {np.ndarray} -- 4x4 camera pose transformation
149 | gripper_openings {np.ndarray} -- Nx1 gripper openings
150 |
151 | Keyword Arguments:
152 | color {tuple} -- color of all grasps (default: {(0,1.,0)})
153 | colors {np.ndarray} -- Nx3 color of each grasp (default: {None})
154 | tube_radius {float} -- Radius of the grasp wireframes (default: {0.0008})
155 | show_gripper_mesh {bool} -- Renders the gripper mesh for one of the grasp poses (default: {False})
156 | """
157 |
158 | gripper = mesh_utils.create_gripper('panda')
159 | gripper_control_points = gripper.get_control_point_tensor(1, False, convex_hull=False).squeeze()
160 | mid_point = 0.5*(gripper_control_points[1, :] + gripper_control_points[2, :])
161 | grasp_line_plot = np.array([np.zeros((3,)), mid_point, gripper_control_points[1], gripper_control_points[3],
162 | gripper_control_points[1], gripper_control_points[2], gripper_control_points[4]])
163 |
164 | if show_gripper_mesh and len(grasps) > 0:
165 | plot_mesh(gripper.hand, cam_pose, grasps[0])
166 |
167 | all_pts = []
168 | connections = []
169 | index = 0
170 | N = 7
171 | for i,(g,g_opening) in enumerate(zip(grasps, gripper_openings)):
172 | gripper_control_points_closed = grasp_line_plot.copy()
173 | gripper_control_points_closed[2:,0] = np.sign(grasp_line_plot[2:,0]) * g_opening/2
174 |
175 | pts = np.matmul(gripper_control_points_closed, g[:3, :3].T)
176 | pts += np.expand_dims(g[:3, 3], 0)
177 | pts_homog = np.concatenate((pts, np.ones((7, 1))),axis=1)
178 | pts = np.dot(pts_homog, cam_pose.T)[:,:3]
179 |
180 | color = color if colors is None else colors[i]
181 |
182 | all_pts.append(pts)
183 | connections.append(np.vstack([np.arange(index, index + N - 1.5),
184 | np.arange(index + 1, index + N - .5)]).T)
185 | index += N
186 | # mlab.plot3d(pts[:, 0], pts[:, 1], pts[:, 2], color=color, tube_radius=tube_radius, opacity=1.0)
187 |
188 | # speeds up plot3d because only one vtk object
189 | all_pts = np.vstack(all_pts)
190 | connections = np.vstack(connections)
191 | src = mlab.pipeline.scalar_scatter(all_pts[:,0], all_pts[:,1], all_pts[:,2])
192 | src.mlab_source.dataset.lines = connections
193 | src.update()
194 | lines =mlab.pipeline.tube(src, tube_radius=tube_radius, tube_sides=12)
195 | mlab.pipeline.surface(lines, color=color, opacity=1.0)
196 |
197 |
--------------------------------------------------------------------------------
/pogs/data/depth_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Depth dataset.
17 | """
18 |
19 | import json
20 | from pathlib import Path
21 | from typing import Dict, Union
22 |
23 | import numpy as np
24 | import torch
25 | from PIL import Image
26 | from rich.progress import track
27 |
28 | from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
29 | from nerfstudio.data.datasets.base_dataset import InputDataset
30 | from nerfstudio.data.utils.data_utils import get_depth_image_from_path
31 | from nerfstudio.model_components import losses
32 | from nerfstudio.utils.misc import torch_compile
33 | from nerfstudio.utils.rich_utils import CONSOLE
34 |
35 |
36 | class DepthDataset(InputDataset):
37 | """Dataset that returns images and depths. If no depths are found, then we generate them with Zoe Depth.
38 |
39 | Args:
40 | dataparser_outputs: description of where and how to read input images.
41 | scale_factor: The scaling factor for the dataparser outputs.
42 | """
43 |
44 | def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0):
45 | super().__init__(dataparser_outputs, scale_factor)
46 | # if there are no depth images than we want to generate them all with zoe depth
47 | if len(dataparser_outputs.image_filenames) > 0 and (
48 | "depth_filenames" not in dataparser_outputs.metadata.keys()
49 | or dataparser_outputs.metadata["depth_filenames"] is None
50 | ):
51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52 | CONSOLE.print("[bold yellow] No depth data found!")
53 | losses.FORCE_PSEUDODEPTH_LOSS = True
54 | CONSOLE.print("[bold red] Using psueodepth: forcing depth loss to be ranking loss.")
55 | cache = dataparser_outputs.image_filenames[0].parent / "depths.npy"
56 | # Note: this should probably be saved to disk as images, and then loaded with the dataparser.
57 | # That will allow multi-gpu training.
58 | if cache.exists():
59 | CONSOLE.print("[bold yellow] Loading pseudodata depth from cache!")
60 | # load all the depths
61 | self.depths = np.load(cache)
62 | self.depths = torch.from_numpy(self.depths).to(device)
63 | else:
64 | CONSOLE.print("[bold yellow] No cache found...")
65 | dataparser_outputs.metadata["depth_filenames"] = None
66 | dataparser_outputs.metadata["depth_unit_scale_factor"] = 1.0
67 | self.metadata["depth_filenames"] = None
68 | self.metadata["depth_unit_scale_factor"] = 1.0
69 |
70 | self.depth_filenames = self.metadata["depth_filenames"]
71 | self.depth_unit_scale_factor = self.metadata["depth_unit_scale_factor"]
72 |
73 | def get_metadata(self, data: Dict) -> Dict:
74 | if self.depth_filenames is None:
75 | return {}
76 |
77 | filepath = self.depth_filenames[data["image_idx"]]
78 | height = int(self._dataparser_outputs.cameras.height[data["image_idx"]])
79 | width = int(self._dataparser_outputs.cameras.width[data["image_idx"]])
80 |
81 | # Scale depth images to meter units and also by scaling applied to cameras
82 | scale_factor = self.depth_unit_scale_factor * self._dataparser_outputs.dataparser_scale
83 | depth_image = get_depth_image_from_path(
84 | filepath=filepath, height=height, width=width, scale_factor=scale_factor
85 | )
86 |
87 | return {"depth_image": depth_image}
88 |
89 | def _find_transform(self, image_path: Path) -> Union[Path, None]:
90 | while image_path.parent != image_path:
91 | transform_path = image_path.parent / "transforms.json"
92 | if transform_path.exists():
93 | return transform_path
94 | image_path = image_path.parent
95 | return None
--------------------------------------------------------------------------------
/pogs/data/utils/detic_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2 |
3 | import argparse
4 | import json
5 | import os
6 | import pickle
7 | import random
8 | from typing import List, Dict, Tuple
9 | import sys
10 | import time
11 | import cv2
12 | from pogs.data.utils.feature_dataloader import FeatureDataloader
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 | from matplotlib.animation import FuncAnimation
16 | import matplotlib.colors as mcolors
17 | # from mpl_toolkits.mplot3d import Axes3D
18 | import open3d as o3d
19 | from PIL import Image
20 | import torch
21 | from tqdm import tqdm
22 |
23 | # Change the current working directory to 'Detic'
24 | dir_path = os.path.dirname(os.path.realpath(__file__))
25 | cwd = os.getcwd()
26 | os.chdir(dir_path+'/../../dependencies/Detic')
27 |
28 | # Setup detectron2 logger
29 | import detectron2
30 | from detectron2.utils.logger import setup_logger
31 | setup_logger()
32 |
33 | # import common libraries
34 | sys.path.insert(0, os.getcwd()+'/third_party/CenterNet2/')
35 |
36 | # import some common detectron2 utilities
37 | from detectron2 import model_zoo
38 | from detectron2.engine import DefaultPredictor
39 | from detectron2.config import get_cfg
40 | from detectron2.utils.visualizer import Visualizer
41 | from detectron2.data import MetadataCatalog, DatasetCatalog
42 |
43 | # Detic libraries
44 | from collections import defaultdict
45 | from centernet.config import add_centernet_config
46 | from pogs.dependencies.Detic.detic.config import add_detic_config
47 | from pogs.dependencies.Detic.detic.modeling.utils import reset_cls_test
48 | from sklearn.cluster import DBSCAN
49 | import matplotlib.patches as patches
50 | import torch.nn.functional as F
51 |
52 | os.chdir(cwd)
53 |
54 | class DeticDataloader(FeatureDataloader):
55 | def __init__(
56 | self,
57 | cfg: dict,
58 | device: torch.device,
59 | image_list: torch.Tensor = None,
60 | cache_path: str = None,
61 | ):
62 | # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63 | self.sam = cfg["sam"]
64 | self.cstm_vocab = cfg["custom_vocab"]
65 | self.downscale_factor = cfg["downscale_factor"]
66 | # image_list: torch.Tensor = None,
67 | self.outs = [],
68 | super().__init__(cfg, device, image_list, cache_path)
69 |
70 | def create(self, image_list = None):
71 | # os.makedirs(self.cache_path, exist_ok=True)
72 | # Build the detector and download our pretrained weights
73 | cfg = get_cfg()
74 | add_centernet_config(cfg)
75 | add_detic_config(cfg)
76 | cfg.merge_from_file(dir_path+'/../../dependencies/Detic/configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml')
77 | cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth'
78 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
79 | cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand'
80 | cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True # For better visualization purpose. Set to False for all classes.
81 | # cfg.MODEL.DEVICE='cpu' # uncomment this to use cpu-only mode.
82 | os.chdir(dir_path+'/../../dependencies/Detic')
83 | self.detic_predictor = DefaultPredictor(cfg)
84 | os.chdir(cwd)
85 | if self.sam == True:
86 | from segment_anything import sam_model_registry, SamPredictor
87 | sam_checkpoint = "../sam_model/sam_vit_h_4b8939.pth"
88 | model_type = "vit_h"
89 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
90 | sam.to(device=self.device)
91 | print('SAM + Detic on device: ', self.device)
92 | self.sam_predictor = SamPredictor(sam)
93 | if len(self.cstm_vocab) > 0:
94 | print("Using custom vocabulary with classes: ", self.cstm_vocab)
95 | self.custom_vocab(self.cstm_vocab)
96 | else:
97 | print("Using default vocabulary")
98 | os.chdir(dir_path+'/../../dependencies/Detic')
99 | self.default_vocab()
100 | os.chdir(cwd)
101 |
102 | if image_list is not None:
103 |
104 | start_time = time.time()
105 | for idx, img in enumerate(tqdm(image_list, desc="Detic Detector", leave=False)):
106 | H, W = img.shape[-2:]
107 | img = (img.permute(1, 2, 0).numpy()*255).astype(np.uint8)
108 |
109 | output = self.detic_predictor(img[:, :, ::-1])
110 | instances = output["instances"].to('cpu')
111 |
112 | boxes = instances.pred_boxes.tensor.numpy()
113 |
114 | masks = None
115 | components = torch.zeros(H, W)
116 | if self.sam:
117 | if len(boxes) > 0:
118 | # Only run SAM if there are bboxes
119 | masks = self.SAM(img, boxes)
120 | for i in range(masks.shape[0]):
121 | if torch.sum(masks[i][0]) <= H*W/3.5:
122 | components[masks[i][0]] = i + 1
123 | else:
124 | masks = output['instances'].pred_masks.unsqueeze(1)
125 | for i in range(masks.shape[0]):
126 | if torch.sum(masks[i][0]) <= H*W/3.5:
127 | components[masks[i][0]] = i + 1
128 | bg_mask = (components == 0).to(self.device)
129 |
130 | # Erode all masks using 3x3 kernel
131 | eroded_masks = torch.conv2d(
132 | (~masks).float().cuda(),
133 | torch.full((3, 3), 1.0).view(1, 1, 3, 3).to("cuda"),
134 | padding=1,
135 | )
136 | eroded_masks = ~(eroded_masks >= 2)
137 |
138 | # Filter out small masks
139 | filtered_idx = []
140 | for i in range(len(masks)):
141 | if masks[i].sum(dim=(1,2)) <= H*W/3.5:
142 | filtered_idx.append(i)
143 | filtered_masks = torch.cat([eroded_masks[filtered_idx], bg_mask.unsqueeze(0).unsqueeze(0)], dim=0).cpu().numpy()
144 |
145 | if self.downscale_factor > 1:
146 | scaled_height = H//self.downscale_factor
147 | scaled_width = W//self.downscale_factor
148 | filtered_masks = F.interpolate(torch.from_numpy(filtered_masks).to(float), (scaled_height, scaled_width), mode = 'nearest').to(bool).squeeze(1).view(-1, scaled_height*scaled_width)
149 | filtered_masks = filtered_masks.numpy()
150 |
151 | outputs = {
152 | # "vis": out,
153 | "boxes": boxes,
154 | "masks": masks,
155 | "masks_filtered": filtered_masks,
156 | # "class_idx": class_idx,
157 | # "class_name": class_name,
158 | # "clip_embeds": clip_embeds,
159 | "components": components,
160 | "scores" : output["instances"].scores,
161 | }
162 |
163 | self.outs[0].append(outputs['masks_filtered'])
164 |
165 | self.data = np.empty(len(image_list), dtype=object)
166 | self.data[:] = self.outs[0]
167 |
168 | print("Detic batch inference time: ", time.time() - start_time)
169 |
170 | # Overridden load method
171 | def load(self):
172 | cache_info_path = self.cache_path.with_suffix(".info")
173 | # print(cache_info_path)
174 | # print(cache_info_path.exists())
175 | # import pdb; pdb.set_trace()
176 | if not cache_info_path.exists():
177 | raise FileNotFoundError
178 | with open(cache_info_path, "r") as f:
179 | cfg = json.loads(f.read())
180 | if cfg != self.cfg:
181 | raise ValueError("Config mismatch")
182 | self.data = np.load(self.cache_path, allow_pickle=True)
183 |
184 | # Overridden save method
185 | def save(self):
186 | os.makedirs(self.cache_path.parent, exist_ok=True)
187 | cache_info_path = self.cache_path.with_suffix(".info")
188 | with open(cache_info_path, "w") as f:
189 | f.write(json.dumps(self.cfg))
190 | np.save(self.cache_path, self.data)
191 |
192 | def __call__(self, img_idx):
193 | return NotImplementedError
194 |
195 | def default_vocab(self):
196 | # detic_predictor = self.detic_predictor
197 | # Setup the model's vocabulary using build-in datasets
198 | BUILDIN_CLASSIFIER = {
199 | 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
200 | 'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy',
201 | 'openimages': 'datasets/metadata/oid_clip_a+cname.npy',
202 | 'coco': 'datasets/metadata/coco_clip_a+cname.npy',
203 | }
204 |
205 | BUILDIN_METADATA_PATH = {
206 | 'lvis': 'lvis_v1_val',
207 | 'objects365': 'objects365_v2_val',
208 | 'openimages': 'oid_val_expanded',
209 | 'coco': 'coco_2017_val',
210 | }
211 |
212 | vocabulary = 'lvis' # change to 'lvis', 'objects365', 'openimages', or 'coco'
213 | self.metadata = MetadataCatalog.get(BUILDIN_METADATA_PATH[vocabulary])
214 | classifier = BUILDIN_CLASSIFIER[vocabulary]
215 |
216 | num_classes = len(self.metadata.thing_classes)
217 | reset_cls_test(self.detic_predictor.model, classifier, num_classes)
218 |
219 | def custom_vocab(self, classes):
220 | os.chdir(dir_path+'/../../dependencies/Detic')
221 | self.metadata = MetadataCatalog.get("__unused2")
222 | os.chdir(cwd)
223 | self.metadata.thing_classes = classes
224 | classifier = self.get_clip_embeddings(self.metadata.thing_classes)
225 | num_classes = len(self.metadata.thing_classes)
226 | reset_cls_test(self.detic_predictor.model, classifier, num_classes)
227 |
228 | # Reset visualization threshold
229 | output_score_threshold = 0.3
230 | for cascade_stages in range(len(self.detic_predictor.model.roi_heads.box_predictor)):
231 | self.detic_predictor.model.roi_heads.box_predictor[cascade_stages].test_score_thresh = output_score_threshold
232 |
233 | def get_clip_embeddings(self, vocabulary, prompt='a '):
234 | self.text_encoder.eval()
235 | texts = [prompt + x for x in vocabulary]
236 | emb = self.text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
237 | return emb
238 |
239 | def SAM(self, im, boxes, class_idx = None, metadata = None):
240 | self.sam_predictor.set_image(im)
241 | input_boxes = torch.tensor(boxes, device=self.sam_predictor.device)
242 | transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(input_boxes, im.shape[:2])
243 | masks, _, _ = self.sam_predictor.predict_torch(
244 | point_coords=None,
245 | point_labels=None,
246 | boxes=transformed_boxes,
247 | multimask_output=False,
248 | )
249 | return masks
250 |
251 | def visualize_detic(self, output):
252 | output_im = output.get_image()[:, :, ::-1]
253 | cv2.imshow("Detic Predictions", output_im)
254 | cv2.waitKey(0)
255 | cv2.destroyAllWindows()
256 |
257 | def predict(self, im):
258 | if im is None:
259 | print("Error: Unable to read the image file")
260 |
261 | H, W = im.shape[:2]
262 |
263 | # Run model and show results
264 | start_time = time.time()
265 | output = self.detic_predictor(im[:, :, ::-1]) # Detic expects BGR images.
266 | print("Inference time: ", time.time() - start_time)
267 | v = Visualizer(im, self.metadata)
268 | out = v.draw_instance_predictions(output["instances"].to('cpu'))
269 | instances = output["instances"].to('cpu')
270 | boxes = instances.pred_boxes.tensor.numpy()
271 | # class_idx = instances.pred_classes.numpy()
272 | # class_name = [self.metadata.thing_classes[idx] for idx in class_idx]
273 | # clip_embeds = self.get_clip_embeddings(class_name)
274 |
275 | masks = None
276 | components = torch.zeros(H, W)
277 | if self.sam:
278 | if len(boxes) > 0:
279 | # Only run SAM if there are bboxes
280 | masks = self.SAM(im, boxes)
281 | for i in range(masks.shape[0]):
282 | if torch.sum(masks[i][0]) <= H*W/3.5:
283 | components[masks[i][0]] = i + 1
284 | else:
285 | masks = output['instances'].pred_masks.unsqueeze(1)
286 | for i in range(masks.shape[0]):
287 | if torch.sum(masks[i][0]) <= H*W/3.5:
288 | components[masks[i][0]] = i + 1
289 | bg_mask = (components == 0).to(self.device)
290 |
291 | # Filter out small masks
292 | filtered_idx = []
293 | for i in range(len(masks)):
294 | if masks[i].sum(dim=(1,2)) <= H*W/3.5:
295 | filtered_idx.append(i)
296 | filtered_masks = torch.cat([masks[filtered_idx], bg_mask.unsqueeze(0).unsqueeze(0)], dim=0)
297 |
298 | # invert_masks = ~filtered_masks
299 | # # erode all masks using 3x3 kernel
300 | # eroded_masks = torch.conv2d(
301 | # invert_masks.float(),
302 | # torch.full((3, 3), 1.0).view(1, 1, 3, 3).to("cuda"),
303 | # padding=1,
304 | # )
305 | # filtered_masks = ~(eroded_masks >= 5).squeeze(1) # (num_masks, H, W)
306 |
307 |
308 | outputs = {
309 | "vis": out,
310 | "boxes": boxes,
311 | "masks": masks,
312 | "masks_filtered": filtered_masks,
313 | # "class_idx": class_idx,
314 | # "class_name": class_name,
315 | # "clip_embeds": clip_embeds,
316 | "components": components,
317 | "scores" : output["instances"].scores,
318 | }
319 | return outputs
320 |
321 |
322 | def show_mask(self, mask, ax, random_color=False):
323 | if random_color:
324 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
325 | else:
326 | color = np.array([30/255, 144/255, 255/255, 0.6])
327 | h, w = mask.shape[-2:]
328 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
329 | ax.imshow(mask_image)
330 |
331 |
332 | def show_box(self, box, ax):
333 | x0, y0 = box[0], box[1]
334 | w, h = box[2] - box[0], box[3] - box[1]
335 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
336 |
337 |
338 | def visualize_output(self, im, masks, input_boxes, classes, image_save_path, mask_only=False):
339 | plt.figure(figsize=(10, 10))
340 | plt.imshow(im)
341 | for mask in masks:
342 | self.show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
343 | if not mask_only:
344 | for box, class_name in zip(input_boxes, classes):
345 | self.show_box(box, plt.gca())
346 | x, y = box[:2]
347 | plt.gca().text(x, y - 5, class_name, color='white', fontsize=12, fontweight='bold', bbox=dict(facecolor='green', edgecolor='green', alpha=0.5))
348 | plt.axis('off')
349 | plt.savefig(image_save_path)
350 | #plt.show()
351 |
352 |
353 | def generate_colors(self, num_colors):
354 | hsv_colors = []
355 | for i in range(num_colors):
356 | hue = i / float(num_colors)
357 | hsv_colors.append((hue, 1.0, 1.0))
358 |
359 | return [mcolors.hsv_to_rgb(color) for color in hsv_colors]
--------------------------------------------------------------------------------
/pogs/data/utils/dino_dataloader.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import torch
4 | from pogs.data.utils.dino_extractor import ViTExtractor
5 | from pogs.data.utils.feature_dataloader import FeatureDataloader
6 | from tqdm import tqdm
7 | from torchvision import transforms
8 | from typing import Tuple
9 | import numpy as np
10 |
11 | #usually 1260 max size, 1050 for vit-L with ROI
12 | MAX_DINO_SIZE = 1260
13 | def get_img_resolution(H, W, p=14):
14 | if H torch.Tensor:
102 | """
103 | returns BxHxWxC
104 | """
105 | return self.data[img_ind].to(self.device)
--------------------------------------------------------------------------------
/pogs/data/utils/feature_dataloader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import typing
4 | from abc import ABC, ABCMeta, abstractmethod
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class FeatureDataloader(ABC):
12 | def __init__(
13 | self,
14 | cfg: dict,
15 | device: torch.device,
16 | image_list: torch.Tensor, # (N, 3, H, W)
17 | cache_path: Path,
18 | ):
19 | self.cfg = cfg
20 | self.device = device
21 | self.cache_path = cache_path
22 | self.data = None # only expect data to be cached, nothing else
23 | self.try_load(image_list) # don't save image_list, avoid duplicates
24 |
25 | @abstractmethod
26 | def __call__(self, img_points):
27 | pass
28 |
29 | @abstractmethod
30 | def create(self, image_list: torch.Tensor):
31 | pass
32 |
33 | def load(self):
34 | cache_info_path = self.cache_path.with_suffix(".info")
35 | print(cache_info_path)
36 | import os
37 | print(os.getcwd())
38 | if not cache_info_path.exists():
39 | raise FileNotFoundError
40 | with open(cache_info_path, "r") as f:
41 | cfg = json.loads(f.read())
42 | if cfg != self.cfg:
43 | raise ValueError("Config mismatch")
44 | self.data = torch.from_numpy(np.load(self.cache_path)).to(self.device)
45 |
46 | def save(self):
47 | os.makedirs(self.cache_path.parent, exist_ok=True)
48 | cache_info_path = self.cache_path.with_suffix(".info")
49 | with open(cache_info_path, "w") as f:
50 | f.write(json.dumps(self.cfg))
51 | np.save(self.cache_path, self.data.numpy(force=True))
52 |
53 | def try_load(self, img_list: torch.Tensor):
54 | try:
55 | self.load()
56 | except (FileNotFoundError, ValueError):
57 | self.create(img_list)
58 | self.save()
--------------------------------------------------------------------------------
/pogs/data/utils/patch_embedding_dataloader.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import torch
5 | from pogs.data.utils.feature_dataloader import FeatureDataloader
6 | from pogs.encoders.image_encoder import BaseImageEncoder
7 | from tqdm import tqdm
8 |
9 |
10 | class PatchEmbeddingDataloader(FeatureDataloader):
11 | def __init__(
12 | self,
13 | cfg: dict,
14 | device: torch.device,
15 | model: BaseImageEncoder,
16 | image_list: torch.Tensor = None,
17 | cache_path: str = None,
18 | ):
19 | assert "tile_ratio" in cfg
20 | assert "stride_ratio" in cfg
21 | assert "image_shape" in cfg
22 | assert "model_name" in cfg
23 |
24 | self.kernel_size = int(cfg["image_shape"][0] * cfg["tile_ratio"])
25 | self.stride = int(self.kernel_size * cfg["stride_ratio"])
26 | self.padding = self.kernel_size // 2
27 | self.center_x = (
28 | (self.kernel_size - 1) / 2
29 | - self.padding
30 | + self.stride
31 | * np.arange(
32 | np.floor((cfg["image_shape"][0] + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1)
33 | )
34 | )
35 | self.center_y = (
36 | (self.kernel_size - 1) / 2
37 | - self.padding
38 | + self.stride
39 | * np.arange(
40 | np.floor((cfg["image_shape"][1] + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1)
41 | )
42 | )
43 | self.center_x = torch.from_numpy(self.center_x).half()
44 | self.center_y = torch.from_numpy(self.center_y).half()
45 | self.start_x = self.center_x[0].float()
46 | self.start_y = self.center_y[0].float()
47 |
48 | self.model = model
49 | self.embed_size = self.model.embedding_dim
50 | super().__init__(cfg, device, image_list, cache_path)
51 |
52 | def load(self):
53 | cache_info_path = self.cache_path.with_suffix(".info")
54 | if not cache_info_path.exists():
55 | raise FileNotFoundError
56 | with open(cache_info_path, "r") as f:
57 | cfg = json.loads(f.read())
58 | if cfg != self.cfg:
59 | raise ValueError("Config mismatch")
60 | self.data = torch.from_numpy(np.load(self.cache_path)).half().to(self.device)
61 |
62 | def create(self, image_list):
63 | assert self.model is not None, "model must be provided to generate features"
64 | assert image_list is not None, "image_list must be provided to generate features"
65 |
66 | unfold_func = torch.nn.Unfold(
67 | kernel_size=self.kernel_size,
68 | stride=self.stride,
69 | padding=self.padding,
70 | ).to(self.device)
71 |
72 | img_embeds = []
73 | for img in tqdm(image_list, desc="CLIP Embedding Images", leave=False):
74 | img_embeds.append(self._embed_clip_tiles(img.unsqueeze(0), unfold_func))
75 | self.data = torch.from_numpy(np.stack(img_embeds)).half().to(self.device)
76 |
77 | def __call__(self, img_points):
78 | # img_points: (B, 3) # (img_ind, x, y) (img_ind, row, col)
79 | # return: (B, 512)
80 | img_points = img_points.cpu()
81 | img_ind, img_points_x, img_points_y = img_points[:, 0], img_points[:, 1], img_points[:, 2]
82 |
83 | x_ind = torch.floor((img_points_x - (self.start_x)) / self.stride).long()
84 | y_ind = torch.floor((img_points_y - (self.start_y)) / self.stride).long()
85 | return self._interp_inds(img_ind, x_ind, y_ind, img_points_x, img_points_y)
86 |
87 | def _interp_inds(self, img_ind, x_ind, y_ind, img_points_x, img_points_y):
88 | img_ind = img_ind.to(self.data.device) # self.data is on cpu to save gpu memory, hence this line
89 | topleft = self.data[img_ind, x_ind, y_ind].to(self.device)
90 | topright = self.data[img_ind, x_ind + 1, y_ind].to(self.device)
91 | botleft = self.data[img_ind, x_ind, y_ind + 1].to(self.device)
92 | botright = self.data[img_ind, x_ind + 1, y_ind + 1].to(self.device)
93 |
94 | x_stride = self.stride
95 | y_stride = self.stride
96 | right_w = ((img_points_x - (self.center_x[x_ind])) / x_stride).to(self.device) # .half()
97 | top = torch.lerp(topleft, topright, right_w[:, None])
98 | bot = torch.lerp(botleft, botright, right_w[:, None])
99 |
100 | bot_w = ((img_points_y - (self.center_y[y_ind])) / y_stride).to(self.device) # .half()
101 | return torch.lerp(top, bot, bot_w[:, None])
102 |
103 | def _embed_clip_tiles(self, image, unfold_func):
104 | # image augmentation: slow-ish (0.02s for 600x800 image per augmentation)
105 | aug_imgs = torch.cat([image])
106 |
107 | tiles = unfold_func(aug_imgs).permute(2, 0, 1).reshape(-1, 3, self.kernel_size, self.kernel_size).to("cuda")
108 |
109 | with torch.no_grad():
110 | clip_embeds = self.model.encode_image(tiles)
111 | clip_embeds /= clip_embeds.norm(dim=-1, keepdim=True)
112 |
113 | clip_embeds = clip_embeds.reshape((self.center_x.shape[0], self.center_y.shape[0], -1))
114 | clip_embeds = torch.concat((clip_embeds, clip_embeds[:, [-1], :]), dim=1)
115 | clip_embeds = torch.concat((clip_embeds, clip_embeds[[-1], :, :]), dim=0)
116 | return clip_embeds.detach().cpu().numpy()
--------------------------------------------------------------------------------
/pogs/data/utils/pyramid_embedding_dataloader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import torch
7 | from pogs.data.utils.feature_dataloader import FeatureDataloader
8 | from pogs.data.utils.patch_embedding_dataloader import PatchEmbeddingDataloader
9 | from pogs.encoders.image_encoder import BaseImageEncoder
10 | from tqdm import tqdm
11 |
12 |
13 | class PyramidEmbeddingDataloader(FeatureDataloader):
14 | def __init__(
15 | self,
16 | cfg: dict,
17 | device: torch.device,
18 | model: BaseImageEncoder,
19 | image_list: torch.Tensor = None,
20 | cache_path: str = None,
21 | ):
22 | assert "tile_size_range" in cfg
23 | assert "tile_size_res" in cfg
24 | assert "stride_scaler" in cfg
25 | assert "image_shape" in cfg
26 | assert "model_name" in cfg
27 |
28 | self.tile_sizes = torch.linspace(*cfg["tile_size_range"], cfg["tile_size_res"]).to(device)
29 | self.strider_scaler_list = [self._stride_scaler(tr.item(), cfg["stride_scaler"]) for tr in self.tile_sizes]
30 |
31 | self.model = model
32 | self.embed_size = self.model.embedding_dim
33 | self.data_dict = {}
34 | super().__init__(cfg, device, image_list, cache_path)
35 |
36 | def __call__(self, img_points, scale=None):
37 | if scale is None:
38 | return self._random_scales(img_points)
39 | else:
40 | return self._uniform_scales(img_points, scale)
41 |
42 | def _stride_scaler(self, tile_ratio, stride_scaler):
43 | return np.interp(tile_ratio, [0.05, 0.15], [1.0, stride_scaler])
44 |
45 | def load(self):
46 | # don't create anything, PatchEmbeddingDataloader will create itself
47 | cache_info_path = self.cache_path.with_suffix(".info")
48 |
49 | # check if cache exists
50 | if not cache_info_path.exists():
51 | raise FileNotFoundError
52 |
53 | # if config is different, remove all cached content
54 | with open(cache_info_path, "r") as f:
55 | cfg = json.loads(f.read())
56 | if cfg != self.cfg:
57 | for f in os.listdir(self.cache_path):
58 | os.remove(os.path.join(self.cache_path, f))
59 | raise ValueError("Config mismatch")
60 |
61 | raise FileNotFoundError # trigger create
62 |
63 | def create(self, image_list):
64 | os.makedirs(self.cache_path, exist_ok=True)
65 | for i, tr in enumerate(tqdm(self.tile_sizes, desc="Scales")):
66 | stride_scaler = self.strider_scaler_list[i]
67 | self.data_dict[i] = PatchEmbeddingDataloader(
68 | cfg={
69 | "tile_ratio": tr.item(),
70 | "stride_ratio": stride_scaler,
71 | "image_shape": self.cfg["image_shape"],
72 | "model_name": self.cfg["model_name"],
73 | },
74 | device=self.device,
75 | model=self.model,
76 | image_list=image_list,
77 | cache_path=Path(f"{self.cache_path}/level_{i}.npy"),
78 | )
79 | print(image_list.shape)
80 |
81 | def save(self):
82 | cache_info_path = self.cache_path.with_suffix(".info")
83 | with open(cache_info_path, "w") as f:
84 | f.write(json.dumps(self.cfg))
85 | # don't save anything, PatchEmbeddingDataloader will save itself
86 | pass
87 |
88 | def _random_scales(self, img_points):
89 | # img_points: (B, 3) # (img_ind, x, y)
90 | # return: (B, 512), some random scale (between 0, 1)
91 | img_points = img_points.to(self.device)
92 | random_scale_bin = torch.randint(self.tile_sizes.shape[0] - 1, size=(img_points.shape[0],), device=self.device)
93 | random_scale_weight = torch.rand(img_points.shape[0], dtype=torch.float16, device=self.device)
94 |
95 | stepsize = (self.tile_sizes[1] - self.tile_sizes[0]) / (self.tile_sizes[-1] - self.tile_sizes[0])
96 |
97 | bottom_interp = torch.zeros((img_points.shape[0], self.embed_size), dtype=torch.float16, device=self.device)
98 | top_interp = torch.zeros((img_points.shape[0], self.embed_size), dtype=torch.float16, device=self.device)
99 |
100 | for i in range(len(self.tile_sizes) - 1):
101 | ids = img_points[random_scale_bin == i]
102 | bottom_interp[random_scale_bin == i] = self.data_dict[i](ids)
103 | top_interp[random_scale_bin == i] = self.data_dict[i + 1](ids)
104 |
105 | return (
106 | torch.lerp(bottom_interp, top_interp, random_scale_weight[..., None]),
107 | (random_scale_bin * stepsize + random_scale_weight * stepsize)[..., None],
108 | )
109 |
110 | def _uniform_scales(self, img_points, scale):
111 | # img_points: (B, 3) # (img_ind, x, y)
112 | scale_bin = torch.floor(
113 | (scale - self.tile_sizes[0]) / (self.tile_sizes[-1] - self.tile_sizes[0]) * (self.tile_sizes.shape[0] - 1)
114 | ).to(torch.int64)
115 | scale_weight = (scale - self.tile_sizes[scale_bin]) / (
116 | self.tile_sizes[scale_bin + 1] - self.tile_sizes[scale_bin]
117 | )
118 | interp_lst = torch.stack([interp(img_points) for interp in self.data_dict.values()])
119 | point_inds = torch.arange(img_points.shape[0])
120 | interp = torch.lerp(
121 | interp_lst[scale_bin, point_inds],
122 | interp_lst[scale_bin + 1, point_inds],
123 | torch.Tensor([scale_weight]).half().to(self.device)[..., None],
124 | )
125 | return interp / interp.norm(dim=-1, keepdim=True), scale
--------------------------------------------------------------------------------
/pogs/encoders/image_encoder.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod, abstractproperty
2 | from dataclasses import dataclass, field
3 | from typing import Type
4 |
5 | import torch
6 | from torch import nn
7 | from pogs.configs import base_config as cfg
8 |
9 | @dataclass
10 | class BaseImageEncoderConfig(cfg.InstantiateConfig):
11 | _target: Type = field(default_factory=lambda: BaseImageEncoder)
12 |
13 |
14 | class BaseImageEncoder(nn.Module):
15 | @abstractproperty
16 | def name(self) -> str:
17 | """
18 | returns the name of the encoder
19 | """
20 |
21 | @abstractproperty
22 | def embedding_dim(self) -> int:
23 | """
24 | returns the dimension of the embeddings
25 | """
26 |
27 | @abstractmethod
28 | def encode_image(self, input: torch.Tensor) -> torch.Tensor:
29 | """
30 | Given a batch of input images, return their encodings
31 | """
32 |
33 | @abstractmethod
34 | def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor:
35 | """
36 | Given a batch of embeddings, return the relevancy to the given positive id
37 | """
--------------------------------------------------------------------------------
/pogs/encoders/openclip_encoder.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Tuple, Type
3 |
4 | import torch
5 | import torchvision
6 |
7 | try:
8 | import open_clip
9 | except ImportError:
10 | assert False, "open_clip is not installed, install it with `pip install open-clip-torch`"
11 |
12 | from pogs.encoders.image_encoder import (BaseImageEncoder,
13 | BaseImageEncoderConfig)
14 | from nerfstudio.viewer.viewer_elements import ViewerText
15 |
16 |
17 | @dataclass
18 | class OpenCLIPNetworkConfig(BaseImageEncoderConfig):
19 | _target: Type = field(default_factory=lambda: OpenCLIPNetwork)
20 | clip_model_type: str = "ViT-B-16"
21 | clip_model_pretrained: str = "laion2b_s34b_b88k"
22 | clip_n_dims: int = 512
23 | negatives: Tuple[str] = ("object", "things", "stuff", "texture")
24 | device: str = 'cuda:0'
25 |
26 | @property
27 | def name(self) -> str:
28 | return "openclip_{}_{}".format(self.clip_model_type, self.clip_model_pretrained)
29 |
30 |
31 | class OpenCLIPNetwork(BaseImageEncoder):
32 | def __init__(self, config: OpenCLIPNetworkConfig):
33 | super().__init__()
34 | self.config = config
35 | self.process = torchvision.transforms.Compose(
36 | [
37 | torchvision.transforms.Resize((224, 224)),
38 | torchvision.transforms.Normalize(
39 | mean=[0.48145466, 0.4578275, 0.40821073],
40 | std=[0.26862954, 0.26130258, 0.27577711],
41 | ),
42 | ]
43 | )
44 | model, _, _ = open_clip.create_model_and_transforms(
45 | self.config.clip_model_type, # e.g., ViT-B-16
46 | pretrained=self.config.clip_model_pretrained, # e.g., laion2b_s34b_b88k
47 | precision="fp16",
48 | device=self.config.device,
49 | # device='cuda:1',
50 | )
51 | # model.to('cuda:1')
52 | model.eval()
53 | self.tokenizer = open_clip.get_tokenizer(self.config.clip_model_type)
54 | self.model = model.to(self.config.device)
55 | self.clip_n_dims = self.config.clip_n_dims
56 |
57 | self.positive_input = ViewerText("Positives", "", cb_hook=self.gui_cb)
58 |
59 | self.positives = self.positive_input.value.split(";")
60 | self.negatives = self.config.negatives
61 | with torch.no_grad():
62 | tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(self.config.device)
63 | self.pos_embeds = model.encode_text(tok_phrases)
64 | tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to(self.config.device)
65 | self.neg_embeds = model.encode_text(tok_phrases)
66 | self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
67 | self.neg_embeds /= self.neg_embeds.norm(dim=-1, keepdim=True)
68 |
69 | assert (
70 | self.pos_embeds.shape[1] == self.neg_embeds.shape[1]
71 | ), "Positive and negative embeddings must have the same dimensionality"
72 | assert (
73 | self.pos_embeds.shape[1] == self.clip_n_dims
74 | ), "Embedding dimensionality must match the model dimensionality"
75 |
76 | @property
77 | def name(self) -> str:
78 | return "openclip_{}_{}".format(self.config.clip_model_type, self.config.clip_model_pretrained)
79 |
80 | @property
81 | def embedding_dim(self) -> int:
82 | return self.config.clip_n_dims
83 |
84 | def gui_cb(self,element):
85 | # element = element.to(self.config.device)
86 | self.set_positives(element.value.split(";"))
87 |
88 | def set_positives(self, text_list):
89 | self.positives = text_list
90 | with torch.no_grad():
91 | tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(self.config.device)
92 | self.pos_embeds = self.model.encode_text(tok_phrases)
93 | self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True)
94 | self.pos_embeds = self.pos_embeds.to(self.config.device)
95 |
96 | def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor:
97 | phrases_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0).to(self.config.device)
98 | p = phrases_embeds.to(embed.dtype) # phrases x 512
99 | embed = embed.to(p.device)
100 | output = torch.mm(embed, p.T) # rays x phrases
101 | positive_vals = output[..., positive_id : positive_id + 1] # rays x 1
102 | negative_vals = output[..., len(self.positives) :] # rays x N_phrase
103 | repeated_pos = positive_vals.repeat(1, len(self.negatives)) # rays x N_phrase
104 |
105 | sims = torch.stack((repeated_pos, negative_vals), dim=-1) # rays x N-phrase x 2
106 | softmax = torch.softmax(10 * sims, dim=-1) # rays x n-phrase x 2
107 | best_id = softmax[..., 0].argmin(dim=1) # rays x 2
108 | return torch.gather(softmax, 1, best_id[..., None, None].expand(best_id.shape[0], len(self.negatives), 2))[
109 | :, 0, :
110 | ]
111 |
112 | def encode_image(self, input):
113 | processed_input = self.process(input).half()
114 | return self.model.encode_image(processed_input)
--------------------------------------------------------------------------------
/pogs/field_components/gaussian_fieldheadnames.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | # from nerfstudio.field_components.field_heads import FieldHeadNames
3 |
4 | class GaussianFieldHeadNames(Enum):
5 | """Possible field outputs"""
6 | HASHGRID = "hashgrid"
7 | CLIP = "clip"
8 | # DINO = "dino"
9 | INSTANCE = "instance"
--------------------------------------------------------------------------------
/pogs/fields/gaussian_field.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Field for compound nerf model, adds scene contraction and image embeddings to instant ngp
17 | """
18 |
19 |
20 | from typing import Dict, Literal, Optional, Tuple
21 | import numpy as np
22 |
23 | import torch
24 | from torch import Tensor, nn
25 |
26 | from nerfstudio.cameras.rays import RaySamples
27 | from nerfstudio.data.scene_box import SceneBox
28 | from nerfstudio.field_components.activations import trunc_exp
29 | from nerfstudio.field_components.embedding import Embedding
30 | from nerfstudio.field_components.encodings import HashEncoding, NeRFEncoding, SHEncoding
31 | from pogs.field_components.gaussian_fieldheadnames import GaussianFieldHeadNames
32 |
33 | from nerfstudio.field_components.mlp import MLP
34 | from nerfstudio.field_components.spatial_distortions import SpatialDistortion, SceneContraction
35 | from nerfstudio.fields.base_field import Field, get_normalized_directions
36 |
37 | try:
38 | import tinycudann as tcnn
39 | except ImportError:
40 | pass
41 |
42 |
43 | class GaussianField(Field):
44 | """Compound Field that uses TCNN
45 |
46 | Args:
47 | aabb: parameters of scene aabb bounds
48 | num_images: number of images in the dataset
49 | num_layers: number of hidden layers
50 | hidden_dim: dimension of hidden layers
51 | geo_feat_dim: output geo feat dimensions
52 | num_levels: number of levels of the hashmap for the base mlp
53 | base_res: base resolution of the hashmap for the base mlp
54 | max_res: maximum resolution of the hashmap for the base mlp
55 | log2_hashmap_size: size of the hashmap for the base mlp
56 | num_layers_color: number of hidden layers for color network
57 | num_layers_transient: number of hidden layers for transient network
58 | features_per_level: number of features per level for the hashgrid
59 | hidden_dim_color: dimension of hidden layers for color network
60 | hidden_dim_transient: dimension of hidden layers for transient network
61 | appearance_embedding_dim: dimension of appearance embedding
62 | transient_embedding_dim: dimension of transient embedding
63 | use_transient_embedding: whether to use transient embedding
64 | use_semantics: whether to use semantic segmentation
65 | num_semantic_classes: number of semantic classes
66 | use_pred_normals: whether to use predicted normals
67 | use_average_appearance_embedding: whether to use average appearance embedding or zeros for inference
68 | spatial_distortion: spatial distortion to apply to the scene
69 | """
70 |
71 | def __init__(
72 | self,
73 | num_levels: int = 16,
74 | base_res: int = 16,
75 | max_res: int = 2048,
76 | log2_hashmap_size: int = 19,
77 | features_per_level: int = 2,
78 | implementation: Literal["tcnn", "torch"] = "tcnn",
79 | grid_layers: Tuple[int] = (12, 12),
80 | grid_sizes: Tuple[Tuple[int]] = (19, 19),
81 | grid_resolutions: Tuple[int] = ((16, 128), (128, 512)),
82 | n_features_level: int = 4,
83 | clip_n_dims: int = 512,
84 |
85 | feature_dims: int = 64,
86 | ) -> None:
87 | super().__init__()
88 |
89 | self.spatial_distortion: SceneContraction = SceneContraction()
90 |
91 | self.register_buffer("max_res", torch.tensor(max_res))
92 | self.register_buffer("num_levels", torch.tensor(num_levels))
93 | self.register_buffer("log2_hashmap_size", torch.tensor(log2_hashmap_size))
94 | self.clip_encs = torch.nn.ModuleList(
95 | [
96 | GaussianField._get_encoding(
97 | grid_resolutions[i][0], grid_resolutions[i][1], grid_layers[i], indim=3, hash_size=grid_sizes[i], features_per_level=n_features_level,
98 | ) for i in range(len(grid_layers))
99 | ]
100 | )
101 | tot_out_dims = sum([e.n_output_dims for e in self.clip_encs])
102 | instance_n_dims = 128
103 | print("Total output dims: ", tot_out_dims)
104 |
105 | self.clip_net = tcnn.Network(
106 | n_input_dims=tot_out_dims+1,
107 | n_output_dims=clip_n_dims,
108 | network_config={
109 | "otype": "CutlassMLP",
110 | "activation": "ReLU",
111 | "output_activation": "None",
112 | "n_neurons": 256,
113 | "n_hidden_layers": 3,
114 | },
115 | )
116 |
117 | self.instance_net = tcnn.Network(
118 | n_input_dims=tot_out_dims,
119 | n_output_dims=instance_n_dims,
120 | network_config={
121 | "otype": "CutlassMLP",
122 | "activation": "ReLU",
123 | "output_activation": "None",
124 | "n_neurons": 256,
125 | "n_hidden_layers": 4,
126 | },
127 | )
128 |
129 |
130 | @staticmethod
131 | def _get_encoding(start_res, end_res, levels, indim=3, hash_size=19, features_per_level=8):
132 | growth = np.exp((np.log(end_res) - np.log(start_res)) / (levels - 1))
133 | enc = tcnn.Encoding(
134 | n_input_dims=indim,
135 | encoding_config={
136 | "otype": "HashGrid",
137 | "n_levels": levels,
138 | "n_features_per_level": features_per_level,
139 | "log2_hashmap_size": hash_size,
140 | "base_resolution": start_res,
141 | "per_level_scale": growth,
142 | },
143 | )
144 | return enc
145 |
146 | def get_outputs(self, positions, clip_scales) -> Dict[GaussianFieldHeadNames, Tensor]:
147 | # random scales, one scale
148 | positions = self.spatial_distortion(positions)
149 | positions = (positions + 2.0) / 4.0
150 |
151 | outputs = {}
152 | xs = [e(positions.view(-1, 3)) for e in self.clip_encs]
153 | x = torch.concat(xs, dim=-1)
154 |
155 | outputs[GaussianFieldHeadNames.HASHGRID] = x.view(positions.shape[0], -1)
156 |
157 | clip_pass = self.clip_net(torch.cat([x, clip_scales.view(-1, 1)], dim=-1)).view(positions.shape[0], -1)
158 |
159 | outputs[GaussianFieldHeadNames.CLIP] = (clip_pass / clip_pass.norm(dim=-1, keepdim=True)).to(torch.float32)
160 |
161 | epsilon = 1e-5
162 | instance_pass = self.instance_net(x).view(positions.shape[0], -1)
163 | outputs[GaussianFieldHeadNames.INSTANCE] = instance_pass / (instance_pass.norm(dim=-1, keepdim=True) + epsilon)
164 |
165 | return outputs
166 |
167 | def get_hash(self, positions) -> Tensor:
168 | positions = self.spatial_distortion(positions)
169 | positions = (positions + 2.0) / 4.0
170 |
171 | encodings = [e(positions.view(-1, 3)) for e in self.clip_encs]
172 | encoding = torch.concat(encodings, dim=-1)
173 |
174 | return encoding.to(torch.float32)
175 |
176 | def get_outputs_from_feature(self, features, clip_scale, random_pixels = None) -> Dict[GaussianFieldHeadNames, Tensor]:
177 | outputs = {}
178 |
179 | #clip_features is Nx32, and clip scale is a number, I want to cat clip scale to the end of clip_features where clip scale is an int
180 | if random_pixels is not None:
181 | clip_features = features[random_pixels]
182 | else:
183 | clip_features = features
184 | clip_pass = self.clip_net(torch.cat([clip_features, clip_scale.view(-1, 1)], dim=-1))
185 |
186 | outputs[GaussianFieldHeadNames.CLIP] = (clip_pass / clip_pass.norm(dim=-1, keepdim=True)).to(torch.float32)
187 |
188 | epsilon = 1e-5
189 | instance_pass = self.instance_net(features)
190 | outputs[GaussianFieldHeadNames.INSTANCE] = instance_pass / (instance_pass.norm(dim=-1, keepdim=True) + epsilon)
191 |
192 | return outputs
193 |
194 | def get_instance_outputs_from_feature(self, features) -> Tensor:
195 | epsilon = 1e-5
196 | instance_pass = self.instance_net(features)
197 | outputs = instance_pass / (instance_pass.norm(dim=-1, keepdim=True) + epsilon)
198 |
199 | return outputs
200 |
201 | def get_clip_outputs_from_feature(self, features, clip_scale) -> Tensor:
202 | clip_pass = self.clip_net(torch.cat([features, clip_scale.view(-1, 1)], dim=-1))
203 | outputs = (clip_pass / clip_pass.norm(dim=-1, keepdim=True)).to(torch.float32)
204 | return outputs
--------------------------------------------------------------------------------
/pogs/grasping/generate_grasps_ply.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import json
4 | dir_path = os.path.dirname(os.path.realpath(__file__))
5 | contact_graspnet_path = os.path.join(dir_path,'../contact_graspnet_wrapper')
6 |
7 | sys.path.append(contact_graspnet_path)
8 | from prime_inference import modified_inference
9 | import argparse
10 | from prime_config_utils import load_config
11 | import numpy as np
12 | from prime_visualization_utils import visualize_grasps
13 | from autolab_core import RigidTransform
14 | import open3d as o3d
15 | tool_to_wrist = RigidTransform()
16 | # 0.1651 was old measurement is the measure dist from suction to
17 | # 0.1857375 Parallel Jaw gripper
18 | tool_to_wrist.translation = np.array([0, 0, 0])
19 | tool_to_wrist.from_frame = "tool"
20 | tool_to_wrist.to_frame = "wrist"
21 |
22 | segmented_ply_filepath = "/home/lifelong/sms/sms/data/utils/Detic/outputs/2024_07_22_green_tape_bowl/prime_seg_gaussians.ply"
23 | full_ply_filepath = "/home/lifelong/sms/sms/data/utils/Detic/outputs/2024_07_22_green_tape_bowl/prime_full_gaussians.ply"
24 | bounding_box_filepath = "/home/lifelong/sms/sms/data/utils/Detic/2024_07_22_green_tape_bowl/table_bounding_cube.json"
25 |
26 |
27 | def filter_noise(points, colors=None):
28 | from sklearn.cluster import DBSCAN
29 | eps = 0.005
30 | min_samples = 10
31 | dbscan = DBSCAN(eps=eps, min_samples=min_samples)
32 | labels = dbscan.fit_predict(points)
33 | filtered_pointcloud = points[labels != -1]
34 | if colors is not None:
35 | filtered_colors = colors[labels != -1]
36 | else:
37 | filtered_colors = None
38 | return filtered_pointcloud, filtered_colors
39 |
40 | def generate_grasps(seg_np_path, full_np_path, pc_bounding_box_path, ckpt_dir, z_range, K, local_regions, filter_grasps, skip_border_objects, forward_passes, segmap_id, arg_configs, save_dir):
41 |
42 | global_config = load_config(ckpt_dir, batch_size=forward_passes, arg_configs=arg_configs)
43 |
44 | print(str(global_config))
45 | print('pid: %s'%(str(os.getpid())))
46 |
47 | pred_grasps_world, scores, contact_pts, points_world, pc_colors = modified_inference(global_config, ckpt_dir, seg_np_path, full_np_path,pc_bounding_box_path, z_range=z_range,
48 | K=K, local_regions=local_regions, filter_grasps=filter_grasps, segmap_id=segmap_id,
49 | forward_passes=forward_passes, skip_border_objects=skip_border_objects,debug=False)
50 | print("GENERATED GRASPS")
51 | sorted_idxs = np.argsort(scores[0])[::-1]
52 | best_scores = {0:scores[0][sorted_idxs][:1]}
53 | best_grasps = {0:pred_grasps_world[0][sorted_idxs][:1]}
54 | best_contact_pts = {0:contact_pts[0][sorted_idxs][:1]}
55 |
56 | point_cloud_world = o3d.geometry.PointCloud()
57 | point_cloud_world.points = o3d.utility.Vector3dVector(points_world)
58 | point_cloud_world.colors = o3d.utility.Vector3dVector(pc_colors)
59 | coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
60 |
61 | final_grasp_world_frame = best_grasps[0][0]
62 | grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
63 | grasp_point_world.transform(final_grasp_world_frame)
64 | pre_grasp_tf = np.array([[1,0,0,0],
65 | [0,1,0,0],
66 | [0,0,1,-0.1],
67 | [0,0,0,1]])
68 |
69 | pre_grasp_world_frame = final_grasp_world_frame @ pre_grasp_tf
70 |
71 | pre_grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
72 |
73 | pre_grasp_point_world.transform(pre_grasp_world_frame)
74 |
75 | #o3d.visualization.draw_geometries([point_cloud_world,coordinate_frame,grasp_point_world,pre_grasp_point_world])
76 | np.save(f'{FLAGS.save_dir}/pred_grasps_world.npy', pred_grasps_world[0])
77 |
78 | np.save(f'{FLAGS.save_dir}/scores.npy', scores[0])
79 |
80 | np.save(f'{FLAGS.save_dir}/contact_pts.npy', contact_pts[0])
81 |
82 | np.save(f'{FLAGS.save_dir}/point_cloud_world.npy', points_world)
83 | np.save(f'{FLAGS.save_dir}/rgb_cloud_world.npy', pc_colors)
84 | np.save(f'{FLAGS.save_dir}/grasp_point_world.npy', final_grasp_world_frame)
85 | np.save(f'{FLAGS.save_dir}/pre_grasp_point_world.npy', pre_grasp_world_frame)
86 |
87 | return pred_grasps_world, scores[0]
88 |
89 | import pdb
90 |
91 | pdb.set_trace()
92 | world_to_cam_tf = np.array([[0,-1,0,0],
93 | [-1,0,0,0],
94 | [0,0,-1,0],
95 | [0,0,0,1]])
96 |
97 |
98 | #visualize_grasps(pc_full, best_grasps, best_scores, plot_opencv_cam=True, pc_colors=pc_colors)
99 | # Create an Open3D point cloud object
100 | point_cloud_cam = o3d.geometry.PointCloud()
101 |
102 | # Set the points and colors
103 | point_cloud_cam.points = o3d.utility.Vector3dVector(pc_full)
104 | point_cloud_cam.colors = o3d.utility.Vector3dVector(pc_colors)
105 |
106 | # Step 2: Visualize the point cloud
107 | coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
108 | grasp_point = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
109 | grasp_point.transform(best_grasps[0][0])
110 | # o3d.visualization.draw_geometries([point_cloud_cam,coordinate_frame,grasp_point])
111 |
112 | ones = np.ones((pc_full.shape[0],1))
113 | world_to_cam_tf = RigidTransform.load('/home/lifelong/sms/sms/ur5_interface/ur5_interface/calibration_outputs/world_to_extrinsic_zed_for_grasping.tf').matrix
114 | homogenous_points_cam = np.hstack((pc_full,ones))
115 | homogenous_points_world = world_to_cam_tf @ homogenous_points_cam.T
116 | points_world = homogenous_points_world[:3,:] / homogenous_points_world[3,:][np.newaxis,:]
117 | points_world = points_world.T
118 |
119 | point_cloud_world = o3d.geometry.PointCloud()
120 |
121 | # Set the points and colors
122 | point_cloud_world.points = o3d.utility.Vector3dVector(points_world)
123 | point_cloud_world.colors = o3d.utility.Vector3dVector(pc_colors)
124 | coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
125 | panda_grasp_point_to_robotiq_grasp_point = np.array([[1,0,0,0],[0,1,0,0],[0,0,1,-0.03],[0,0,0,1]]) # -0.06
126 | final_grasp_world_frame = world_to_cam_tf @ best_grasps[0][0] @ panda_grasp_point_to_robotiq_grasp_point
127 | grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
128 | grasp_point_world.transform(final_grasp_world_frame)
129 | pre_grasp_tf = np.array([[1,0,0,0],
130 | [0,1,0,0],
131 | [0,0,1,-0.1],
132 | [0,0,0,1]])
133 | pre_grasp_world_frame = final_grasp_world_frame @ pre_grasp_tf
134 | pre_grasp_point_world = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
135 | pre_grasp_point_world.transform(pre_grasp_world_frame)
136 | # o3d.visualization.draw_geometries([point_cloud_world,coordinate_frame,grasp_point_world,pre_grasp_point_world])
137 | pred_grasps_world = []
138 | for i in range(len(pred_grasps_cam[0])):
139 | grasp = world_to_cam_tf @ pred_grasps_cam[0][i] @ panda_grasp_point_to_robotiq_grasp_point
140 | pred_grasps_world.append(grasp)
141 | pred_grasps_world = np.array(pred_grasps_world)
142 | np.save(f'{FLAGS.save_dir}/pred_grasps_world.npy', pred_grasps_world)
143 | np.save(f'{FLAGS.save_dir}/scores.npy', scores[0])
144 | np.save(f'{FLAGS.save_dir}/contact_pts.npy', contact_pts[0])
145 |
146 | return pred_grasps_world, scores[0]
147 |
148 | if __name__ == "__main__":
149 | parser = argparse.ArgumentParser()
150 | parser.add_argument('--seg_np_path', default='',required=True)
151 | parser.add_argument('--full_np_path', default='',required=True)
152 | parser.add_argument('--save_dir', default='',required=True)
153 | checkpoint_dir = os.path.join(dir_path,'../dependencies/contact_graspnet/checkpoints/scene_test_2048_bs3_hor_sigma_001')
154 | parser.add_argument('--ckpt_dir', default=checkpoint_dir, help='Log dir [default: checkpoints/scene_test_2048_bs3_hor_sigma_001]')
155 | parser.add_argument('--pc_bounding_box_path', default='', help='Input data: npz/npy file with keys either "depth" & camera matrix "K" or just point cloud "pc" in meters. Optionally, a 2D "segmap"',required=True)
156 | parser.add_argument('--K', default=None, help='Flat Camera Matrix, pass as "[fx, 0, cx, 0, fy, cy, 0, 0 ,1]"')
157 | parser.add_argument('--z_range', default=None, help='Z value threshold to crop the input point cloud')
158 | parser.add_argument('--local_regions', action='store_true', default=False, help='Crop 3D local regions around given segments.')
159 | parser.add_argument('--filter_grasps', action='store_true', default=True, help='Filter grasp contacts according to segmap.')
160 | parser.add_argument('--skip_border_objects', action='store_true', default=False, help='When extracting local_regions, ignore segments at depth map boundary.')
161 | parser.add_argument('--forward_passes', type=int, default=1, help='Run multiple parallel forward passes to mesh_utils more potential contact points.')
162 | parser.add_argument('--segmap_id', type=int, default=0, help='Only return grasps of the given object id')
163 | parser.add_argument('--arg_configs', nargs="*", type=str, default=[], help='overwrite config parameters')
164 | FLAGS = parser.parse_args()
165 | generate_grasps(FLAGS.seg_np_path, FLAGS.full_np_path, FLAGS.pc_bounding_box_path, FLAGS.ckpt_dir, FLAGS.z_range, FLAGS.K, FLAGS.local_regions,
166 | FLAGS.filter_grasps, FLAGS.skip_border_objects, FLAGS.forward_passes, FLAGS.segmap_id, FLAGS.arg_configs, FLAGS.save_dir)
--------------------------------------------------------------------------------
/pogs/grasping/results/predictions_global.ply.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/pogs/grasping/results/predictions_global.ply.npz
--------------------------------------------------------------------------------
/pogs/pogs_config.py:
--------------------------------------------------------------------------------
1 | """
2 | pogs configuration file.
3 | """
4 |
5 | from nerfstudio.configs.base_config import ViewerConfig
6 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig
7 | from nerfstudio.engine.optimizers import AdamOptimizerConfig
8 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig
9 | from nerfstudio.plugins.types import MethodSpecification
10 | from nerfstudio.engine.trainer import TrainerConfig as TrainerConfigBase
11 | from pogs.encoders.openclip_encoder import OpenCLIPNetworkConfig
12 | from pogs.pogs import POGSModelConfig
13 |
14 |
15 | from pogs.pogs_pipeline import POGSPipelineConfig
16 | from pogs.data.full_images_datamanager import FullImageDatamanagerConfig, FullImageDatamanager
17 | from pogs.data.depth_dataset import DepthDataset
18 |
19 | pogs_method = MethodSpecification(
20 | config = TrainerConfigBase(
21 | method_name="pogs",
22 | steps_per_eval_image=100,
23 | steps_per_eval_batch=100,
24 | steps_per_save=1000,
25 | max_num_iterations=4000,
26 | mixed_precision=False,
27 | gradient_accumulation_steps = {'camera_opt': 100,'color':10,'shs':10, 'lerf': 3},
28 | pipeline=POGSPipelineConfig(
29 | datamanager=FullImageDatamanagerConfig(
30 | _target=FullImageDatamanager[DepthDataset], # Comment out the [DepthDataset] part to use RGB only datasets (e.g. polycam datasets)
31 | dataparser=NerfstudioDataParserConfig(load_3D_points=True, orientation_method='none', center_method='none', auto_scale_poses=False, depth_unit_scale_factor=1.0),
32 | network=OpenCLIPNetworkConfig(
33 | clip_model_type="ViT-B-16", clip_model_pretrained="laion2b_s34b_b88k", clip_n_dims=512, device='cuda:0'
34 | ),
35 | ),
36 | model=POGSModelConfig(),
37 | ),
38 | optimizers={
39 | "means": {
40 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
41 | "scheduler": ExponentialDecaySchedulerConfig(
42 | lr_final=1.6e-6,
43 | max_steps=30000,
44 | ),
45 | },
46 | "features_dc": {
47 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15),
48 | "scheduler": None,
49 | },
50 | "features_rest": {
51 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15),
52 | "scheduler": None,
53 | },
54 | "opacities": {
55 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
56 | "scheduler": None,
57 | },
58 | "scales": {
59 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
60 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-3, max_steps=30000)
61 | },
62 | "quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
63 | "camera_opt": {
64 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
65 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
66 | },
67 | "lerf": {
68 | "optimizer": AdamOptimizerConfig(lr=2.5e-3, eps=1e-15),
69 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-3, max_steps=15000),
70 | },
71 | "dino_feats": {
72 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
73 | "scheduler": ExponentialDecaySchedulerConfig(
74 | lr_final=1e-3,
75 | max_steps=6000,
76 | ),
77 | },
78 | "nn_projection": {
79 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
80 | "scheduler": ExponentialDecaySchedulerConfig(
81 | lr_final=1e-3,
82 | max_steps=6000,
83 | ),
84 | },
85 | },
86 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
87 | vis="viewer",
88 | ),
89 | description="Persistent Object Gaussian Splatting",
90 | )
--------------------------------------------------------------------------------
/pogs/scripts/track_main_demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import viser
3 | import viser.transforms as vtf
4 | import time
5 | import numpy as np
6 | import tyro
7 | from pathlib import Path
8 | from autolab_core import RigidTransform
9 | from pogs.tracking.zed import Zed
10 | from pogs.tracking.optim import Optimizer
11 | import warp as wp
12 | from pogs.encoders.openclip_encoder import OpenCLIPNetworkConfig, OpenCLIPNetwork
13 |
14 | import os
15 | import cv2
16 | dir_path = os.path.dirname(os.path.realpath(__file__))
17 | print(dir_path)
18 | WORLD_TO_ZED2 = RigidTransform.load(dir_path+"/../tracking/data/calibration_outputs/world_to_extrinsic_zed.tf")
19 | DEVICE = 'cuda:0'
20 | set_initial_frame = False
21 |
22 | def main(
23 | config_path: str = "/home/yujustin/pogs/outputs/drill/pogs/2025-02-07_004727/config.yml",
24 | offline_folder: str = '/home/yujustin/pogs/data/demonstrations/drill'
25 | ):
26 | """Quick interactive demo for object tracking.
27 |
28 | Args:
29 | config_path: Path to the nerfstudio POGS model config file.
30 | offline_folder: Path to the offline folder with images and depth
31 | """
32 | image_folder = os.path.join(offline_folder,"left")
33 | depth_folder = os.path.join(offline_folder,"depth")
34 | image_paths = sorted(os.listdir(image_folder))
35 | depth_paths = sorted(os.listdir(depth_folder))
36 |
37 | server = viser.ViserServer()
38 | wp.init()
39 | # Set up the camera.
40 |
41 | clip_encoder = OpenCLIPNetworkConfig(
42 | clip_model_type="ViT-B-16",
43 | clip_model_pretrained="laion2b_s34b_b88k",
44 | clip_n_dims=512,
45 | device=DEVICE
46 | ).setup() # OpenCLIP encoder for language querying utils
47 | assert isinstance(clip_encoder, OpenCLIPNetwork)
48 |
49 | camera_tf = WORLD_TO_ZED2
50 |
51 | # Visualize the camera.
52 | camera_frame = server.add_frame(
53 | "camera",
54 | position=camera_tf.translation, # rough alignment.
55 | wxyz=camera_tf.quaternion,
56 | show_axes=True,
57 | axes_length=0.1,
58 | axes_radius=0.005,
59 | )
60 |
61 | initial_image_path = os.path.join(image_folder,image_paths[0])
62 | initial_depth_path = os.path.join(depth_folder,depth_paths[0])
63 | img_numpy = cv2.imread(initial_image_path)
64 | depth_numpy = np.load(initial_depth_path)
65 | l = torch.from_numpy(cv2.cvtColor(img_numpy,cv2.COLOR_RGB2BGR)).to(DEVICE)
66 | depth = torch.from_numpy(depth_numpy).to(DEVICE)
67 |
68 | zedK = np.array([[1.05576221e+03, 0.00000000e+00, 9.62041199e+02],
69 | [0.00000000e+00, 1.05576221e+03, 5.61746765e+02],
70 | [0.00000000e+00, 0.00000000e+00, 1.00000000e+00]])
71 | toad_opt = Optimizer( # Initialize the optimizer
72 | Path(config_path),
73 | zedK,
74 | l.shape[1],
75 | l.shape[0],
76 | init_cam_pose=torch.from_numpy(
77 | vtf.SE3(
78 | wxyz_xyz=np.array([*camera_frame.wxyz, *camera_frame.position])
79 | ).as_matrix()[None, :3, :]
80 | ).float(),
81 | )
82 | real_frames = []
83 | rendered_rgb_frames = []
84 | part_deltas = []
85 | save_videos = True
86 | obj_label_list = [None for _ in range(toad_opt.num_groups)]
87 | initial_image_path = os.path.join(image_folder,image_paths[0])
88 | initial_depth_path = os.path.join(depth_folder,depth_paths[0])
89 | img_numpy = cv2.imread(initial_image_path)
90 | depth_numpy = np.load(initial_depth_path)
91 | l = torch.from_numpy(cv2.cvtColor(img_numpy,cv2.COLOR_RGB2BGR)).to(DEVICE)
92 | depth = torch.from_numpy(depth_numpy).to(DEVICE)
93 | toad_opt.set_frame(l,toad_opt.cam2world_ns_ds,depth)
94 |
95 | toad_opt.init_obj_pose()
96 | print("Starting main tracking loop")
97 |
98 | assert isinstance(toad_opt, Optimizer)
99 | while not toad_opt.initialized:
100 | time.sleep(0.1)
101 | if toad_opt.initialized:
102 | # start_time3 = time.time()
103 | for(img_path,depth_path) in zip(image_paths,depth_paths):
104 | full_image_path = os.path.join(image_folder,img_path)
105 | full_depth_path = os.path.join(depth_folder,depth_path)
106 | img_numpy = cv2.imread(full_image_path)
107 | depth_numpy = np.load(full_depth_path)
108 | left = torch.from_numpy(cv2.cvtColor(img_numpy,cv2.COLOR_RGB2BGR)).to(DEVICE)
109 | depth = torch.from_numpy(depth_numpy).to(DEVICE)
110 | # import pdb; pdb.set_trace
111 | start_time3 = time.time()
112 | toad_opt.set_observation(left,toad_opt.cam2world_ns,depth)
113 | print("Set observation in ", time.time()-start_time3)
114 | start_time5 = time.time()
115 | n_opt_iters = 15
116 | # with zed.raft_lock:
117 | outputs = toad_opt.step_opt(niter=n_opt_iters)
118 | print(f"{n_opt_iters} opt steps in ", time.time()-start_time5)
119 |
120 | # Add ZED img and GS render to viser
121 | rgb_img = left.cpu().numpy()
122 | for i in range(len(toad_opt.group_masks)):
123 | frame = toad_opt.optimizer.frame.roi_frames[i]
124 | xmin = frame.xmin
125 | xmax = frame.xmax
126 | ymin = frame.ymin
127 | ymax = frame.ymax
128 | rgb_img = cv2.rectangle(rgb_img, (xmin, ymin), (xmax, ymax),(255,0,0), 2)
129 |
130 | server.scene.add_image(
131 | "cam/zed_left",
132 | rgb_img,
133 | render_width=rgb_img.shape[1]/2500,
134 | render_height=rgb_img.shape[0]/2500,
135 | position = (-0.5, -0.5, 0.5),
136 | wxyz=(0, -1, 0, 0),
137 | visible=True
138 | )
139 |
140 | server.scene.add_image(
141 | "cam/gs_render",
142 | outputs["rgb"].cpu().detach().numpy(),
143 | render_width=outputs["rgb"].shape[1]/2500,
144 | render_height=outputs["rgb"].shape[0]/2500,
145 | position = (0.5, -0.5, 0.5),
146 | wxyz=(0, -1, 0, 0),
147 | visible=True
148 | )
149 |
150 | if save_videos:
151 | real_frames.append(rgb_img)
152 | rendered_rgb_frames.append(outputs["rgb"].cpu().detach().numpy())
153 |
154 | tf_list = toad_opt.get_parts2world()
155 | part_deltas.append(tf_list)
156 | for idx, tf in enumerate(tf_list):
157 | server.add_frame(
158 | f"object/group_{idx}",
159 | position=tf.translation(),
160 | wxyz=tf.rotation().wxyz,
161 | show_axes=True,
162 | axes_length=0.05,
163 | axes_radius=.001
164 | )
165 | mesh = toad_opt.toad_object.meshes[idx]
166 | server.add_mesh_trimesh(
167 | f"object/group_{idx}/mesh",
168 | mesh=mesh,
169 | )
170 | if idx == toad_opt.max_relevancy_label:
171 | obj_label_list[idx] = server.add_label(
172 | f"object/group_{idx}/label",
173 | text=toad_opt.max_relevancy_text,
174 | position = (0,0,0.05),
175 | )
176 | else:
177 | if obj_label_list[idx] is not None:
178 | obj_label_list[idx].remove()
179 |
180 | # Visualize pointcloud.
181 | start_time4 = time.time()
182 | K = torch.from_numpy(zedK).float().cuda()
183 | assert isinstance(left, torch.Tensor) and isinstance(depth, torch.Tensor)
184 | points, colors = Zed.project_depth(left, depth, K, depth_threshold=1.0, subsample=6)
185 | server.add_point_cloud(
186 | "camera/points",
187 | points=points,
188 | colors=colors,
189 | point_size=0.001,
190 | )
191 |
192 |
193 | # except KeyboardInterrupt:
194 | # # Generate videos from the frames if the user interrupts the loop with ctrl+c
195 | # frames_dict = {"real_frames": real_frames,
196 | # "rendered_rgb": rendered_rgb_frames}
197 | # timestr = generate_videos(frames_dict, fps = 5, config_path=config_path.parent)
198 |
199 | # # Save part deltas to npy file
200 | # path = config_path.parent.joinpath(f"{timestr}")
201 | # np.save(path.joinpath("part_deltas_traj.npy"), np.array(part_deltas))
202 | # exit()
203 | # except Exception as e:
204 | # print("An exception occured: ", e)
205 | # exit()
206 |
207 |
208 | if __name__ == "__main__":
209 | tyro.cli(main)
210 |
--------------------------------------------------------------------------------
/pogs/tracking/atap_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import List
4 | from pogs.pogs import POGSModel
5 | import warp as wp
6 |
7 | #https://openaccess.thecvf.com/content_CVPR_2019/papers/Barron_A_General_and_Adaptive_Robust_Loss_Function_CVPR_2019_paper.pdf
8 | @wp.func
9 | def jon_loss(x: float,alpha:float, c:float):
10 | pow_part = (((x/c)**2.0)/wp.abs(alpha-2.0) + 1.0)
11 | return (wp.abs(alpha-2.0)/alpha) * (wp.pow(pow_part,alpha/2.0) - 1.0)
12 |
13 | @wp.kernel
14 | def atap_loss(cur_means: wp.array(dtype = wp.vec3), dists: wp.array(dtype = float), ids: wp.array(dtype = int),
15 | match_ids: wp.array(dtype = int), group_ids1: wp.array(dtype = int), group_ids2: wp.array(dtype=int),
16 | connectivity_weights: wp.array(dtype = float,ndim = 2), loss: wp.array(dtype = float), alpha: float):
17 | tid = wp.tid()
18 | id1 = ids[tid]
19 | id2 = match_ids[tid]
20 | gid1 = group_ids1[tid]
21 | gid2 = group_ids2[tid]
22 | con_weight = connectivity_weights[gid1,gid2]
23 | curdist = wp.length(cur_means[id1] - cur_means[id2])
24 | loss[tid] = jon_loss(curdist - dists[tid], alpha, 0.001) * con_weight * .001
25 |
26 |
27 | class ATAPLoss:
28 | touch_radius: float = .0015
29 | N: int = 500
30 | loss_mult: float = .2
31 | loss_alpha: float = 1.0 #rule: for jointed, use 1.0 alpha, for non-jointed use 0.1 ish
32 | def __init__(self, sms_model: POGSModel, group_masks: List[torch.Tensor], group_labels: torch.Tensor, dataset_scale: float = 1.0):
33 | """
34 | Initializes the data structure to compute the loss between groups touching
35 | """
36 | self.touch_radius = self.touch_radius * dataset_scale
37 | print(f"Touch radius is {self.touch_radius}")
38 | self.sms_model = sms_model
39 | self.group_masks = group_masks
40 | self.group_labels = group_labels
41 | self.nn_info = []
42 | for grp in self.group_masks:
43 | with torch.no_grad():
44 | dists, ids, match_ids, group_ids1, group_ids = self._radius_nn(grp, self.touch_radius)
45 | self.nn_info.append((dists, ids, match_ids, group_ids1, group_ids))
46 | print(f"Group {len(self.nn_info)} has {len(ids)} neighbors")
47 | self.dists = torch.cat([x[0] for x in self.nn_info]).cuda()
48 | self.ids = torch.cat([x[1] for x in self.nn_info]).cuda().int()
49 | self.match_ids = torch.cat([x[2] for x in self.nn_info]).cuda().int()
50 | self.group_ids1 = torch.cat([x[3] for x in self.nn_info]).cuda().int()
51 | self.group_ids2 = torch.cat([x[4] for x in self.nn_info]).cuda().int()
52 | self.num_pairs = torch.cat([torch.tensor(len(x[1])).repeat(len(x[1])) for x in self.nn_info]).cuda().float()
53 |
54 |
55 | def __call__(self, connectivity_weights: torch.Tensor):
56 | """
57 | Computes the loss between groups touching
58 | connectivity_weights: a tensor of shape (num_groups,num_groups) representing the weights between each group
59 |
60 | returns: a differentiable loss
61 | """
62 | if len(self.group_masks) == 1:
63 | return torch.tensor(0.0,device='cuda')
64 | if self.dists.shape[0] == 0:
65 | return torch.tensor(0.0,device='cuda')
66 | assert connectivity_weights.shape == (len(self.group_masks),len(self.group_masks)), "connectivity weights must be a square matrix of size num_groups"
67 | loss = wp.empty(self.dists.shape[0], dtype=wp.float32, requires_grad=True, device='cuda')
68 | wp.launch(
69 | dim = self.dists.shape[0],
70 | kernel = atap_loss,
71 | inputs = [wp.from_torch(self.sms_model.gauss_params['means'],dtype=wp.vec3),wp.from_torch(self.dists),
72 | wp.from_torch(self.ids),wp.from_torch(self.match_ids),wp.from_torch(self.group_ids1),
73 | wp.from_torch(self.group_ids2),wp.from_torch(connectivity_weights),loss, self.loss_alpha]
74 | )
75 | return (wp.to_torch(loss)/self.num_pairs).sum()*self.loss_mult
76 |
77 |
78 | def _radius_nn(self, group_mask: torch.Tensor, r: float):
79 | """
80 | returns the nearest neighbors to gaussians in a group within a certain radius (and outside that group)
81 | returns -1 indices for neighbors outside the radius or within the same group
82 | """
83 | global_group_ids = torch.zeros(self.sms_model.num_points,dtype=torch.long,device='cuda')
84 | for i,grp in enumerate(self.group_masks):
85 | global_group_ids[grp] = i
86 | from cuml.neighbors import NearestNeighbors
87 | model = NearestNeighbors(n_neighbors=self.N)
88 | means = self.sms_model.means.detach().cpu().numpy()
89 | model.fit(means)
90 | dists, match_ids = model.kneighbors(means)
91 | dists, match_ids = torch.tensor(dists,dtype=torch.float32,device='cuda'),torch.tensor(match_ids,dtype=torch.long,device='cuda')
92 | dists, match_ids = dists[group_mask], match_ids[group_mask]
93 | # filter matches outside the radius
94 | match_ids[dists>r] = -1
95 | # filter out ones within same group mask
96 | match_ids[group_mask[match_ids]] = -1
97 | ids = torch.arange(self.sms_model.num_points,dtype=torch.long,device='cuda')[group_mask].unsqueeze(-1).repeat(1,self.N)
98 | #flatten all the ids/dists/match_ids
99 | ids = ids[match_ids!=-1].flatten()
100 | dists = dists[match_ids!=-1].flatten()
101 | match_ids = match_ids[match_ids!=-1].flatten()
102 | return dists, ids, match_ids, global_group_ids[ids], global_group_ids[match_ids]
--------------------------------------------------------------------------------
/pogs/tracking/data/ZED2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/pogs/tracking/data/ZED2.stl
--------------------------------------------------------------------------------
/pogs/tracking/data/ZEDM.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/pogs/tracking/data/ZEDM.stl
--------------------------------------------------------------------------------
/pogs/tracking/motion.py:
--------------------------------------------------------------------------------
1 | from ur5py.ur5 import UR5Robot
2 | import numpy as np
3 | from autolab_core import RigidTransform
4 |
5 | WRIST_TO_CAM = RigidTransform.load("/home/lifelong/sms/sms/ur5_interface/ur5_interface/calibration_outputs/wrist_to_cam.tf")
6 |
7 | class Motion:
8 | def __init__(self, robot: UR5Robot):
9 | self.robot = robot
10 | # robot = UR5Robot(gripper=1)
11 | self.clear_tcp()
12 |
13 | home_joints = np.array([0.30947089195251465, -1.2793572584735315, -2.035713497792379, -1.388848606740133, 1.5713528394699097, 0.34230729937553406])
14 | robot.move_joint(home_joints,vel=1.0,acc=0.1)
15 | world_to_wrist = robot.get_pose()
16 | world_to_wrist.from_frame = "wrist"
17 | world_to_cam = world_to_wrist * WRIST_TO_CAM
18 | proper_world_to_cam_rotation = np.array([[0,1,0],[1,0,0],[0,0,-1]])
19 | self.home_world_to_cam = RigidTransform(rotation=proper_world_to_cam_rotation,translation=world_to_cam.translation,from_frame='cam',to_frame='world')
20 | self.home_world_to_wrist = self.home_world_to_cam * WRIST_TO_CAM.inverse()
21 |
22 | self.robot.move_pose(self.home_world_to_wrist,vel=1.0,acc=0.1)
23 |
24 |
25 | def clear_tcp(self):
26 | tool_to_wrist = RigidTransform()
27 | tool_to_wrist.translation = np.array([0, 0, 0])
28 | tool_to_wrist.from_frame = "tool"
29 | tool_to_wrist.to_frame = "wrist"
30 | self.robot.set_tcp(tool_to_wrist)
31 |
32 |
--------------------------------------------------------------------------------
/pogs/tracking/observation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List, Optional, Callable, Generic, TypeVar
3 | from nerfstudio.cameras.cameras import Cameras
4 | from torchvision.transforms.functional import resize
5 | from pogs.tracking.utils import *
6 | from pogs.tracking.utils2 import *
7 | from copy import deepcopy
8 |
9 | T = TypeVar('T')
10 | class Future(Generic[T]):
11 | """
12 | A simple wrapper for deferred execution of a callable until retrieved
13 | """
14 | def __init__(self,callable):
15 | self.callable = callable
16 | self.executed = False
17 |
18 | def retrieve(self):
19 | if not self.executed:
20 | self.result = self.callable()
21 | self.executed = True
22 | return self.result
23 |
24 | class Frame:
25 | rasterize_resolution: int = 500
26 | camera: Cameras
27 | rgb: torch.Tensor
28 | metric_depth: bool
29 | _depth: Future[torch.Tensor]
30 | _dino_feats: Future[torch.Tensor]
31 | # _hand_mask: Future[torch.Tensor]
32 |
33 | @property
34 | def depth(self):
35 | return self._depth.retrieve()
36 |
37 | @property
38 | def dino_feats(self):
39 | return self._dino_feats.retrieve()
40 |
41 | # @property
42 | # def hand_mask(self):
43 | # return self._hand_mask.retrieve()
44 |
45 | @property
46 | def mask(self):
47 | return self._mask.retrieve()
48 |
49 | def __init__(self, rgb: torch.Tensor, camera: Cameras, dino_fn: Callable, metric_depth_img: Optional[torch.Tensor],
50 | xmin: Optional[float] = None, xmax: Optional[float] = None, ymin: Optional[float] = None, ymax: Optional[float] = None):
51 |
52 | self.camera = deepcopy(camera.to('cuda'))
53 |
54 | self._dino_fn = dino_fn
55 | self.rgb = resize(
56 | rgb.permute(2, 0, 1),
57 | (camera.height, camera.width),
58 | antialias=True,
59 | ).permute(1, 2, 0)
60 | self.metric_depth = metric_depth_img is not None
61 | self.obj_mask = None
62 |
63 |
64 | @torch.no_grad()
65 | def _get_depth():
66 | if metric_depth_img is not None:
67 | depth = metric_depth_img
68 | else:
69 | raise FileNotFoundError
70 | depth = resize(
71 | depth.unsqueeze(0),
72 | (camera.height, camera.width),
73 | antialias=True,
74 | ).squeeze().unsqueeze(-1)
75 | return depth
76 | self._depth = Future(_get_depth)
77 | @torch.no_grad()
78 | def _get_dino():
79 | dino_feats = dino_fn(
80 | rgb.permute(2, 0, 1).unsqueeze(0)
81 | ).squeeze()
82 | dino_feats = resize(
83 | dino_feats.permute(2, 0, 1),
84 | (camera.height, camera.width),
85 | antialias=True,
86 | ).permute(1, 2, 0)
87 | return dino_feats
88 | self._dino_feats = Future(_get_dino)
89 | # @torch.no_grad()
90 | # def _get_hand_mask():
91 | # hand_mask = get_hand_mask((self.rgb * 255).to(torch.uint8))
92 | # hand_mask = (
93 | # torch.nn.functional.max_pool2d(
94 | # hand_mask[None, None], 3, padding=1, stride=1
95 | # ).squeeze()
96 | # == 0.0
97 | # )
98 | # return hand_mask
99 | # self._hand_mask = Future(_get_hand_mask)
100 | @torch.no_grad()
101 | def _get_mask():
102 | obj_mask = resize(
103 | self.obj_mask.unsqueeze(0),
104 | (camera.height, camera.width),
105 | antialias=True,
106 | ).squeeze(0)
107 | return obj_mask
108 | self._mask = Future(_get_mask)
109 | self.xmin, self.xmax, self.ymin, self.ymax = xmin, xmax, ymin, ymax
110 |
111 |
112 |
113 | class PosedObservation:
114 | """
115 | Class for computing relevant data products for a frame and storing them
116 | """
117 | max_roi_resolution: int = 490
118 | _frame: Frame
119 | _raw_rgb: torch.Tensor
120 | _original_camera: Cameras
121 | _original_depth: Optional[torch.Tensor] = None
122 | _roi_frames: Optional[List[Frame]] = None
123 |
124 | def __init__(self, rgb: torch.Tensor, camera: Cameras, dino_fn: Callable, metric_depth_img: Optional[torch.Tensor] = None):
125 | """
126 | Initialize the frame
127 |
128 | rgb: HxWx3 tensor of the rgb image, normalized to [0,1]
129 | camera: Cameras object for the camera intrinsics and extrisics to render the frame at
130 | dino_fn: callable taking in 3HxW RGB image and outputting dino features CxHxW
131 | metric_depth_img: HxWx1 tensor of metric depth, if desired
132 | """
133 | assert rgb.shape[0] == camera.height and rgb.shape[1] == camera.width, f"Input image should be the same size as the camera, got {rgb.shape} and {camera.height}x{camera.width}"
134 | self._dino_fn = dino_fn
135 | assert rgb.shape[-1] == 3, rgb.shape
136 | self._raw_rgb = rgb
137 | if metric_depth_img is not None:
138 | self._original_depth = metric_depth_img
139 | self._original_camera = deepcopy(camera.to('cuda'))
140 | cam = deepcopy(camera.to('cuda'))
141 |
142 | self._frame = Frame(rgb, cam, dino_fn, metric_depth_img)
143 | self._roi_frames = []
144 | self._obj_masks = None
145 |
146 |
147 | @property
148 | def frame(self):
149 | return self._frame
150 |
151 | @property
152 | def roi_frames(self):
153 | if len(self._roi_frames) == 0:
154 | raise RuntimeError("ROIs not set")
155 | return self._roi_frames
156 |
157 | def add_roi(self, xmin, xmax, ymin, ymax):
158 | assert xmin < xmax and ymin < ymax
159 | assert xmin >= 0 and ymin >= 0
160 | assert xmax <= 1.0 and ymax <= 1.0, "xmin and ymin should be normalized"
161 | # convert normalized to pixels in original image
162 | xmin,xmax,ymin,ymax = int(xmin*(self._original_camera.width-1)), int(xmax*(self._original_camera.width-1)),\
163 | int(ymin*(self._original_camera.height-1)), int(ymax*(self._original_camera.height-1))
164 | # adjust these value to be multiples of 14, dino patch size
165 | xlen = ((xmax - xmin)//14) * 14
166 | ylen = ((ymax - ymin)//14) * 14
167 | xmax = xmin + xlen
168 | ymax = ymin + ylen
169 | rgb = self._raw_rgb[ymin:ymax, xmin:xmax].clone()
170 | camera = crop_camera(self._original_camera, xmin, xmax, ymin, ymax)
171 | if max(camera.width.item(),camera.height.item()) > self.max_roi_resolution:
172 | camera.rescale_output_resolution(self.max_roi_resolution/max(camera.width.item(),camera.height.item()))
173 | depth = self._original_depth[ymin:ymax, xmin:xmax].clone().squeeze(-1)
174 | self._roi_frames.append(Frame(rgb, camera, self._dino_fn, depth, xmin, xmax, ymin, ymax))
175 |
176 | def update_roi(self, idx, xmin, xmax, ymin, ymax):
177 | assert len(self._roi_frames) > idx
178 | assert xmin < xmax and ymin < ymax
179 | assert xmin >= 0 and ymin >= 0
180 | assert xmax <= 1.0 and ymax <= 1.0, "xmin and ymin should be normalized"
181 | # convert normalized to pixels in original image
182 | xmin,xmax,ymin,ymax = int(xmin*(self._original_camera.width-1)), int(xmax*(self._original_camera.width-1)),\
183 | int(ymin*(self._original_camera.height-1)), int(ymax*(self._original_camera.height-1))
184 | # adjust these value to be multiples of 14, dino patch size
185 | xlen = ((xmax - xmin)//14) * 14
186 | ylen = ((ymax - ymin)//14) * 14
187 | xmax = xmin + xlen
188 | ymax = ymin + ylen
189 | rgb = self._raw_rgb[ymin:ymax, xmin:xmax].clone()
190 | camera = crop_camera(self._original_camera, xmin, xmax, ymin, ymax)
191 | if max(camera.width.item(),camera.height.item()) > self.max_roi_resolution:
192 | camera.rescale_output_resolution(self.max_roi_resolution/max(camera.width.item(),camera.height.item()))
193 | depth = self._original_depth[ymin:ymax, xmin:xmax].clone().squeeze(-1)
194 |
195 | self._roi_frames[idx] = Frame(rgb, camera, self._dino_fn, depth, xmin, xmax, ymin, ymax)
196 | if len(self._obj_masks) > 0:
197 |
198 | self._roi_frames[idx].obj_mask = self._obj_masks[idx].squeeze(0)[ymin:ymax, xmin:xmax].clone()
--------------------------------------------------------------------------------
/pogs/tracking/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | """Rigid transforms implemented in PyTorch, ported from jaxlie."""
2 |
3 | from . import utils as utils
4 | from ._base import MatrixLieGroup as MatrixLieGroup
5 | from ._base import SEBase as SEBase
6 | from ._base import SOBase as SOBase
7 | from ._se3 import SE3 as SE3
8 | from ._so3 import SO3 as SO3
--------------------------------------------------------------------------------
/pogs/tracking/transforms/_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload
3 |
4 | import numpy as onp
5 | import torch
6 | from torch import Tensor
7 | from typing_extensions import Self, final, override
8 |
9 | GroupType = TypeVar("GroupType", bound="MatrixLieGroup")
10 | SEGroupType = TypeVar("SEGroupType", bound="SEBase")
11 |
12 |
13 | class MatrixLieGroup(abc.ABC):
14 | """Interface definition for matrix Lie groups."""
15 |
16 | # Class properties.
17 | # > These will be set in `_utils.register_lie_group()`.
18 |
19 | matrix_dim: ClassVar[int]
20 | """Dimension of square matrix output from `.as_matrix()`."""
21 |
22 | parameters_dim: ClassVar[int]
23 | """Dimension of underlying parameters, `.parameters()`."""
24 |
25 | tangent_dim: ClassVar[int]
26 | """Dimension of tangent space."""
27 |
28 | space_dim: ClassVar[int]
29 | """Dimension of coordinates that can be transformed."""
30 |
31 | def __init__(
32 | # Notes:
33 | # - For the constructor signature to be consistent with subclasses, `parameters`
34 | # should be marked as positional-only. But this isn't possible in Python 3.7.
35 | # - This method is implicitly overriden by the dataclass decorator and
36 | # should _not_ be marked abstract.
37 | self,
38 | parameters: Tensor,
39 | ):
40 | """Construct a group object from its underlying parameters."""
41 | raise NotImplementedError()
42 |
43 | # Shared implementations.
44 |
45 | @overload
46 | def __matmul__(self: GroupType, other: GroupType) -> GroupType: ...
47 |
48 | @overload
49 | def __matmul__(self, other: Tensor) -> Tensor: ...
50 |
51 | def __matmul__(
52 | self: GroupType, other: Union[GroupType, Tensor]
53 | ) -> Union[GroupType, Tensor]:
54 | """Overload for the `@` operator.
55 |
56 | Switches between the group action (`.apply()`) and multiplication
57 | (`.multiply()`) based on the type of `other`.
58 | """
59 | if isinstance(other, (onp.ndarray, Tensor)):
60 | return self.apply(target=other)
61 | elif isinstance(other, MatrixLieGroup):
62 | assert self.space_dim == other.space_dim
63 | return self.multiply(other=other)
64 | else:
65 | assert False, f"Invalid argument type for `@` operator: {type(other)}"
66 |
67 | # Factory.
68 |
69 | @classmethod
70 | @abc.abstractmethod
71 | def identity(
72 | cls: Type[GroupType], device: Union[torch.device, str], dtype: torch.dtype
73 | ) -> GroupType:
74 | """Returns identity element.
75 |
76 | Returns:
77 | Identity element.
78 | """
79 |
80 | @classmethod
81 | @abc.abstractmethod
82 | def from_matrix(cls: Type[GroupType], matrix: Tensor) -> GroupType:
83 | """Get group member from matrix representation.
84 |
85 | Args:
86 | matrix: Matrix representaiton.
87 |
88 | Returns:
89 | Group member.
90 | """
91 |
92 | # Accessors.
93 |
94 | @abc.abstractmethod
95 | def as_matrix(self) -> Tensor:
96 | """Get transformation as a matrix. Homogeneous for SE groups."""
97 |
98 | @abc.abstractmethod
99 | def parameters(self) -> Tensor:
100 | """Get underlying representation."""
101 |
102 | # Operations.
103 |
104 | @abc.abstractmethod
105 | def apply(self, target: Tensor) -> Tensor:
106 | """Applies group action to a point.
107 |
108 | Args:
109 | target: Point to transform.
110 |
111 | Returns:
112 | Transformed point.
113 | """
114 |
115 | @abc.abstractmethod
116 | def multiply(self: Self, other: Self) -> Self:
117 | """Composes this transformation with another.
118 |
119 | Returns:
120 | self @ other
121 | """
122 |
123 | @classmethod
124 | @abc.abstractmethod
125 | def exp(cls: Type[GroupType], tangent: Tensor) -> GroupType:
126 | """Computes `expm(wedge(tangent))`.
127 |
128 | Args:
129 | tangent: Tangent vector to take the exponential of.
130 |
131 | Returns:
132 | Output.
133 | """
134 |
135 | @abc.abstractmethod
136 | def log(self) -> Tensor:
137 | """Computes `vee(logm(transformation matrix))`.
138 |
139 | Returns:
140 | Output. Shape should be `(tangent_dim,)`.
141 | """
142 |
143 | @abc.abstractmethod
144 | def adjoint(self) -> Tensor:
145 | """Computes the adjoint, which transforms tangent vectors between tangent
146 | spaces.
147 |
148 | More precisely, for a transform `GroupType`:
149 | ```
150 | GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType
151 | ```
152 |
153 | In robotics, typically used for transforming twists, wrenches, and Jacobians
154 | across different reference frames.
155 |
156 | Returns:
157 | Output. Shape should be `(tangent_dim, tangent_dim)`.
158 | """
159 |
160 | @abc.abstractmethod
161 | def inverse(self: GroupType) -> GroupType:
162 | """Computes the inverse of our transform.
163 |
164 | Returns:
165 | Output.
166 | """
167 |
168 | @abc.abstractmethod
169 | def normalize(self: GroupType) -> GroupType:
170 | """Normalize/projects values and returns.
171 |
172 | Returns:
173 | GroupType: Normalized group member.
174 | """
175 |
176 | # @classmethod
177 | # @abc.abstractmethod
178 | # def sample_uniform(cls: Type[GroupType], key: Tensor) -> GroupType:
179 | # """Draw a uniform sample from the group. Translations (if applicable) are in the
180 | # range [-1, 1].
181 | #
182 | # Args:
183 | # key: PRNG key, as returned by `jax.random.PRNGKey()`.
184 | #
185 | # Returns:
186 | # Sampled group member.
187 | # """
188 |
189 | def get_batch_axes(self) -> Tuple[int, ...]:
190 | """Return any leading batch axes in contained parameters. If an array of shape
191 | `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
192 | return `(100,)`."""
193 | return self.parameters().shape[:-1]
194 |
195 |
196 | class SOBase(MatrixLieGroup):
197 | """Base class for special orthogonal groups."""
198 |
199 |
200 | ContainedSOType = TypeVar("ContainedSOType", bound=SOBase)
201 |
202 |
203 | class SEBase(Generic[ContainedSOType], MatrixLieGroup):
204 | """Base class for special Euclidean groups.
205 |
206 | Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional
207 | translation vector.
208 | """
209 |
210 | # SE-specific interface.
211 |
212 | @classmethod
213 | @abc.abstractmethod
214 | def from_rotation_and_translation(
215 | cls: Type[SEGroupType],
216 | rotation: ContainedSOType,
217 | translation: Tensor,
218 | ) -> SEGroupType:
219 | """Construct a rigid transform from a rotation and a translation.
220 |
221 | Args:
222 | rotation: Rotation term.
223 | translation: Translation term.
224 |
225 | Returns:
226 | Constructed transformation.
227 | """
228 |
229 | @final
230 | @classmethod
231 | def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType:
232 | return cls.from_rotation_and_translation(
233 | rotation=rotation,
234 | translation=rotation.parameters().new_zeros(
235 | (*rotation.parameters().shape[:-1], cls.space_dim),
236 | dtype=rotation.parameters().dtype,
237 | ),
238 | )
239 |
240 | @abc.abstractmethod
241 | def rotation(self) -> ContainedSOType:
242 | """Returns a transform's rotation term."""
243 |
244 | @abc.abstractmethod
245 | def translation(self) -> Tensor:
246 | """Returns a transform's translation term."""
247 |
248 | # Overrides.
249 |
250 | @final
251 | @override
252 | def apply(self, target: Tensor) -> Tensor:
253 | return self.rotation() @ target + self.translation() # type: ignore
254 |
255 | @final
256 | @override
257 | def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType:
258 | return type(self).from_rotation_and_translation(
259 | rotation=self.rotation() @ other.rotation(),
260 | translation=(self.rotation() @ other.translation()) + self.translation(),
261 | )
262 |
263 | @final
264 | @override
265 | def inverse(self: SEGroupType) -> SEGroupType:
266 | R_inv = self.rotation().inverse()
267 | return type(self).from_rotation_and_translation(
268 | rotation=R_inv,
269 | translation=-(R_inv @ self.translation()),
270 | )
271 |
272 | @final
273 | @override
274 | def normalize(self: SEGroupType) -> SEGroupType:
275 | return type(self).from_rotation_and_translation(
276 | rotation=self.rotation().normalize(),
277 | translation=self.translation(),
278 | )
--------------------------------------------------------------------------------
/pogs/tracking/transforms/_se3.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import cast
5 |
6 | import numpy as np
7 | import torch
8 | from torch import Tensor
9 | from typing_extensions import Union, override
10 |
11 | from . import _base
12 | from ._so3 import SO3
13 | from .utils import get_epsilon, register_lie_group
14 |
15 |
16 | def _skew(omega: Tensor) -> Tensor:
17 | """
18 | Returns the skew-symmetric form of a length-3 vector.
19 | :param omega (*, 3)
20 | :returns (*, 3, 3)
21 | """
22 |
23 | wx, wy, wz = omega.unbind(dim=-1)
24 | o = torch.zeros_like(wx)
25 | return torch.stack(
26 | [o, -wz, wy, wz, o, -wx, -wy, wx, o],
27 | dim=-1,
28 | ).reshape(*wx.shape, 3, 3)
29 |
30 |
31 | @register_lie_group(
32 | matrix_dim=4,
33 | parameters_dim=7,
34 | tangent_dim=6,
35 | space_dim=3,
36 | )
37 | @dataclass(frozen=True)
38 | class SE3(_base.SEBase[SO3]):
39 | """Special Euclidean group for proper rigid transforms in 3D.
40 |
41 | Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization
42 | is `(vx, vy, vz, omega_x, omega_y, omega_z)`.
43 | """
44 |
45 | # SE3-specific.
46 |
47 | wxyz_xyz: Tensor
48 | """Internal parameters. wxyz quaternion followed by xyz translation."""
49 |
50 | @override
51 | def __repr__(self) -> str:
52 | quat = np.round(self.wxyz_xyz[..., :4].numpy(force=True), 5)
53 | trans = np.round(self.wxyz_xyz[..., 4:].numpy(force=True), 5)
54 | return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})"
55 |
56 | # SE-specific.
57 |
58 | @classmethod
59 | @override
60 | def from_rotation_and_translation(
61 | cls,
62 | rotation: SO3,
63 | translation: Tensor,
64 | ) -> SE3:
65 | assert translation.shape[-1] == 3
66 | return SE3(wxyz_xyz=torch.cat([rotation.wxyz, translation], dim=-1))
67 |
68 | @override
69 | def rotation(self) -> SO3:
70 | return SO3(wxyz=self.wxyz_xyz[..., :4])
71 |
72 | @override
73 | def translation(self) -> Tensor:
74 | return self.wxyz_xyz[..., 4:]
75 |
76 | # Factory.
77 |
78 | @classmethod
79 | @override
80 | def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SE3:
81 | return SE3(
82 | wxyz_xyz=torch.tensor(
83 | [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device=device, dtype=dtype
84 | )
85 | )
86 |
87 | @classmethod
88 | @override
89 | def from_matrix(cls, matrix: Tensor) -> SE3:
90 | assert matrix.shape[-2:] == (4, 4)
91 | # Currently assumes bottom row is [0, 0, 0, 1].
92 | return SE3.from_rotation_and_translation(
93 | rotation=SO3.from_matrix(matrix[..., :3, :3]),
94 | translation=matrix[..., :3, 3],
95 | )
96 |
97 | # Accessors.
98 |
99 | @override
100 | def as_matrix(self) -> Tensor:
101 | R = self.rotation().as_matrix() # (*, 3, 3)
102 | t = self.translation().unsqueeze(-1) # (*, 3, 1)
103 | dims = R.shape[:-2]
104 | bottom = (
105 | torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
106 | .reshape(*(1,) * len(dims), 1, 4)
107 | .repeat(*dims, 1, 1)
108 | )
109 | return torch.cat([torch.cat([R, t], dim=-1), bottom], dim=-2)
110 |
111 | @override
112 | def parameters(self) -> Tensor:
113 | return self.wxyz_xyz
114 |
115 | # Operations.
116 |
117 | @classmethod
118 | @override
119 | def exp(cls, tangent: Tensor) -> SE3:
120 | # Reference:
121 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761
122 |
123 | # (x, y, z, omega_x, omega_y, omega_z)
124 | assert tangent.shape[-1] == 6
125 |
126 | rotation = SO3.exp(tangent[..., 3:])
127 |
128 | theta_squared = torch.square(tangent[..., 3:]).sum(dim=-1) # (*)
129 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype)
130 |
131 | theta_squared_safe = cast(
132 | Tensor,
133 | torch.where(
134 | use_taylor,
135 | 1.0, # Any non-zero value should do here.
136 | theta_squared,
137 | ),
138 | )
139 | del theta_squared
140 | theta_safe = torch.sqrt(theta_squared_safe)
141 |
142 | skew_omega = _skew(tangent[..., 3:])
143 | dtype = skew_omega.dtype
144 | device = skew_omega.device
145 | V = torch.where(
146 | use_taylor[..., None, None],
147 | rotation.as_matrix(),
148 | (
149 | torch.eye(3, device=device, dtype=dtype)
150 | + ((1.0 - torch.cos(theta_safe)) / (theta_squared_safe))[
151 | ..., None, None
152 | ]
153 | * skew_omega
154 | + (
155 | (theta_safe - torch.sin(theta_safe))
156 | / (theta_squared_safe * theta_safe)
157 | )[..., None, None]
158 | * torch.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
159 | ),
160 | )
161 |
162 | return SE3.from_rotation_and_translation(
163 | rotation=rotation,
164 | translation=torch.einsum("...ij,...j->...i", V, tangent[..., :3]),
165 | )
166 |
167 |
168 | @override
169 | def log(self) -> Tensor:
170 | # Reference:
171 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223
172 | omega = self.rotation().log()
173 | theta_squared = torch.square(omega).sum(dim=-1) # (*)
174 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype)
175 |
176 | skew_omega = _skew(omega)
177 |
178 | # Shim to avoid NaNs in jnp.where branches, which cause failures for
179 | # reverse-mode AD.
180 | theta_squared_safe = torch.where(
181 | use_taylor,
182 | 1.0, # Any non-zero value should do here.
183 | theta_squared,
184 | )
185 | del theta_squared
186 | theta_safe = torch.sqrt(theta_squared_safe)
187 | half_theta_safe = theta_safe / 2.0
188 |
189 | dtype = omega.dtype
190 | device = omega.device
191 | V_inv = torch.where(
192 | use_taylor[..., None, None],
193 | torch.eye(3, device=device, dtype=dtype)
194 | - 0.5 * skew_omega
195 | + torch.matmul(skew_omega, skew_omega) / 12.0,
196 | (
197 | torch.eye(3, device=device, dtype=dtype)
198 | - 0.5 * skew_omega
199 | + (
200 | 1.0
201 | - theta_safe
202 | * torch.cos(half_theta_safe)
203 | / (2.0 * torch.sin(half_theta_safe))
204 | )[..., None, None]
205 | / theta_squared_safe[..., None, None]
206 | * torch.matmul(skew_omega, skew_omega)
207 | ),
208 | )
209 | return torch.cat(
210 | [torch.einsum("...ij,...j->...i", V_inv, self.translation()), omega], dim=-1
211 | )
212 |
213 | @override
214 | def adjoint(self) -> Tensor:
215 | R = self.rotation().as_matrix()
216 | dims = R.shape[:-2]
217 | # (*, 6, 6)
218 | return torch.cat(
219 | [
220 | torch.cat([R, torch.matmul(_skew(self.translation()), R)], dim=-1),
221 | torch.cat([torch.zeros((*dims, 3, 3)), R], dim=-1),
222 | ],
223 | dim=-2,
224 | )
--------------------------------------------------------------------------------
/pogs/tracking/transforms/_so3.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import Union
5 |
6 | import numpy as np
7 | import torch
8 | from torch import Tensor
9 | from typing_extensions import override
10 |
11 | from . import _base
12 | from .utils import get_epsilon, register_lie_group
13 |
14 |
15 | @register_lie_group(
16 | matrix_dim=3,
17 | parameters_dim=4,
18 | tangent_dim=3,
19 | space_dim=3,
20 | )
21 | @dataclass(frozen=True)
22 | class SO3(_base.SOBase):
23 | """Special orthogonal group for 3D rotations.
24 |
25 | Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is
26 | `(omega_x, omega_y, omega_z)`.
27 | """
28 |
29 | # SO3-specific.
30 |
31 | wxyz: Tensor
32 | """Internal parameters. `(w, x, y, z)` quaternion."""
33 |
34 | @override
35 | def __repr__(self) -> str:
36 | wxyz = np.round(self.wxyz.numpy(force=True), 5)
37 | return f"{self.__class__.__name__}(wxyz={wxyz})"
38 |
39 | @staticmethod
40 | def from_x_radians(theta: Tensor) -> SO3:
41 | """Generates a x-axis rotation.
42 |
43 | Args:
44 | angle: X rotation, in radians.
45 |
46 | Returns:
47 | Output.
48 | """
49 | zeros = torch.zeros_like(theta)
50 | return SO3.exp(torch.stack([theta, zeros, zeros], dim=-1))
51 |
52 | @staticmethod
53 | def from_y_radians(theta: Tensor) -> SO3:
54 | """Generates a y-axis rotation.
55 |
56 | Args:
57 | angle: Y rotation, in radians.
58 |
59 | Returns:
60 | Output.
61 | """
62 | zeros = torch.zeros_like(theta)
63 | return SO3.exp(torch.stack([zeros, theta, zeros], dim=-1))
64 |
65 | @staticmethod
66 | def from_z_radians(theta: Tensor) -> SO3:
67 | """Generates a z-axis rotation.
68 |
69 | Args:
70 | angle: Z rotation, in radians.
71 |
72 | Returns:
73 | Output.
74 | """
75 | zeros = torch.zeros_like(theta)
76 | return SO3.exp(torch.stack([zeros, zeros, theta], dim=-1))
77 |
78 | @staticmethod
79 | def from_rpy_radians(
80 | roll: Tensor,
81 | pitch: Tensor,
82 | yaw: Tensor,
83 | ) -> SO3:
84 | """Generates a transform from a set of Euler angles. Uses the ZYX mobile robot
85 | convention.
86 |
87 | Args:
88 | roll: X rotation, in radians. Applied first.
89 | pitch: Y rotation, in radians. Applied second.
90 | yaw: Z rotation, in radians. Applied last.
91 |
92 | Returns:
93 | Output.
94 | """
95 | return (
96 | SO3.from_z_radians(yaw)
97 | @ SO3.from_y_radians(pitch)
98 | @ SO3.from_x_radians(roll)
99 | )
100 |
101 | @staticmethod
102 | def from_quaternion_xyzw(xyzw: Tensor) -> SO3:
103 | """Construct a rotation from an `xyzw` quaternion.
104 |
105 | Note that `wxyz` quaternions can be constructed using the default dataclass
106 | constructor.
107 |
108 | Args:
109 | xyzw: xyzw quaternion. Shape should be (4,).
110 |
111 | Returns:
112 | Output.
113 | """
114 | assert xyzw.shape == (4,)
115 | return SO3(torch.roll(xyzw, shifts=1, dims=-1))
116 |
117 | def as_quaternion_xyzw(self) -> Tensor:
118 | """Grab parameters as xyzw quaternion."""
119 | return torch.roll(self.wxyz, shifts=-1, dims=-1)
120 |
121 | # Factory.
122 |
123 | @classmethod
124 | @override
125 | def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SO3:
126 | return SO3(wxyz=torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype))
127 |
128 | @classmethod
129 | @override
130 | def from_matrix(cls, matrix: Tensor) -> SO3:
131 | assert matrix.shape[-2:] == (3, 3)
132 |
133 | # Modified from:
134 | # > "Converting a Rotation Matrix to a Quaternion" from Mike Day
135 | # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf
136 |
137 | def case0(m):
138 | t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2]
139 | q = torch.stack(
140 | [
141 | m[..., 2, 1] - m[..., 1, 2],
142 | t,
143 | m[..., 1, 0] + m[..., 0, 1],
144 | m[..., 0, 2] + m[..., 2, 0],
145 | ],
146 | dim=-1,
147 | )
148 | return t, q
149 |
150 | def case1(m):
151 | t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2]
152 | q = torch.stack(
153 | [
154 | m[..., 0, 2] - m[..., 2, 0],
155 | m[..., 1, 0] + m[..., 0, 1],
156 | t,
157 | m[..., 2, 1] + m[..., 1, 2],
158 | ],
159 | dim=-1,
160 | )
161 | return t, q
162 |
163 | def case2(m):
164 | t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2]
165 | q = torch.stack(
166 | [
167 | m[..., 1, 0] - m[..., 0, 1],
168 | m[..., 0, 2] + m[..., 2, 0],
169 | m[..., 2, 1] + m[..., 1, 2],
170 | t,
171 | ],
172 | dim=-1,
173 | )
174 | return t, q
175 |
176 | def case3(m):
177 | t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2]
178 | q = torch.stack(
179 | [
180 | t,
181 | m[..., 2, 1] - m[..., 1, 2],
182 | m[..., 0, 2] - m[..., 2, 0],
183 | m[..., 1, 0] - m[..., 0, 1],
184 | ],
185 | dim=-1,
186 | )
187 | return t, q
188 |
189 | # Compute four cases, then pick the most precise one.
190 | # Probably worth revisiting this!
191 | case0_t, case0_q = case0(matrix)
192 | case1_t, case1_q = case1(matrix)
193 | case2_t, case2_q = case2(matrix)
194 | case3_t, case3_q = case3(matrix)
195 |
196 | cond0 = matrix[..., 2, 2] < 0
197 | cond1 = matrix[..., 0, 0] > matrix[..., 1, 1]
198 | cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1]
199 |
200 | t = torch.where(
201 | cond0,
202 | torch.where(cond1, case0_t, case1_t),
203 | torch.where(cond2, case2_t, case3_t),
204 | )
205 | q = torch.where(
206 | cond0[..., None],
207 | torch.where(cond1[..., None], case0_q, case1_q),
208 | torch.where(cond2[..., None], case2_q, case3_q),
209 | )
210 | return SO3(wxyz=q * 0.5 / torch.sqrt(t[..., None]))
211 |
212 | # Accessors.
213 |
214 | @override
215 | def as_matrix(self) -> Tensor:
216 | norm_sq = torch.square(self.wxyz).sum(dim=-1, keepdim=True)
217 | qvec = self.wxyz * torch.sqrt(2.0 / norm_sq) # (*, 4)
218 | Q = torch.einsum("...i,...j->...ij", qvec, qvec) # (*, 4, 4)
219 | return torch.stack(
220 | [
221 | 1.0 - Q[..., 2, 2] - Q[..., 3, 3],
222 | Q[..., 1, 2] - Q[..., 3, 0],
223 | Q[..., 1, 3] + Q[..., 2, 0],
224 | Q[..., 1, 2] + Q[..., 3, 0],
225 | 1.0 - Q[..., 1, 1] - Q[..., 3, 3],
226 | Q[..., 2, 3] - Q[..., 1, 0],
227 | Q[..., 1, 3] - Q[..., 2, 0],
228 | Q[..., 2, 3] + Q[..., 1, 0],
229 | 1.0 - Q[..., 1, 1] - Q[..., 2, 2],
230 | ],
231 | dim=-1,
232 | ).reshape(*qvec.shape[:-1], 3, 3)
233 |
234 | @override
235 | def parameters(self) -> Tensor:
236 | return self.wxyz
237 |
238 | # Operations.
239 |
240 | @override
241 | def apply(self, target: Tensor) -> Tensor:
242 | assert target.shape[-1] == 3
243 |
244 | # Compute using quaternion multiplys.
245 | padded_target = torch.cat([torch.ones_like(target[..., :1]), target], dim=-1)
246 | out = self.multiply(SO3(wxyz=padded_target).multiply(self.inverse()))
247 | return out.wxyz[..., 1:]
248 |
249 | @override
250 | def multiply(self, other: SO3) -> SO3:
251 | w0, x0, y0, z0 = self.wxyz.unbind(dim=-1)
252 | w1, x1, y1, z1 = other.wxyz.unbind(dim=-1)
253 | wxyz2 = torch.stack(
254 | [
255 | -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,
256 | x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,
257 | -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,
258 | x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,
259 | ],
260 | dim=-1,
261 | )
262 |
263 | return SO3(wxyz=wxyz2)
264 |
265 | @classmethod
266 | @override
267 | def exp(cls, tangent: Tensor) -> SO3:
268 | # Reference:
269 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583
270 |
271 | assert tangent.shape[-1] == 3
272 |
273 | theta_squared = torch.square(tangent).sum(dim=-1) # (*)
274 | theta_pow_4 = theta_squared * theta_squared
275 | use_taylor = theta_squared < get_epsilon(tangent.dtype)
276 |
277 | safe_theta = torch.sqrt(
278 | torch.where(
279 | use_taylor,
280 | torch.ones_like(theta_squared), # Any constant value should do here.
281 | theta_squared,
282 | )
283 | )
284 | safe_half_theta = 0.5 * safe_theta
285 |
286 | real_factor = torch.where(
287 | use_taylor,
288 | 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0,
289 | torch.cos(safe_half_theta),
290 | )
291 |
292 | imaginary_factor = torch.where(
293 | use_taylor,
294 | 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0,
295 | torch.sin(safe_half_theta) / safe_theta,
296 | )
297 |
298 | return SO3(
299 | wxyz=torch.cat(
300 | [
301 | real_factor[..., None],
302 | imaginary_factor[..., None] * tangent,
303 | ],
304 | dim=-1,
305 | )
306 | )
307 |
308 | @override
309 | def log(self) -> Tensor:
310 | # Reference:
311 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247
312 |
313 | w, xyz = torch.split(self.wxyz, [1, 3], dim=-1) # (*, 1), (*, 3)
314 | norm_sq = torch.square(xyz).sum(dim=-1, keepdim=True) # (*, 1)
315 | use_taylor = norm_sq < get_epsilon(norm_sq.dtype)
316 |
317 | norm_safe = torch.sqrt(
318 | torch.where(
319 | use_taylor,
320 | torch.ones_like(norm_sq), # Any non-zero value should do here.
321 | norm_sq,
322 | )
323 | )
324 | w_safe = torch.where(use_taylor, w, torch.ones_like(w))
325 | atan_n_over_w = torch.atan2(
326 | torch.where(w < 0, -norm_safe, norm_safe),
327 | torch.abs(w),
328 | )
329 | atan_factor = torch.where(
330 | use_taylor,
331 | 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3,
332 | torch.where(
333 | torch.abs(w) < get_epsilon(w.dtype),
334 | torch.where(w > 0, 1.0, -1.0) * torch.pi / norm_safe,
335 | 2.0 * atan_n_over_w / norm_safe,
336 | ),
337 | )
338 |
339 | return atan_factor * xyz
340 |
341 | @override
342 | def adjoint(self) -> Tensor:
343 | return self.as_matrix()
344 |
345 | @override
346 | def inverse(self) -> SO3:
347 | # Negate complex terms.
348 | w, xyz = torch.split(self.wxyz, [1, 3], dim=-1)
349 | return SO3(wxyz=torch.cat([w, -xyz], dim=-1))
350 |
351 | @override
352 | def normalize(self) -> SO3:
353 | return SO3(wxyz=self.wxyz / torch.linalg.norm(self.wxyz, dim=-1, keepdim=True))
--------------------------------------------------------------------------------
/pogs/tracking/transforms/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ._utils import get_epsilon, register_lie_group
2 |
3 | __all__ = ["get_epsilon", "register_lie_group"]
--------------------------------------------------------------------------------
/pogs/tracking/transforms/utils/_utils.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Callable, Type, TypeVar
2 |
3 | import torch
4 |
5 | if TYPE_CHECKING:
6 | from .._base import MatrixLieGroup
7 |
8 |
9 | T = TypeVar("T", bound="MatrixLieGroup")
10 |
11 |
12 | def get_epsilon(dtype: torch.dtype) -> float:
13 | """Helper for grabbing type-specific precision constants.
14 |
15 | Args:
16 | dtype: Datatype.
17 |
18 | Returns:
19 | Output float.
20 | """
21 | return {
22 | torch.float32: 1e-5,
23 | torch.float64: 1e-10,
24 | }[dtype]
25 |
26 |
27 | def register_lie_group(
28 | *,
29 | matrix_dim: int,
30 | parameters_dim: int,
31 | tangent_dim: int,
32 | space_dim: int,
33 | ) -> Callable[[Type[T]], Type[T]]:
34 | """Decorator for registering Lie group dataclasses.
35 |
36 | Sets dimensionality class variables.
37 | """
38 |
39 | def _wrap(cls: Type[T]) -> Type[T]:
40 | # Register dimensions as class attributes.
41 | cls.matrix_dim = matrix_dim
42 | cls.parameters_dim = parameters_dim
43 | cls.tangent_dim = tangent_dim
44 | cls.space_dim = space_dim
45 |
46 | return cls
47 |
48 | return _wrap
--------------------------------------------------------------------------------
/pogs/tracking/tri_zed.py:
--------------------------------------------------------------------------------
1 | import pyzed.sl as sl
2 | from typing import Optional, Tuple
3 | import torch
4 | import numpy as np
5 | from threading import Lock
6 | import plotly
7 | from plotly import express as px
8 | import trimesh
9 | from pathlib import Path
10 | # from raftstereo.raft_stereo import *
11 | from autolab_core import RigidTransform
12 | import torch
13 | import torch.nn.functional as tfn
14 | import os
15 | import numpy as np
16 | import pathlib
17 | from matplotlib import pyplot as plt
18 | torch._C._jit_set_profiling_executor(False)
19 |
20 | def is_tensor(data):
21 | return type(data) == torch.Tensor
22 |
23 | def is_tuple(data):
24 | return isinstance(data, tuple)
25 |
26 | def is_list(data):
27 | return isinstance(data, list) or isinstance(data, torch.nn.ModuleList)
28 |
29 | def is_dict(data):
30 | return isinstance(data, dict) or isinstance(data, torch.nn.ModuleDict)
31 |
32 | def is_seq(data):
33 | return is_tuple(data) or is_list(data)
34 |
35 | def iterate1(func):
36 | """Decorator to iterate over a list (first argument)"""
37 | def inner(var, *args, **kwargs):
38 | if is_seq(var):
39 | return [func(v, *args, **kwargs) for v in var]
40 | elif is_dict(var):
41 | return {key: func(val, *args, **kwargs) for key, val in var.items()}
42 | else:
43 | return func(var, *args, **kwargs)
44 | return inner
45 |
46 |
47 | @iterate1
48 | def interpolate(tensor, size, scale_factor, mode):
49 | if size is None and scale_factor is None:
50 | return tensor
51 | if is_tensor(size):
52 | size = size.shape[-2:]
53 | return tfn.interpolate(
54 | tensor, size=size, scale_factor=scale_factor,
55 | recompute_scale_factor=False, mode=mode,
56 | align_corners=None,
57 | )
58 |
59 |
60 | def resize_input(
61 | rgb: torch.Tensor,
62 | intrinsics: torch.Tensor = None,
63 | resize: tuple = None
64 | ):
65 | """Resizes input data
66 |
67 | Args:
68 | rgb (torch.Tensor): input image (B,3,H,W)
69 | intrinsics (torch.Tensor): camera intrinsics (B,3,3)
70 | resize (tuple, optional): resize shape. Defaults to None.
71 |
72 | Returns:
73 | rgb: resized image (B,3,h,w)
74 | intrinsics: resized intrinsics (B,3,3)
75 | """
76 |
77 | # Don't resize if not requested
78 | if resize is None:
79 | if intrinsics is None:
80 | return rgb
81 | else:
82 | return rgb, intrinsics
83 | # Resize rgb
84 | orig_shape = [float(v) for v in rgb.shape[-2:]]
85 | rgb = interpolate(rgb, mode="bilinear", scale_factor=None, size=resize)
86 | # Return only rgb if there are no intrinsics
87 | if intrinsics is None:
88 | return rgb
89 | # Resize intrinsics
90 | shape = [float(v) for v in rgb.shape[-2:]]
91 | intrinsics = intrinsics.clone()
92 | intrinsics[:, 0] *= shape[1] / orig_shape[1]
93 | intrinsics[:, 1] *= shape[0] / orig_shape[0]
94 | # return resized input
95 | return rgb, intrinsics
96 |
97 | def format_image(rgb):
98 | return torch.tensor(rgb.transpose(2,0,1)[None]).to(torch.float32).cuda() / 255.0
99 | class StereoModel(torch.nn.Module):
100 | """Learned Stereo model.
101 |
102 | Takes as input two images plus intrinsics and outputs a metrically scaled depth map.
103 |
104 | Taken from: https://github.com/ToyotaResearchInstitute/mmt_stereo_inference
105 | Paper here: https://arxiv.org/pdf/2109.11644.pdf
106 | Authors: Krishna Shankar, Mark Tjersland, Jeremy Ma, Kevin Stone, Max Bajracharya
107 |
108 | Pre-trained checkpoint here: s3://tri-ml-models/efm/depth/stereo.pt
109 |
110 | Args:
111 | cfg (Config): configuration file to initialize the model
112 | ckpt (str, optional): checkpoint path to load a pre-trained model. Defaults to None.
113 | baseline (float): Camera baseline. Defaults to 0.12 (ZED baseline)
114 | """
115 |
116 | def __init__(self, ckpt: str = None):
117 | super().__init__()
118 | # Initialize model
119 | self.model = torch.jit.load(ckpt).cuda()
120 | self.model.eval()
121 |
122 | def inference(
123 | self,
124 | baseline: float,
125 | rgb_left: torch.Tensor,
126 | rgb_right: torch.Tensor,
127 | intrinsics: torch.Tensor,
128 | resize: tuple = None,
129 | ):
130 | """Performs inference on input data
131 |
132 | Args:
133 | rgb_left (torch.Tensor): input float32 image (B,3,H,W)
134 | rgb_right (torch.Tensor): input float32 image (B,3,H,W)
135 | intrinsics (torch.Tensor): camera intrinsics (B,3,3)
136 | resize (tuple, optional): resize shape. Defaults to None.
137 |
138 | Returns:
139 | depth: output depth map (B,1,H,W)
140 | """
141 |
142 | rgb_left, intrinsics = resize_input(
143 | rgb=rgb_left, intrinsics=intrinsics, resize=resize
144 | )
145 | rgb_right = resize_input(rgb=rgb_right, resize=resize)
146 |
147 | with torch.no_grad():
148 | output, _ = self.model(rgb_left, rgb_right)
149 |
150 | disparity_sparse = output["disparity_sparse"]
151 | mask = disparity_sparse != 0
152 | depth = torch.zeros_like(disparity_sparse)
153 | depth[mask] = baseline * intrinsics[0, 0, 0] / disparity_sparse[mask]
154 | # depth = baseline * intrinsics[0, 0, 0] / output["disparity"]
155 | rgb = (rgb_left.squeeze(0).permute(1,2,0).cpu().detach().numpy()*255).astype(np.uint8)
156 | return depth, output["disparity"], disparity_sparse,rgb
157 |
158 |
159 | class Zed():
160 | width: int
161 | """Width of the rgb/depth images."""
162 | height: int
163 | """Height of the rgb/depth images."""
164 | raft_lock: Lock
165 | """Lock for the camera, for raft-stereo depth!"""
166 |
167 | zed_mesh: trimesh.Trimesh
168 | """Trimesh of the ZED camera."""
169 | cam_to_zed: RigidTransform
170 | """Transform from left camera to ZED camera base."""
171 |
172 | def __init__(self, flip_mode, resolution, fps, cam_id=None, recording_file=None, start_time=0.0):
173 | init = sl.InitParameters()
174 | if cam_id is not None:
175 | init.set_from_serial_number(cam_id)
176 | self.cam_id = cam_id
177 | self.width = None
178 | self.debug_ = False
179 | self.height = None
180 | self.init_res = None
181 | # Set camera flip mode
182 | if flip_mode:
183 | init.camera_image_flip = sl.FLIP_MODE.ON
184 | else:
185 | init.camera_image_flip = sl.FLIP_MODE.OFF
186 |
187 | if recording_file is not None:
188 | init.set_from_svo_file(recording_file)
189 |
190 | # Configure camera resolution
191 | if resolution == '720p':
192 | init.camera_resolution = sl.RESOLUTION.HD720
193 | self.height = 720
194 | self.width = 1280
195 | self.init_res = 1280
196 | elif resolution == '1080p':
197 | init.camera_resolution = sl.RESOLUTION.HD1080
198 | self.height = 1080
199 | self.width = 1920
200 | self.init_res = 1920
201 | elif resolution == '2k':
202 | init.camera_resolution = sl.RESOLUTION.HD2k
203 | self.height = 1242
204 | self.width = 2208
205 | self.init_res = 2208
206 | else:
207 | print("Only 720p, 1080p, and 2k supported by Zed")
208 | exit()
209 | # Disable native ZED depth computation (we'll use RAFT-Stereo instead)
210 | init.depth_mode = sl.DEPTH_MODE.NONE
211 | init.sdk_verbose = 1
212 | init.camera_fps = fps
213 | self.cam = sl.Camera()
214 | init.camera_disable_self_calib = True
215 | status = self.cam.open(init)
216 | if recording_file is not None:
217 | fps = self.cam.get_camera_information().camera_configuration.fps
218 | self.cam.set_svo_position(int(start_time*fps))
219 | if status != sl.ERROR_CODE.SUCCESS: #Ensure the camera has opened succesfully
220 | print("Camera Open : "+repr(status)+". Exit program.")
221 | exit()
222 | else:
223 | print("Opened camera")
224 | res = sl.Resolution()
225 | res.width = self.width
226 | res.height = self.height
227 | left_cx = self.get_K(cam="left")[0, 2]
228 | right_cx = self.get_K(cam="right")[0, 2]
229 | self.cx_diff = right_cx - left_cx # /1920
230 | self.f_ = self.get_K(cam="left")[0,0]
231 | self.cx_ = left_cx
232 | self.cy_ = self.get_K(cam="left")[1,2]
233 | self.baseline_ = self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.get_translation().get()[0] / 1000.0
234 | self.intrinsics_ = torch.tensor([[
235 | [self.f_,0,self.cx_],
236 | [0,self.f_,self.cy_],
237 | [0,0,1]
238 | ]]).to(torch.float32).cuda()
239 | # Create lock for raft -- gpu threading messes up CUDA memory state, with curobo...
240 | self.raft_lock = Lock()
241 | self.dir_path = pathlib.Path(__file__).parent.resolve()
242 | self.stereo_ckpt = os.path.join(self.dir_path,'models/stereo_20230724.pt') #We use stereo model from this paper: https://arxiv.org/abs/2109.11644. However, you can sub this in for any realtime stereo model (including the default Zed model).
243 | with self.raft_lock:
244 | self.model = StereoModel(self.stereo_ckpt)
245 |
246 | # left_cx = self.get_K(cam='left')[0,2]
247 | # right_cx = self.get_K(cam='right')[0,2]
248 | # self.cx_diff = (right_cx-left_cx)
249 |
250 |
251 |
252 | # For visualiation.
253 |
254 | zedM_path = Path(__file__).parent / Path("data/ZEDM.stl")
255 | zed2_path = Path(__file__).parent / Path("data/ZED2.stl")
256 | self.zedM_mesh = trimesh.load(str(zedM_path))
257 | self.zed2_mesh = trimesh.load(str(zed2_path))
258 | # assert isinstance(zed_mesh, trimesh.Trimesh)
259 | self.zed_mesh = self.zed2_mesh
260 | self.cam_to_zed = RigidTransform(
261 | rotation=RigidTransform.quaternion_from_axis_angle(
262 | np.array([1, 0, 0]) * (np.pi / 2)
263 | ),
264 | translation=np.array([0.06, 0.042, -0.035]),
265 | )
266 |
267 | def prime_get_frame(self):
268 | res = sl.Resolution()
269 | res.width = self.width
270 | res.height = self.height
271 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS:
272 | left_rgb = sl.Mat()
273 | right_rgb = sl.Mat()
274 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT)
275 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT)
276 | self.height = self.height - self.height % 32
277 | self.width = self.width - self.width % 32
278 | left_cropped = np.flip(left_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy()
279 | right_cropped = np.flip(right_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy()
280 | with self.raft_lock:
281 | tridepth, disparity, disparity_sparse,cropped_rgb = self.model.inference(rgb_left=format_image(left_cropped),rgb_right=format_image(right_cropped),intrinsics=self.intrinsics_,baseline=self.baseline_)
282 | return left_cropped,right_cropped
283 | else:
284 | print("Couldn't grab frame")
285 |
286 | def get_frame(
287 | self, depth=True
288 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
289 | res = sl.Resolution()
290 | res.width = self.width
291 | res.height = self.height
292 | r = self.width/self.init_res
293 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS:
294 | left_rgb = sl.Mat()
295 | right_rgb = sl.Mat()
296 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT, sl.MEM.CPU, res)
297 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT, sl.MEM.CPU, res)
298 | self.height = self.height - self.height % 32
299 | self.width = self.width - self.width % 32
300 | left_cropped = np.flip(left_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy()
301 | right_cropped = np.flip(right_rgb.get_data()[:self.height,:self.width,:3], axis=2).copy()
302 |
303 | with self.raft_lock:
304 | tridepth, disparity, disparity_sparse,cropped_rgb = self.model.inference(rgb_left=format_image(left_cropped),rgb_right=format_image(right_cropped),intrinsics=self.intrinsics_,baseline=self.baseline_)
305 | if(self.debug_):
306 | import pdb
307 | pdb.set_trace()
308 | plt.imshow(tridepth.detach().cpu().numpy()[0,0],cmap='jet')
309 | plt.savefig('/home/lifelong/prime_raft.png')
310 | import pdb
311 | pdb.set_trace()
312 | return torch.from_numpy(left_cropped).cuda(), torch.from_numpy(right_cropped).cuda(), tridepth[0,0]
313 | elif self.cam.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED:
314 | print("End of recording file")
315 | return None,None,None
316 | else:
317 | raise RuntimeError("Could not grab frame")
318 |
319 | def get_K(self,cam='left') -> np.ndarray:
320 | calib = self.cam.get_camera_information().camera_configuration.calibration_parameters
321 | if cam=='left':
322 | intrinsics = calib.left_cam
323 | else:
324 | intrinsics = calib.right_cam
325 | r = self.width/self.init_res
326 | K = np.array([[intrinsics.fx*r, 0, intrinsics.cx*r],
327 | [0, intrinsics.fy*r, intrinsics.cy*r],
328 | [0, 0, 1]])
329 | return K
330 |
331 | def get_stereo_transform(self):
332 | transform = self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.m
333 | transform[:3,3] /= 1000#convert to meters
334 | return transform
335 |
336 | def start_record(self, out_path):
337 | recordingParameters = sl.RecordingParameters()
338 | recordingParameters.compression_mode = sl.SVO_COMPRESSION_MODE.H264
339 | recordingParameters.video_filename = out_path
340 | err = self.cam.enable_recording(recordingParameters)
341 |
342 | def stop_record(self):
343 | self.cam.disable_recording()
344 |
345 | @staticmethod
346 | def plotly_render(frame) -> plotly.graph_objs.Figure:
347 | fig = px.imshow(frame)
348 | fig.update_layout(
349 | margin=dict(l=0, r=0, t=0, b=0),
350 | showlegend=False,
351 | yaxis_visible=False,
352 | yaxis_showticklabels=False,
353 | xaxis_visible=False,
354 | xaxis_showticklabels=False,
355 | )
356 | return fig
357 |
358 | @staticmethod
359 | def project_depth(
360 | rgb: torch.Tensor,
361 | depth: torch.Tensor,
362 | K: torch.Tensor,
363 | depth_threshold: float = 1.0,
364 | subsample: int = 4,
365 | ) -> Tuple[np.ndarray, np.ndarray]:
366 | """Deproject RGBD image to point cloud, using provided intrinsics.
367 | Also threshold/subsample pointcloud for visualization speed."""
368 |
369 | img_wh = rgb.shape[:2][::-1]
370 |
371 | grid = (
372 | torch.stack(
373 | torch.meshgrid(
374 | torch.arange(img_wh[0], device="cuda"),
375 | torch.arange(img_wh[1], device="cuda"),
376 | indexing="xy",
377 | ),
378 | 2,
379 | )
380 | + 0.5
381 | )
382 |
383 | homo_grid = torch.concat(
384 | [grid, torch.ones((grid.shape[0], grid.shape[1], 1), device="cuda")],
385 | dim=2
386 | ).reshape(-1, 3)
387 | local_dirs = torch.matmul(torch.linalg.inv(K),homo_grid.T).T
388 | points = (local_dirs * depth.reshape(-1,1)).float()
389 | points = points.reshape(-1,3)
390 |
391 | mask = depth.reshape(-1, 1) <= depth_threshold
392 | points = points.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy()
393 | colors = rgb.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy()
394 |
395 | return (points, colors)
396 |
397 |
398 | import tyro
399 | def main(name: str) -> None:
400 | # def main() -> None:
401 |
402 | import torch
403 | from viser import ViserServer
404 | # zed = Zed(recording_file="exps/eyeglasses/2024-06-06_014947/traj.svo2")
405 | #jzed = Zed(recording_file="test.svo2")
406 | # import pdb; pdb.set_trace()
407 | zed = Zed(recording_file = "exps/scissors/2024-06-06_155342/traj.svo2")
408 | # zed.start_record(f"/home/chungmin/Documents/please2/toad/motion_vids/{name}.svo2")
409 | import os
410 | # os.makedirs(out_dir,exist_ok=True)
411 | i = 0
412 |
413 |
414 | #code for visualizing poincloud
415 | import viser
416 | from matplotlib import pyplot as plt
417 | import viser.transforms as tf
418 | v = ViserServer()
419 | gui_reset_up = v.add_gui_button(
420 | "Reset up direction",
421 | hint="Set the camera control 'up' direction to the current camera's 'up'.",
422 | )
423 |
424 | @gui_reset_up.on_click
425 | def _(event: viser.GuiEvent) -> None:
426 | client = event.client
427 | assert client is not None
428 | client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array(
429 | [0.0, -1.0, 0.0]
430 | )
431 | while True:
432 | left,right,depth = zed.get_frame()
433 | left = left.cpu().numpy()
434 | depth = depth.cpu().numpy()
435 | # import matplotlib.pyplot as plt
436 | # plt.imshow(left)
437 | # plt.show()
438 | K = zed.get_K()
439 | T_world_camera = np.eye(4)
440 |
441 | img_wh = left.shape[:2][::-1]
442 |
443 | grid = (
444 | np.stack(np.meshgrid(np.arange(img_wh[0]), np.arange(img_wh[1])), 2) + 0.5
445 | )
446 |
447 | homo_grid = np.concatenate([grid,np.ones((grid.shape[0],grid.shape[1],1))],axis=2).reshape(-1,3)
448 | local_dirs = np.matmul(np.linalg.inv(K),homo_grid.T).T
449 | points = (local_dirs * depth.reshape(-1,1)).astype(np.float32)
450 | points = points.reshape(-1,3)
451 | v.add_point_cloud("points", points = points.reshape(-1,3), colors=left.reshape(-1,3),point_size=.001)
452 |
453 | if __name__ == "__main__":
454 | tyro.cli(main)
--------------------------------------------------------------------------------
/pogs/tracking/utils.py:
--------------------------------------------------------------------------------
1 | import warp as wp
2 | import torch
3 | from nerfstudio.cameras.cameras import Cameras
4 | from pogs.tracking.transforms import SE3, SO3
5 |
6 | def extrapolate_poses(p1_7v, p2_7v, lam, rot_scale_factor=3.0, thresh=True):
7 | ext_7v = []
8 | for i in range(len(p2_7v)):
9 | r1 = SO3(p1_7v[i,3:])
10 | t1 = SE3.from_rotation_and_translation(r1, p1_7v[i,:3])
11 | r2 = SO3(p2_7v[i,3:])
12 | t2 = SE3.from_rotation_and_translation(r2, p2_7v[i,:3])
13 | t_2_1 = t1.inverse() @ t2
14 | delta_pos = t_2_1.translation()*lam
15 | delta_rot = SO3.exp((t_2_1.rotation().log() * lam * rot_scale_factor))
16 | if thresh and delta_pos.norm().item() < 0.05: # Threshold for small deltas to avoid oscillations
17 | new_t = t2
18 | else:
19 | new_t = (t2 @ SE3.from_rotation_and_translation(delta_rot, delta_pos))
20 | ext_7v.append(new_t.wxyz_xyz.roll(3,dims=-1))
21 | return torch.stack(ext_7v)
22 |
23 | def zero_optim_state(optimizer:torch.optim.Adam):
24 | # import pdb; pdb.set_trace()
25 | param = optimizer.param_groups[0]["params"][0]
26 | param_state = optimizer.state[param]
27 | if "max_exp_avg_sq" in param_state:
28 | # for amsgrad
29 | param_state["max_exp_avg_sq"] = torch.zeros(param_state["max_exp_avg_sq"].shape, device=param_state["max_exp_avg_sq"].device)
30 | if "exp_avg" in param_state:
31 | param_state["exp_avg"] = torch.zeros(param_state["exp_avg"].shape, device=param_state["exp_avg"].device)
32 | param_state["exp_avg_sq"] = torch.zeros(param_state["exp_avg_sq"].shape, device=param_state["exp_avg_sq"].device)
33 |
34 | def replace_in_optim(optimizer:torch.optim.Adam, new_params):
35 | """replaces the parameters in the optimizer"""
36 | param = optimizer.param_groups[0]["params"][0]
37 | param_state = optimizer.state[param]
38 |
39 | del optimizer.state[param]
40 | optimizer.state[new_params[0]] = param_state
41 | optimizer.param_groups[0]["params"] = new_params
42 | del param
43 |
44 | @wp.func
45 | def poses_7vec_to_transform(poses: wp.array(dtype=float, ndim=2), i: int):
46 | """
47 | Kernel helper for converting x y z qw qx qy qz to a wp.Transformation
48 | """
49 | position = wp.vector(poses[i,0], poses[i,1], poses[i,2])
50 | quaternion = wp.quaternion(poses[i,4], poses[i,5], poses[i,6], poses[i,3])
51 | return wp.transformation(position, quaternion)
52 |
53 | @wp.kernel
54 | def apply_to_model(
55 | # init_o2w: wp.array(dtype=float, ndim=2),
56 | init_p2ws: wp.array(dtype=float, ndim=2),
57 | # o_delta: wp.array(dtype=float, ndim=2),
58 | p_deltas: wp.array(dtype=float, ndim=2),
59 | group_labels: wp.array(dtype=int),
60 | means: wp.array(dtype=wp.vec3),
61 | quats: wp.array(dtype=float, ndim=2),
62 | #outputs
63 | means_out: wp.array(dtype=wp.vec3),
64 | quats_out: wp.array(dtype=float, ndim=2),
65 | ):
66 | """
67 | Kernel for applying the transforms to a gaussian splat
68 |
69 | [removed] init_o2w: 1x7 tensor of initial object to world poses
70 | init_p2ws: Nx7 tensor of initial pose to object poses
71 | [removed] o_delta: Nx7 tensor of object pose deltas represented as objnew_to_objoriginal
72 | p_deltas: Nx7 tensor of pose deltas represented as partnew_to_partoriginal
73 | group_labels: N, tensor of group labels (0->K-1) for K groups
74 | means: Nx3 tensor of means
75 | quats: Nx4 tensor of quaternions (wxyz)
76 | means_out: Nx3 tensor of output means
77 | quats_out: Nx4 tensor of output quaternions (wxyz)
78 | """
79 | tid = wp.tid()
80 | group_id = group_labels[tid]
81 | # o2w_T = poses_7vec_to_transform(init_o2w,0)
82 | p2w_T = poses_7vec_to_transform(init_p2ws,group_id)
83 | # odelta_T = poses_7vec_to_transform(o_delta,0)
84 | pdelta_T = poses_7vec_to_transform(p_deltas,group_id)
85 | g2w_T = wp.transformation(means[tid], wp.quaternion(quats[tid, 1], quats[tid, 2], quats[tid, 3], quats[tid, 0]))
86 | g2p_T = wp.transform_inverse(p2w_T) * g2w_T
87 | new_g2w_T = p2w_T * pdelta_T * g2p_T
88 | means_out[tid] = wp.transform_get_translation(new_g2w_T)
89 | new_quat = wp.transform_get_rotation(new_g2w_T)
90 | quats_out[tid, 0] = new_quat[3] #w
91 | quats_out[tid, 1] = new_quat[0] #x
92 | quats_out[tid, 2] = new_quat[1] #y
93 | quats_out[tid, 3] = new_quat[2] #z
94 |
95 | def identity_7vec(device='cuda'):
96 | """
97 | Returns a 7-tensor of identity pose
98 | """
99 | return torch.tensor([[0, 0, 0, 1, 0, 0, 0]], dtype=torch.float32, device=device)
100 |
101 | def normalized_quat_to_rotmat(quat):
102 | """
103 | Converts a quaternion to a 3x3 rotation matrix
104 | """
105 | assert quat.shape[-1] == 4, quat.shape
106 | w, x, y, z = torch.unbind(quat, dim=-1)
107 | mat = torch.stack(
108 | [
109 | 1 - 2 * (y**2 + z**2),
110 | 2 * (x * y - w * z),
111 | 2 * (x * z + w * y),
112 | 2 * (x * y + w * z),
113 | 1 - 2 * (x**2 + z**2),
114 | 2 * (y * z - w * x),
115 | 2 * (x * z - w * y),
116 | 2 * (y * z + w * x),
117 | 1 - 2 * (x**2 + y**2),
118 | ],
119 | dim=-1,
120 | )
121 | return mat.reshape(quat.shape[:-1] + (3, 3))
122 |
123 | def torch_posevec_to_mat(posevecs):
124 | """
125 | Converts a Nx7-tensor to Nx4x4 matrix
126 |
127 | posevecs: Nx7 tensor of pose vectors
128 | returns: Nx4x4 tensor of transformation matrices
129 | """
130 | assert posevecs.shape[-1] == 7, posevecs.shape
131 | assert len(posevecs.shape) == 2, posevecs.shape
132 | out = torch.eye(4, device=posevecs.device).unsqueeze(0).expand(posevecs.shape[0], -1, -1)
133 | out[:, :3, 3] = posevecs[:, :3]
134 | out[:, :3, :3] = normalized_quat_to_rotmat(posevecs[:, 3:])
135 | return out
136 |
137 | def mnn_matcher(feat_a, feat_b):
138 | """
139 | Returns mutual nearest neighbors between two sets of features
140 |
141 | feat_a: NxD
142 | feat_b: MxD
143 | return: K, K (indices in feat_a and feat_b)
144 | """
145 | device = feat_a.device
146 | sim = feat_a.mm(feat_b.t())
147 | nn12 = torch.max(sim, dim=1)[1]
148 | nn21 = torch.max(sim, dim=0)[1]
149 | ids1 = torch.arange(0, sim.shape[0], device=device)
150 | mask = ids1 == nn21[nn12]
151 | return ids1[mask], nn12[mask]
152 |
153 | def crop_camera(camera: Cameras, xmin, xmax, ymin, ymax):
154 | height = torch.tensor(ymax - ymin,device='cuda').view(1,1).int()
155 | width = torch.tensor(xmax - xmin,device='cuda').view(1,1).int()
156 | cx = torch.tensor(camera.cx.clone() - xmin,device='cuda').view(1,1)
157 | cy = torch.tensor(camera.cy.clone() - ymin,device='cuda').view(1,1)
158 | fx = camera.fx.clone()
159 | fy = camera.fy.clone()
160 | return Cameras(camera.camera_to_worlds.clone(), fx, fy, cx, cy, width, height)
--------------------------------------------------------------------------------
/pogs/tracking/utils2.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
6 | import moviepy as mpy
7 | import wandb
8 | from pogs.tracking.observation import PosedObservation, Frame
9 | from torchvision.transforms.functional import to_pil_image
10 | from pogs.pogs import POGSModel
11 | import time
12 | import cv2
13 |
14 |
15 | def generate_videos(frames_dict, fps=30, config_path=None):
16 | import datetime
17 | timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
18 | for key in frames_dict.keys():
19 | frames = frames_dict[key]
20 | if len(frames)>1:
21 | if frames[0].max() > 1:
22 | frames = [f for f in frames]
23 | else:
24 | frames = [f*255 for f in frames]
25 | clip = mpy.ImageSequenceClip(frames, fps=fps)
26 | if config_path is None:
27 | clip.write_videofile(f"{timestr}/{key}.mp4", codec="libx264")
28 | else:
29 | path = config_path.joinpath(f"{timestr}")
30 | if not path.exists():
31 | path.mkdir(parents=True)
32 | clip.write_videofile(str(path.joinpath(f"{key}.mp4")), codec="libx264")
33 | try:
34 | wandb.log({f"{key}": wandb.Video(str(path.joinpath(f"{key}.mp4")))})
35 | except:
36 | pass
37 | return timestr
38 |
39 |
40 | def overlay(image, mask, color, alpha, resize=None):
41 | """Combines image and its segmentation mask into a single image.
42 | https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay
43 |
44 | Params:
45 | image: Training image. np.ndarray,
46 | mask: Segmentation mask. np.ndarray,
47 | color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0)
48 | alpha: Segmentation mask's transparency. float = 0.5,
49 | resize: If provided, both image and its mask are resized before blending them together.
50 | tuple[int, int] = (1024, 1024))
51 |
52 | Returns:
53 | image_combined: The combined image. np.ndarray
54 |
55 | """
56 | color = color[::-1]
57 | colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
58 | colored_mask = np.moveaxis(colored_mask, 0, -1)
59 | masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
60 | image_overlay = masked.filled()
61 |
62 | if resize is not None:
63 | image = cv2.resize(image.transpose(1, 2, 0), resize)
64 | image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
65 |
66 | image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
67 |
68 | return image_combined
--------------------------------------------------------------------------------
/pogs/tracking/zed.py:
--------------------------------------------------------------------------------
1 | import pyzed.sl as sl
2 | from typing import Optional, Tuple
3 | import torch
4 | import numpy as np
5 | from threading import Lock
6 | import plotly
7 | from plotly import express as px
8 | import trimesh
9 | from pathlib import Path
10 | # from raftstereo.raft_stereo import *
11 | from autolab_core import RigidTransform
12 | from matplotlib import pyplot as plt
13 | class Zed():
14 | width: int
15 | """Width of the rgb/depth images."""
16 | height: int
17 | """Height of the rgb/depth images."""
18 | raft_lock: Lock
19 | """Lock for the camera, for raft-stereo depth!"""
20 |
21 | zed_mesh: trimesh.Trimesh
22 | """Trimesh of the ZED camera."""
23 | cam_to_zed: RigidTransform
24 | """Transform from left camera to ZED camera base."""
25 |
26 | def __init__(self, recording_file = None, start_time = 0.0):
27 | init = sl.InitParameters()
28 | if recording_file is not None:
29 | init.set_from_svo_file(recording_file)
30 | # disable depth
31 | init.camera_image_flip = sl.FLIP_MODE.ON
32 | init.depth_mode=sl.DEPTH_MODE.NONE
33 | init.camera_resolution = sl.RESOLUTION.HD1080
34 | init.sdk_verbose = 1
35 | init.camera_fps = 30
36 | else:
37 | init.camera_resolution = sl.RESOLUTION.HD720
38 | init.sdk_verbose = 1
39 | init.camera_fps = 30
40 | # flip camera
41 | # init.camera_image_flip = sl.FLIP_MODE.ON
42 | init.depth_mode=sl.DEPTH_MODE.NONE
43 | init.depth_minimum_distance = 100#millimeters
44 | self.init_res = 1920 if init.camera_resolution == sl.RESOLUTION.HD1080 else 1280
45 | print("INIT RES",self.init_res)
46 | self.debug_ = False
47 | self.width = 1280
48 | self.height = 720
49 | self.cam = sl.Camera()
50 | status = self.cam.open(init)
51 | if recording_file is not None:
52 | fps = self.cam.get_camera_information().camera_configuration.fps
53 | self.cam.set_svo_position(int(start_time*fps))
54 | if status != sl.ERROR_CODE.SUCCESS: #Ensure the camera has opened succesfully
55 | print("Camera Open : "+repr(status)+". Exit program.")
56 | exit()
57 | else:
58 | print("Opened camera")
59 |
60 | left_cx = self.get_K(cam="left")[0, 2]
61 | right_cx = self.get_K(cam="right")[0, 2]
62 | self.cx_diff = right_cx - left_cx # /1920
63 | self.f_ = self.get_K(cam="left")[0,0]
64 | self.cx_ = left_cx
65 | self.cy_ = self.get_K(cam="left")[1,2]
66 |
67 | # Create lock for raft -- gpu threading messes up CUDA memory state, with curobo...
68 | self.raft_lock = Lock()
69 | with self.raft_lock:
70 | self.model = create_raft()
71 |
72 | # left_cx = self.get_K(cam='left')[0,2]
73 | # right_cx = self.get_K(cam='right')[0,2]
74 | # self.cx_diff = (right_cx-left_cx)
75 |
76 | # For visualiation.
77 | zed_path = Path(__file__).parent / Path("data/ZEDM.stl")
78 | zed_mesh = trimesh.load(str(zed_path))
79 | assert isinstance(zed_mesh, trimesh.Trimesh)
80 | self.zed_mesh = zed_mesh
81 | self.cam_to_zed = RigidTransform(
82 | rotation=RigidTransform.quaternion_from_axis_angle(
83 | np.array([1, 0, 0]) * (np.pi / 2)
84 | ),
85 | translation=np.array([0.06, 0.042, -0.035]), # Numbers are for ZED 2
86 | )
87 |
88 | def get_frame(
89 | self, depth=True
90 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
91 | res = sl.Resolution()
92 | res.width = self.width
93 | res.height = self.height
94 | r = self.width/self.init_res
95 | if self.cam.grab() == sl.ERROR_CODE.SUCCESS:
96 | left_rgb = sl.Mat()
97 | right_rgb = sl.Mat()
98 | self.cam.retrieve_image(left_rgb, sl.VIEW.LEFT, sl.MEM.CPU, res)
99 | self.cam.retrieve_image(right_rgb, sl.VIEW.RIGHT, sl.MEM.CPU, res)
100 | left,right = torch.from_numpy(np.flip(left_rgb.get_data()[...,:3],axis=2).copy()).cuda(), torch.from_numpy(np.flip(right_rgb.get_data()[...,:3],axis=2).copy()).cuda()
101 | if depth:
102 | left_torch,right_torch = left.permute(2,0,1),right.permute(2,0,1)
103 | with self.raft_lock:
104 | flow = raft_inference(left_torch,right_torch,self.model)
105 | fx = self.get_K()[0,0]
106 | depth = fx*self.get_stereo_transform()[0,3]/(flow.abs()+self.cx_diff)
107 | if(self.debug_):
108 | plt.imshow(depth.detach().cpu().numpy(),cmap='jet')
109 | plt.savefig('/home/lifelong/justin_raft.png')
110 | import pdb
111 | pdb.set_trace()
112 | else:
113 | depth = None
114 | return left, right, depth
115 | elif self.cam.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED:
116 | print("End of recording file")
117 | return None,None,None
118 | else:
119 | raise RuntimeError("Could not grab frame")
120 |
121 | def get_K(self,cam='left') -> np.ndarray:
122 | calib = self.cam.get_camera_information().camera_configuration.calibration_parameters
123 | if cam=='left':
124 | intrinsics = calib.left_cam
125 | else:
126 | intrinsics = calib.right_cam
127 | r = self.width/self.init_res
128 | K = np.array([[intrinsics.fx*r, 0, intrinsics.cx*r], [0, intrinsics.fy*r, intrinsics.cy*r], [0, 0, 1]])
129 | return K
130 |
131 | def get_stereo_transform(self):
132 | transform = self.cam.get_camera_information().camera_configuration.calibration_parameters.stereo_transform.m
133 | transform[:3,3] /= 1000#convert to meters
134 | return transform
135 |
136 | def start_record(self, out_path):
137 | recordingParameters = sl.RecordingParameters()
138 | recordingParameters.compression_mode = sl.SVO_COMPRESSION_MODE.H264
139 | recordingParameters.video_filename = out_path
140 | err = self.cam.enable_recording(recordingParameters)
141 |
142 | def stop_record(self):
143 | self.cam.disable_recording()
144 |
145 | @staticmethod
146 | def plotly_render(frame) -> plotly.graph_objs.Figure:
147 | fig = px.imshow(frame)
148 | fig.update_layout(
149 | margin=dict(l=0, r=0, t=0, b=0),
150 | showlegend=False,
151 | yaxis_visible=False,
152 | yaxis_showticklabels=False,
153 | xaxis_visible=False,
154 | xaxis_showticklabels=False,
155 | )
156 | return fig
157 |
158 | @staticmethod
159 | def project_depth(
160 | rgb: torch.Tensor,
161 | depth: torch.Tensor,
162 | K: torch.Tensor,
163 | depth_threshold: float = 1.0,
164 | subsample: int = 4,
165 | ) -> Tuple[np.ndarray, np.ndarray]:
166 | """Deproject RGBD image to point cloud, using provided intrinsics.
167 | Also threshold/subsample pointcloud for visualization speed."""
168 |
169 | img_wh = rgb.shape[:2][::-1]
170 |
171 | grid = (
172 | torch.stack(
173 | torch.meshgrid(
174 | torch.arange(img_wh[0], device="cuda"),
175 | torch.arange(img_wh[1], device="cuda"),
176 | indexing="xy",
177 | ),
178 | 2,
179 | )
180 | + 0.5
181 | )
182 |
183 | homo_grid = torch.concat(
184 | [grid, torch.ones((grid.shape[0], grid.shape[1], 1), device="cuda")],
185 | dim=2
186 | ).reshape(-1, 3)
187 | local_dirs = torch.matmul(torch.linalg.inv(K),homo_grid.T).T
188 | points = (local_dirs * depth.reshape(-1,1)).float()
189 | points = points.reshape(-1,3)
190 |
191 | mask = depth.reshape(-1, 1) <= depth_threshold
192 | points = points.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy()
193 | colors = rgb.reshape(-1, 3)[mask.flatten()][::subsample].cpu().numpy()
194 |
195 | return (points, colors)
196 |
197 |
198 | import tyro
199 | def main(name: str) -> None:
200 | # def main() -> None:
201 |
202 | import torch
203 | from viser import ViserServer
204 | # zed = Zed(recording_file="exps/eyeglasses/2024-06-06_014947/traj.svo2")
205 | #jzed = Zed(recording_file="test.svo2")
206 | # import pdb; pdb.set_trace()
207 | zed = Zed(recording_file = "exps/scissors/2024-06-06_155342/traj.svo2")
208 | # zed.start_record(f"/home/chungmin/Documents/please2/toad/motion_vids/{name}.svo2")
209 | import os
210 | # os.makedirs(out_dir,exist_ok=True)
211 | i = 0
212 |
213 |
214 | #code for visualizing poincloud
215 | import viser
216 | from matplotlib import pyplot as plt
217 | import viser.transforms as tf
218 | v = ViserServer()
219 | gui_reset_up = v.add_gui_button(
220 | "Reset up direction",
221 | hint="Set the camera control 'up' direction to the current camera's 'up'.",
222 | )
223 |
224 | @gui_reset_up.on_click
225 | def _(event: viser.GuiEvent) -> None:
226 | client = event.client
227 | assert client is not None
228 | client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array(
229 | [0.0, -1.0, 0.0]
230 | )
231 | while True:
232 | left,right,depth = zed.get_frame()
233 | left = left.cpu().numpy()
234 | depth = depth.cpu().numpy()
235 | # import matplotlib.pyplot as plt
236 | # plt.imshow(left)
237 | # plt.show()
238 | K = zed.get_K()
239 | T_world_camera = np.eye(4)
240 |
241 | img_wh = left.shape[:2][::-1]
242 |
243 | grid = (
244 | np.stack(np.meshgrid(np.arange(img_wh[0]), np.arange(img_wh[1])), 2) + 0.5
245 | )
246 |
247 | homo_grid = np.concatenate([grid,np.ones((grid.shape[0],grid.shape[1],1))],axis=2).reshape(-1,3)
248 | local_dirs = np.matmul(np.linalg.inv(K),homo_grid.T).T
249 | points = (local_dirs * depth.reshape(-1,1)).astype(np.float32)
250 | points = points.reshape(-1,3)
251 | v.add_point_cloud("points", points = points.reshape(-1,3), colors=left.reshape(-1,3),point_size=.001)
252 |
253 | if __name__ == "__main__":
254 | tyro.cli(main)
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "pogs"
3 | version = "0.1.1"
4 | requires-python = ">=3.10"
5 |
6 | dependencies=[
7 | "Pillow",
8 | "jaxtyping",
9 | "rich",
10 | "open-clip-torch",
11 | "numpy==1.26.4",
12 | "torchtyping",
13 | "autolab_core",
14 | "moviepy",
15 | "kornia",
16 | "iopath",
17 | "transformers==4.44.0",
18 | "typeguard>=4.0.0",
19 | "awscli"
20 | ]
21 | [tool.setuptools]
22 | include-package-data = true
23 |
24 | [tool.setuptools.packages.find]
25 | include = ["pogs*"]
26 |
27 | [project.entry-points.'nerfstudio.method_configs']
28 | pogs = 'pogs.pogs_config:pogs_method'
--------------------------------------------------------------------------------
/scripts/dino_pca_visualization.py:
--------------------------------------------------------------------------------
1 | """ Usage example:
2 | python dino_pca_visualization.py --image_path shelf_iron.png
3 | """
4 |
5 | import PIL.Image as Image
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from sklearn.decomposition import PCA
9 | from pogs.data.utils.dino_dataloader import DinoDataloader, get_img_resolution
10 | import tyro
11 | import rich
12 | import torch
13 | from pogs.data.utils.dino_extractor import ViTExtractor
14 | from pogs.data.utils.feature_dataloader import FeatureDataloader
15 | from tqdm import tqdm
16 | from torchvision import transforms
17 | from typing import Tuple
18 |
19 | def main(
20 | image_path: str = "shelf_iron.png",
21 | dino_model_type: str = "dinov2_vitl14",
22 | dino_stride: int = 14,
23 | device: str = "cuda",
24 | keep_cuda: bool = True,
25 | ):
26 | extractor = ViTExtractor(dino_model_type, dino_stride)
27 | image = Image.open(image_path)
28 | image = np.array(image)
29 | if image.dtype == np.uint8:
30 | image = np.array(image) / 255.0
31 |
32 | if image.dtype == np.float64:
33 | image = image.astype(np.float32)
34 | h, w = get_img_resolution(image.shape[0], image.shape[1])
35 | preprocess = transforms.Compose([
36 | transforms.Resize((h,w),antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
37 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
38 | ])
39 |
40 | preproc_image = preprocess(torch.from_numpy(image).permute(2,0,1).unsqueeze(0)).to(device)
41 | dino_embeds = []
42 | for image in tqdm(preproc_image, desc="dino", total=1, leave=False):
43 | with torch.no_grad():
44 | descriptors = extractor.model.get_intermediate_layers(image.unsqueeze(0),reshape=True)[0].squeeze().permute(1,2,0)/10
45 | if keep_cuda:
46 | dino_embeds.append(descriptors)
47 | else:
48 | dino_embeds.append(descriptors.cpu().detach())
49 | out = dino_embeds[0]
50 |
51 | patch_h = out.shape[0]
52 | patch_w = out.shape[1]
53 | total_features = out.squeeze(0).squeeze(0).reshape(-1, out.shape[-1])
54 | pca = PCA(n_components=3)
55 | pca.fit(total_features.cpu().numpy())
56 | pca_features = pca.transform(total_features.cpu().numpy())
57 |
58 | # visualize PCA components for finding a proper threshold
59 | # 3 histograms for 3 components
60 | plt.subplot(2, 2, 1)
61 | plt.hist(pca_features[:, 0])
62 | plt.subplot(2, 2, 2)
63 | plt.hist(pca_features[:, 1])
64 | plt.subplot(2, 2, 3)
65 | plt.hist(pca_features[:, 2])
66 | plt.savefig(f"{image_path}_pca_hist.png")
67 |
68 | plt.clf()
69 |
70 | # Visualize PCA components
71 | pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / \
72 | (pca_features[:, 0].max() - pca_features[:, 0].min())
73 |
74 | plt.imshow(pca_features[0 : patch_h*patch_w, 0].reshape(patch_h, patch_w), cmap='gist_rainbow')
75 | plt.axis('off')
76 |
77 | plt.savefig(f"{image_path}_pca.png", bbox_inches='tight', pad_inches = 0)
78 | plt.margins(0,0)
79 | plt.show()
80 |
81 | if __name__ == "__main__":
82 | tyro.cli(main)
--------------------------------------------------------------------------------
/scripts/shelf_iron.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/scripts/shelf_iron.png
--------------------------------------------------------------------------------
/scripts/shelf_iron.png_pca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/scripts/shelf_iron.png_pca.png
--------------------------------------------------------------------------------
/scripts/shelf_iron.png_pca_hist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uynitsuj/POGS/3d29f04febd34be68f4eb35e8baf5252adda011e/scripts/shelf_iron.png_pca_hist.png
--------------------------------------------------------------------------------