├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── object_table.png └── system_diagram.png ├── poetry.lock ├── pyproject.toml ├── run_asset_generation.py ├── run_data_collection.py ├── scalable_real2sim ├── __init__.py ├── data_processing │ ├── __init__.py │ ├── alpha_channel.py │ ├── colmap.py │ ├── frosting.py │ ├── image_subsampling.py │ ├── masks.py │ ├── nerfstudio.py │ └── neuralangelo.py ├── output │ ├── __init__.py │ ├── canonicalize.py │ └── sdformat.py └── segmentation │ ├── __init__.py │ ├── detect_object.py │ └── segment_moving_object_data.py └── scripts ├── compute_geometric_metrics.py ├── segment_moving_obj_data.py └── subtract_masks.py /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .venv_nerfstudio 3 | data 4 | checkpoints 5 | __pycache__ 6 | outputs 7 | .vscode 8 | tests -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "scalable_real2sim/robot_payload_id"] 2 | path = scalable_real2sim/robot_payload_id 3 | url = git@github.com:nepfaff/robot_payload_id.git 4 | [submodule "scalable_real2sim/neuralangelo"] 5 | path = scalable_real2sim/neuralangelo 6 | url = git@github.com:evelyn-fu/neuralangelo.git 7 | branch = object_masking 8 | [submodule "scalable_real2sim/Frosting"] 9 | path = scalable_real2sim/Frosting 10 | url = git@github.com:nepfaff/Frosting.git 11 | [submodule "scalable_real2sim/BundleSDF"] 12 | path = scalable_real2sim/BundleSDF 13 | url = git@github.com:nepfaff/BundleSDF.git 14 | [submodule "scalable_real2sim/pickplace_data_collection"] 15 | path = scalable_real2sim/pickplace_data_collection 16 | url = git@github.com:evelyn-fu/pickplace_data_collection.git 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.9.1 4 | hooks: 5 | - id: black 6 | language_version: python3.10 7 | 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | name: isort (python) 13 | 14 | - repo: https://github.com/floatingpurr/sync_with_poetry 15 | rev: 1.1.0 16 | hooks: 17 | - id: sync_with_poetry 18 | args: [ 19 | '--all', # Scan all dependencies in poetry.lock (main and dev) 20 | ] 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Nicholas Pfaff, Evelyn Fu, Jeremy Binagia, Phillip Isola, Russ Tedrake 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scalable Real2Sim: Physics-Aware Asset Generation Via Robotic Pick-and-Place Setups 2 | Code for "Scalable Real2Sim: Physics-Aware Asset Generation Via Robotic Pick-and-Place Setups" 3 | 4 | See our [project website](https://scalable-real2sim.github.io/) for more details. The paper is available on [arXiv](https://arxiv.org/abs/2503.00370). 5 | 6 | An overview of our pipeline: 7 | ![pipeline](assets/system_diagram.png) 8 | 9 | Some of our reconstructed assets: 10 | ![result_teaser](assets/object_table.png) 11 | 12 | ## Installation 13 | 14 | This repo uses Poetry for dependency management. To setup this project, first install 15 | [Poetry](https://python-poetry.org/docs/#installation) and, make sure to have Python3.10 16 | installed on your system. 17 | 18 | Then, configure poetry to setup a virtual environment within the project directory: 19 | ```bash 20 | poetry config virtualenvs.in-project true 21 | ``` 22 | 23 | Fetch the submodules: 24 | ```bash 25 | git submodule init && git submodule update 26 | ``` 27 | 28 | Next, install all the required dependencies to the virtual environment with the 29 | following command: 30 | ```bash 31 | poetry install 32 | ``` 33 | 34 | Install BundleSDF (this may take a while & will create a .venv in the BundleSDF directory): 35 | ```bash 36 | cd scalable_real2sim/BundleSDF/ && bash setup.bash 37 | ``` 38 | 39 | Download pretrained weights of LoFTR outdoor_ds.ckpt, and put it under 40 | `scalable_real2sim/BundleSDF/BundleTrack/LoFTR/weights/outdoor_ds.ckpt`. 41 | 42 | Install Colmap. See [here](https://colmap.github.io/install.html) for instructions. 43 | 44 | Create a separate environment for Nerfstudio and install it: 45 | ```bash 46 | python -m venv .venv_nerfstudio && \ 47 | source .venv_nerfstudio/bin/activate && \ 48 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 && \ 49 | pip install pip==23.0.1 && \ 50 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch && \ 51 | pip install git+https://github.com/nerfstudio-project/nerfstudio.git 52 | ``` 53 | Note that the old pip version is required for `tiny-cuda-nn` to work. 54 | 55 | Install Frosting (this will create a .venv in the Frosting directory): 56 | ```bash 57 | cd scalable_real2sim/Frosting/ && bash setup.bash 58 | ``` 59 | 60 | Install Neuralangelo (this will create a .venv in the Neuralangelo directory): 61 | ```bash 62 | cd scalable_real2sim/neuralangelo/ && bash setup.bash \ 63 | && git submodule init && git submodule update 64 | ``` 65 | 66 | ## Code Organization 67 | 68 | This repo is more of a collection of submodules. Yet, there are two main entrypoints, 69 | one for data collection and one for asset generation. The data collection one is 70 | robot setup dependent and might thus not run out of the box. The asset generation should 71 | be runnable as is, provided all dependencies are installed correctly. 72 | 73 | ### Pipeline Entrypoints 74 | 75 | The real2sim pipeline is split into two main scripts: 76 | - `run_data_collection.py` 77 | - `run_asset_generation.py` 78 | 79 | #### run_data_collection.py 80 | 81 | Script for collecting all the data needed for asset generation. Runs in a loop until 82 | the first bin is empty. 83 | 84 | #### run_asset_generation.py 85 | 86 | Script for generating assets from the data collected by `run_data_collection.py`. 87 | 88 | Having this separate from the data collection script allows the robot to collect data 89 | non-stop without having to pause for computationally intensive data processing. 90 | The asset generation could then, for example, happen in the cloud. 91 | 92 | ### Scripts 93 | 94 | Contains additional scripts that might be useful. For example, this includes scripts 95 | for computing evaluation metrics. 96 | 97 | #### compute_geometric_metrics.py 98 | 99 | Script for computing geometric error metrics between a GT and reconstructed visual 100 | mesh. 101 | 102 | #### segment_moving_obj_data.py 103 | 104 | Our segmentation pipeline for obtaining object and gripper masks. You might want to 105 | do human-in-the-loop segmentation by annotating specific frames with positive/ negative 106 | labels for more robust results. We provide a simple GUI for this purpose. The default 107 | automatic annotations using DINO work well in many cases but can struggle with the 108 | gripper masks. All downstream object tracking and reconstruction results are sensitive 109 | to the segmentation quality and thus spending a bit of effort here might be worthwhile. 110 | 111 | ### Submodules 112 | 113 | #### robot_payload_id 114 | 115 | Contains all our system identification code: 116 | - Optimal excitation trajectory design 117 | - Robot identification 118 | - Payload identification 119 | 120 | #### BundleSDF 121 | 122 | A SOTA object tracking method. 123 | Please see the [BundleSDF GitHub]([https://github.com/zju3dv/BundleSDF](https://github.com/NVlabs/BundleSDF)). 124 | 125 | Our fork improves the geometric reconstruction and texture mapping quality. 126 | 127 | #### nerfstudio 128 | 129 | A collection of SOTA NeRF methods. 130 | Please see the [nerfstudio GitHub](https://github.com/nerfstudio-project/nerfstudio). 131 | 132 | All our alpha-transparent training code is already merged into the main branch. 133 | 134 | #### neuralangelo 135 | 136 | A SOTA neural surface reconstruction method. 137 | Please see the [neuralangelo GitHub](https://github.com/NVlabs/neuralangelo). 138 | 139 | Our fork adds masking and alpha-transparent training. 140 | 141 | #### Frosting 142 | 143 | A SOTA surface reconstruction method using Gaussian Splatting. 144 | Please see the [Frosting GitHub](https://github.com/Anttwo/Frosting). 145 | 146 | Our fork adds masking, alpha-transparent training, and depth supervision. 147 | 148 | ## Running the pipeline 149 | 150 | ### 1. Generate an optimal excitation trajectory 151 | 152 | This trajectory needs to be generated once per environment as it considers the robot's 153 | position/ velocity/ acceleration limits, enforces non-penetration constraints with 154 | the environment, and enforces self-collision avoidance. 155 | 156 | The trajectory optimization can be run with 157 | `scalable_real2sim/robot_payload_id/scripts/design_optimal_excitation_trajectories.py`. 158 | See [here](https://github.com/nepfaff/robot_payload_id?tab=readme-ov-file#optimal-experiment-design) 159 | for more details and recommended script parameters. 160 | 161 | ### 2. Run data collection 162 | 163 | The initial robot data can be collected with `scalable_real2sim/pickplace_data_collection/scripts/collect_robot_joint_data_at_multiple_gripper_oppenings.py`. This 164 | needs to be done once per environment. Note that this requires a MOSEK license. 165 | 166 | The object data collection can be run with `scalable_real2sim/run_data_collection.py`. 167 | 168 | Note that this code is written for a particular robot setup, and small adjustments will likely need to be made to work for a different setup. Also note that you may encounter a segfault in one of the dependencies if your numpy version is >= 2.0.0. 169 | 170 | ### 3. Run robot identification 171 | 172 | The robot identification can be run with 173 | `scalable_real2sim/robot_payload_id/scripts/identify_robot_at_multiple_gripper_oppenings.py`. 174 | The only required argument is `--joint_data_path` which points to the robot data 175 | collected in step 2. 176 | Note that this needs to be done once per environment for the robot data from step 2. 177 | 178 | ### 4. Run asset generation 179 | 180 | The asset generation can be run with `scalable_real2sim/run_asset_generation.py`. 181 | 182 | ## Figures 183 | 184 | The paper robot figures were created with the help of 185 | [drake-blender-recorder](https://github.com/nepfaff/drake-blender-recorder). 186 | 187 | ## Citation 188 | 189 | If you find this work useful, please cite our paper: 190 | ```bibtex 191 | @article{pfaff2025_scalable_real2sim, 192 | author = {Pfaff, Nicholas and Fu, Evelyn and Binagia, Jeremy and Isola, Phillip and Tedrake, Russ}, 193 | title = {Scalable Real2Sim: Physics-Aware Asset Generation Via Robotic Pick-and-Place Setups}, 194 | year = {2025}, 195 | eprint = {2503.00370}, 196 | archivePrefix = {arXiv}, 197 | primaryClass = {cs.RO}, 198 | url = {https://arxiv.org/abs/2503.00370}, 199 | } 200 | ``` 201 | -------------------------------------------------------------------------------- /assets/object_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nepfaff/scalable-real2sim/edee99c65583c636ea401c64dfa8883e8e4dae5d/assets/object_table.png -------------------------------------------------------------------------------- /assets/system_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nepfaff/scalable-real2sim/edee99c65583c636ea401c64dfa8883e8e4dae5d/assets/system_diagram.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "scalable_real2sim" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Nicholas Pfaff "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.10,<3.11" 10 | drake = { version = ">=0.0.20250204 <1.0", source = "drake-nightly" } 11 | numpy = "^2.2.0" 12 | scipy = "^1.15.1" 13 | manipulation = "^2025.2.16" 14 | open3d = "^0.19.0" 15 | sam2 = "^0.4.1" 16 | transformers = {git = "https://github.com/huggingface/transformers.git", rev = "9c02cb6233eddedd8ecf0d48957cb481103f93f3"} 17 | opencv-python = "^4.10.0.84" 18 | accelerate = "^1.2.1" 19 | torch = "2.3.1" 20 | clip = "^0.2.0" 21 | timm = "^1.0.14" 22 | coacd = "^1.0.5" 23 | vhacdx = "^0.0.8.post2" 24 | robot-payload-id = {path = "scalable_real2sim/robot_payload_id"} 25 | torchvision = "0.18.1" 26 | trimesh = "^4.6.2" 27 | 28 | [tool.poetry.group.dev.dependencies] 29 | pre-commit = "^3.4.0" 30 | black = "23.9.1" 31 | isort = "5.12.0" 32 | poetry-pre-commit-plugin = "^0.1.2" 33 | 34 | [[tool.poetry.source]] 35 | name = "drake-nightly" 36 | url = "https://drake-packages.csail.mit.edu/whl/nightly/" 37 | priority = "explicit" 38 | 39 | [tool.isort] 40 | profile = 'black' 41 | lines_between_types = 1 42 | combine_as_imports = true 43 | known_first_party = ['dynamic_mesh_distance'] 44 | 45 | [tool.autoflake] 46 | in-place = true 47 | recursive = true 48 | expand-star-imports = true 49 | ignore-init-module-imports = true 50 | remove-all-unused-imports = true 51 | remove-duplicate-keys = true 52 | remove-unused-variables = true 53 | 54 | [build-system] 55 | requires = ["poetry-core"] 56 | build-backend = "poetry.core.masonry.api" 57 | -------------------------------------------------------------------------------- /run_asset_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for generating assets from the data collected by `run_data_collection.py`. 3 | """ 4 | 5 | import argparse 6 | import json 7 | import logging 8 | import math 9 | import os 10 | import shutil 11 | import subprocess 12 | import time 13 | 14 | from datetime import timedelta 15 | 16 | import numpy as np 17 | import torch 18 | 19 | from tqdm import tqdm 20 | 21 | from scalable_real2sim.data_processing.frosting import process_moving_obj_data_for_sugar 22 | from scalable_real2sim.data_processing.image_subsampling import ( 23 | select_and_copy_dissimilar_images, 24 | ) 25 | from scalable_real2sim.data_processing.nerfstudio import preprocess_data_for_nerfstudio 26 | from scalable_real2sim.data_processing.neuralangelo import ( 27 | process_moving_obj_data_for_neuralangelo, 28 | ) 29 | from scalable_real2sim.output.canonicalize import canonicalize_mesh_from_file 30 | from scalable_real2sim.output.sdformat import create_sdf 31 | from scalable_real2sim.robot_payload_id.scripts.identify_grasped_object_payload import ( 32 | identify_grasped_object_payload, 33 | ) 34 | from scalable_real2sim.segmentation.detect_object import detect_object 35 | from scalable_real2sim.segmentation.segment_moving_object_data import ( 36 | segment_moving_obj_data, 37 | ) 38 | 39 | 40 | def downsample_images(data_dir: str, num_images: int) -> None: 41 | start = time.perf_counter() 42 | 43 | # Check if we need to subsample images. 44 | rgb_dir = os.path.join(data_dir, "rgb") 45 | rgb_files = sorted(os.listdir(rgb_dir)) 46 | if len(rgb_files) > num_images: 47 | logging.info(f"Found {len(rgb_files)} images, subsampling to {num_images}...") 48 | 49 | # Move original images to rgb_original. 50 | rgb_original_dir = os.path.join(data_dir, "rgb_original") 51 | os.makedirs(rgb_original_dir, exist_ok=True) 52 | for f in rgb_files: 53 | shutil.move(os.path.join(rgb_dir, f), os.path.join(rgb_original_dir, f)) 54 | 55 | num_uniform_frames = num_images // 2 # Big gaps cause tracking to fail 56 | max_frame_gap = math.ceil(len(rgb_files) / num_uniform_frames) 57 | logging.info( 58 | f"Using a maximum frame gap of {max_frame_gap} for image downsampling." 59 | ) 60 | select_and_copy_dissimilar_images( 61 | image_dir=rgb_original_dir, 62 | output_dir=rgb_dir, 63 | K=num_images, 64 | N=max_frame_gap, 65 | model_name="dino", 66 | ) 67 | logging.info("Image subsampling complete.") 68 | else: 69 | logging.info("No image subsampling needed.") 70 | 71 | logging.info( 72 | f"Image downsampling took {timedelta(seconds=time.perf_counter() - start)}." 73 | ) 74 | 75 | 76 | def run_segmentation(data_dir: str, output_dir: str) -> None: 77 | start = time.perf_counter() 78 | 79 | # Detect the object of interest. Need to add a dot for the DINO model. 80 | object_of_interest = ( 81 | detect_object( 82 | image_path=os.path.join( 83 | data_dir, "rgb", sorted(os.listdir(os.path.join(data_dir, "rgb")))[0] 84 | ) 85 | ) 86 | + "." 87 | ) 88 | torch.cuda.empty_cache() 89 | logging.info(f"Detected object of interest: {object_of_interest}") 90 | 91 | gripper_txt = ( 92 | "Blue plastic robotic gripper with two symmetrical, curved arms " 93 | "attached to the end of a metallic robotic arm." 94 | ) 95 | 96 | # Generate the object masks. 97 | segment_moving_obj_data( 98 | rgb_dir=os.path.join(data_dir, "rgb"), 99 | output_dir=os.path.join(output_dir, "masks"), 100 | txt_prompt=object_of_interest, 101 | txt_prompt_index=1, 102 | neg_txt_prompt=gripper_txt, 103 | ) 104 | torch.cuda.empty_cache() 105 | 106 | # Generate the gripper masks. 107 | segment_moving_obj_data( 108 | rgb_dir=os.path.join(data_dir, "rgb"), 109 | output_dir=os.path.join(output_dir, "gripper_masks"), 110 | txt_prompt=gripper_txt, 111 | txt_prompt_index=1, 112 | neg_txt_prompt=object_of_interest, 113 | ) 114 | torch.cuda.empty_cache() 115 | 116 | logging.info(f"Segmentation took {timedelta(seconds=time.perf_counter() - start)}.") 117 | 118 | 119 | def run_bundle_sdf_tracking_and_reconstruction( 120 | data_dir: str, output_dir: str, interpolate_missing_vertices: bool = False 121 | ) -> None: 122 | start = time.perf_counter() 123 | 124 | # Run tracking and reconstruction. We run it in a subprocess to be able to use 125 | # independent dependencies for BundleSDF. 126 | current_dir = os.path.dirname(os.path.abspath(__file__)) 127 | bundle_sdf_dir = os.path.join(current_dir, "scalable_real2sim", "BundleSDF") 128 | data_dir = os.path.abspath(data_dir) 129 | output_dir = os.path.abspath(output_dir) 130 | os.chdir(bundle_sdf_dir) 131 | torch.cuda.empty_cache() 132 | subprocess.run( 133 | "source .venv/bin/activate && " 134 | f"python run_custom.py --video_dir {data_dir} " 135 | f"--out_folder {output_dir} --use_gui 1 " 136 | f"--interpolate_missing_vertices {int(interpolate_missing_vertices)}", 137 | cwd=bundle_sdf_dir, 138 | shell=True, 139 | executable="/bin/bash", 140 | ) 141 | os.chdir(current_dir) 142 | 143 | # Extract data. 144 | output_dir_parent = os.path.dirname(output_dir) 145 | shutil.move( 146 | os.path.join(output_dir, "ob_in_cam"), 147 | os.path.join(data_dir, "ob_in_cam"), 148 | ) 149 | shutil.move( 150 | os.path.join(output_dir, "mesh_cleaned.obj"), 151 | os.path.join(output_dir_parent, "bundle_sdf_mesh.obj"), 152 | ) 153 | textured_mesh_dir = os.path.join(output_dir_parent, "bundle_sdf_mesh") 154 | os.makedirs(textured_mesh_dir, exist_ok=True) 155 | shutil.move( 156 | os.path.join(output_dir, "textured_mesh.obj"), 157 | os.path.join(textured_mesh_dir, "textured_mesh.obj"), 158 | ) 159 | shutil.move( 160 | os.path.join(output_dir, "material.mtl"), 161 | os.path.join(textured_mesh_dir, "material.mtl"), 162 | ) 163 | shutil.move( 164 | os.path.join(output_dir, "material_0.png"), 165 | os.path.join(textured_mesh_dir, "material_0.png"), 166 | ) 167 | 168 | logging.info( 169 | "BundleSDF tracking and reconstruction took " 170 | f"{timedelta(seconds=time.perf_counter() - start)}." 171 | ) 172 | 173 | 174 | def run_nerfstudio(data_dir: str, output_dir: str, use_depth: bool = False) -> None: 175 | # Preprocess the data for Nerfstudio. 176 | start = time.perf_counter() 177 | preprocess_data_for_nerfstudio(data_dir, output_dir) 178 | logging.info( 179 | f"Nerfstudio preprocessing took {timedelta(seconds=time.perf_counter() - start)}." 180 | ) 181 | 182 | # Run Nerfstudio training. 183 | logging.info( 184 | "Started Nerfstudio training. Monitor progress at " 185 | "https://viewer.nerf.studio/versions/23-05-15-1/?websocket_url=ws://localhost:7007." 186 | ) 187 | start = time.perf_counter() 188 | torch.cuda.empty_cache() 189 | # NOTE: Use "nerfacto-big" or "nerfacto-huge" for better results but longer 190 | # training times. 191 | method = "depth-nerfacto" if use_depth else "nerfacto" 192 | train_process = subprocess.run( 193 | "source .venv_nerfstudio/bin/activate && " 194 | f"ns-train {method} --output-dir {output_dir} " 195 | # "--max-num-iterations 30000 " # Adjust if needed 196 | "--pipeline.model.background_color 'random' " 197 | f"--pipeline.model.disable-scene-contraction True --data {output_dir} " 198 | # Slower but enables larger datasets 199 | f"{'--pipeline.datamanager.load-from-disk True ' if not use_depth else ''}" 200 | "--vis viewer_legacy --viewer.quit-on-train-completion True", 201 | cwd=os.path.dirname(os.path.abspath(__file__)), 202 | shell=True, 203 | executable="/bin/bash", 204 | capture_output=True, 205 | text=True, 206 | ) 207 | logging.info( 208 | f"Nerfstudio training took {timedelta(seconds=time.perf_counter() - start)}." 209 | ) 210 | 211 | # Extract config file path from output. This is a bit complex because the file name 212 | # might be split across multiple lines. 213 | lines = train_process.stdout.split("\n") 214 | config_lines = [] 215 | capture = False 216 | for line in lines: 217 | if "Saving config to:" in line: 218 | capture = True # Start capturing from the next lines 219 | continue 220 | if capture: 221 | if "Saving checkpoints to:" in line: # Stop capturing 222 | break 223 | config_lines.append(line.strip()) 224 | # Join lines to reconstruct the path. 225 | config_path = "".join(config_lines).strip() 226 | if config_path and config_path.endswith("config.yml"): 227 | logging.info(f"Found config file: {config_path}") 228 | else: 229 | raise RuntimeError( 230 | "Could not find config file path in ns-train output. " 231 | f"Output: {train_process.stdout}" 232 | ) 233 | 234 | # Run mesh extraction. 235 | start = time.perf_counter() 236 | # NOTE: It is recommended to increase `--resolution` to the highest value that fits 237 | # into your GPU memory. 238 | mesh_output_dir = os.path.join(output_dir, "nerfstudio_mesh") 239 | subprocess.run( 240 | "source .venv_nerfstudio/bin/activate && " 241 | f"ns-export tsdf --load-config {config_path} " 242 | f"--output-dir {mesh_output_dir} --target-num-faces 100000 " 243 | "--downscale-factor 2 --num-pixels-per-side 2048 --resolution 250 250 250 " 244 | "--use-bounding-box True --bounding-box-min -0.5 -0.5 -0.5 " 245 | "--bounding-box-max 0.5 0.5 0.5 --refine-mesh-using-initial-aabb-estimate True", 246 | cwd=os.path.dirname(os.path.abspath(__file__)), 247 | shell=True, 248 | executable="/bin/bash", 249 | ) 250 | logging.info( 251 | f"Nerfstudio mesh extraction took {timedelta(seconds=time.perf_counter() - start)}." 252 | ) 253 | 254 | # Move the mesh to the output directory parent directory. 255 | shutil.move(mesh_output_dir, os.path.dirname(output_dir)) 256 | 257 | 258 | def run_frosting(data_dir: str, output_dir: str, use_depth: bool = False) -> None: 259 | # Preprocess the data for Frosting. 260 | logging.info("Preprocessing data for Frosting...") 261 | start = time.perf_counter() 262 | process_moving_obj_data_for_sugar( 263 | data_dir, output_dir, num_images=1800, use_depth=use_depth 264 | ) 265 | logging.info( 266 | f"Frosting preprocessing took {timedelta(seconds=time.perf_counter() - start)}." 267 | ) 268 | 269 | # Run reconstruction. We run it in a subprocess to be able to use independent 270 | # dependencies for Frosting. 271 | start = time.perf_counter() 272 | current_dir = os.path.dirname(os.path.abspath(__file__)) 273 | frosting_dir = os.path.join(current_dir, "scalable_real2sim", "Frosting") 274 | output_dir = os.path.abspath(output_dir) 275 | os.chdir(frosting_dir) 276 | torch.cuda.empty_cache() 277 | subprocess.run( 278 | "source .venv/bin/activate && " 279 | f"python train_full_pipeline.py -s {output_dir} -r 'dn_consistency' " 280 | "--high_poly True --export_obj True --white_background False " 281 | f"--masks {output_dir}/gripper_masks/ " 282 | + ("--depths depth/" if use_depth else ""), 283 | cwd=frosting_dir, 284 | shell=True, 285 | executable="/bin/bash", 286 | ) 287 | os.chdir(current_dir) 288 | 289 | # Move the mesh to the output directory parent directory. 290 | mesh_path = os.path.join( 291 | frosting_dir, "output", "refined_frosting_base_mesh", "frosting" 292 | ) 293 | mesh_out_dir = os.path.join(os.path.dirname(output_dir), "frosting_mesh") 294 | shutil.move(mesh_path, mesh_out_dir) 295 | 296 | logging.info( 297 | f"Frosting reconstruction took {timedelta(seconds=time.perf_counter() - start)}." 298 | ) 299 | 300 | 301 | def run_neuralangelo(data_dir: str, output_dir: str, use_depth: bool = False) -> None: 302 | # Process the data for Neuralangelo. 303 | logging.info("Processing data for Neuralangelo...") 304 | start = time.perf_counter() 305 | cfg_path = process_moving_obj_data_for_neuralangelo( 306 | data_dir, output_dir, num_images=1800 307 | ) 308 | logging.info( 309 | f"Neuralangelo preprocessing took {timedelta(seconds=time.perf_counter() - start)}." 310 | ) 311 | 312 | # Run Neuralangelo training. 313 | logging.info("Running Neuralangelo training...") 314 | start = time.perf_counter() 315 | current_dir = os.path.dirname(os.path.abspath(__file__)) 316 | neuralangelo_dir = os.path.join(current_dir, "scalable_real2sim", "neuralangelo") 317 | output_dir = os.path.abspath(output_dir) 318 | os.chdir(neuralangelo_dir) 319 | torch.cuda.empty_cache() 320 | subprocess.run( 321 | "source .venv/bin/activate && " 322 | f"torchrun --nproc_per_node=1 train.py --config {cfg_path} " 323 | f"--logdir {output_dir} --show_pbar " 324 | "--wandb --wandb_name nicholas_neuralangelo", 325 | cwd=neuralangelo_dir, 326 | shell=True, 327 | executable="/bin/bash", 328 | ) 329 | logging.info( 330 | f"Neuralangelo training took {timedelta(seconds=time.perf_counter() - start)}." 331 | ) 332 | 333 | # Obtain checkpoint path. 334 | checkpoint_version_txt = os.path.join(output_dir, "latest_checkpoint.txt") 335 | with open(checkpoint_version_txt, "r") as f: 336 | checkpoint_name = f.read().strip() 337 | checkpoint_path = os.path.join(output_dir, checkpoint_name) 338 | 339 | # Run Neuralangelo mesh extraction. 340 | logging.info("Running Neuralangelo mesh extraction...") 341 | start = time.perf_counter() 342 | output_mesh_dir = os.path.join(os.path.dirname(output_dir), "neuralangelo_mesh") 343 | # Lower resolution to reduce mesh size, lower block_res to reduce GPU memory usage. 344 | subprocess.run( 345 | "source .venv/bin/activate && " 346 | "torchrun --nproc_per_node=1 projects/neuralangelo/scripts/extract_mesh.py " 347 | f"--config {cfg_path} --checkpoint={checkpoint_path} --textured " 348 | f"--resolution=2048 --block_res=160 --output_file={output_mesh_dir}/mesh.obj", 349 | cwd=neuralangelo_dir, 350 | shell=True, 351 | executable="/bin/bash", 352 | ) 353 | logging.info( 354 | f"Neuralangelo mesh extraction took {timedelta(seconds=time.perf_counter() - start)}." 355 | ) 356 | os.chdir(current_dir) 357 | 358 | 359 | def replace_trimesh_mesh_material(mesh_path: str) -> None: 360 | """Replaces the trimesh material file values with a nicer looking one.""" 361 | material_path = os.path.join(os.path.dirname(mesh_path), "material.mtl") 362 | 363 | # Replace the material values with nicer looking ones. 364 | with open(material_path, "w") as mtl_file: 365 | mtl_file.write( 366 | "newmtl material_0\n" 367 | "Ns 50.000000\n" 368 | "Ka 1.000000 1.000000 1.000000\n" 369 | "Kd 1.0 1.0 1.0\n" 370 | "Ks 0.2 0.2 0.2\n" 371 | "Ke 0.000000 0.000000 0.000000\n" 372 | "Ni 1.500000\n" 373 | "d 1.000000\n" 374 | "illum 2\n" 375 | "map_Kd material_0.png\n" 376 | ) 377 | 378 | 379 | def main( 380 | data_dir: str, 381 | robot_id_dir: str, 382 | output_dir: str, 383 | skip_segmentation: bool = False, 384 | bundle_sdf_interpolate_missing_vertices: bool = False, 385 | use_depth: bool = False, 386 | ): 387 | logging.info("Starting asset generation...") 388 | 389 | # Create output dir. 390 | os.makedirs(output_dir, exist_ok=True) 391 | 392 | # Get all subdirectories in the data directory. 393 | object_dirs = [ 394 | os.path.join(data_dir, d) 395 | for d in os.listdir(data_dir) 396 | if os.path.isdir(os.path.join(data_dir, d)) 397 | ] 398 | 399 | for object_dir in tqdm(object_dirs): 400 | # Set up logging for this object. 401 | object_name = os.path.basename(object_dir) 402 | object_output_dir = os.path.join(output_dir, object_name) 403 | os.makedirs(object_output_dir, exist_ok=True) 404 | 405 | # Create a file handler for this object. 406 | log_file = os.path.join(object_output_dir, f"{object_name}_processing.log") 407 | file_handler = logging.FileHandler(log_file) 408 | file_handler.setLevel(logging.INFO) 409 | file_handler.setFormatter( 410 | logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 411 | ) 412 | 413 | # Add the file handler to the root logger. 414 | logging.getLogger().addHandler(file_handler) 415 | 416 | logging.info(f"Processing object: {object_dir}") 417 | 418 | try: 419 | # Downsample images to not run out of memory. 420 | downsample_images(object_dir, num_images=1800) 421 | 422 | # Generate object and gripper masks. 423 | if not skip_segmentation: 424 | logging.info("Running segmentation...") 425 | run_segmentation(data_dir=object_dir, output_dir=object_dir) 426 | else: 427 | logging.info("Skipping segmentation...") 428 | if not os.path.exists(os.path.join(object_dir, "masks")): 429 | raise FileNotFoundError( 430 | "Object masks not found. Please run segmentation first." 431 | ) 432 | if not os.path.exists(os.path.join(object_dir, "gripper_masks")): 433 | raise FileNotFoundError( 434 | "Gripper masks not found. Please run segmentation first." 435 | ) 436 | 437 | # Object tracking + BundleSDF reconstruction. 438 | logging.info("Running BundleSDF tracking and reconstruction...") 439 | bundle_sdf_output_dir = os.path.join(object_output_dir, "bundle_sdf") 440 | os.makedirs(bundle_sdf_output_dir, exist_ok=True) 441 | run_bundle_sdf_tracking_and_reconstruction( 442 | data_dir=object_dir, 443 | output_dir=bundle_sdf_output_dir, 444 | interpolate_missing_vertices=bundle_sdf_interpolate_missing_vertices, 445 | ) 446 | 447 | # Canonicalize the BundleSDF mesh. 448 | bundle_sdf_mesh_path = os.path.join( 449 | object_output_dir, "bundle_sdf_mesh/textured_mesh.obj" 450 | ) 451 | canonicalize_mesh_from_file( 452 | mesh_path=bundle_sdf_mesh_path, 453 | output_path=bundle_sdf_mesh_path, 454 | ) 455 | replace_trimesh_mesh_material(bundle_sdf_mesh_path) 456 | logging.info(f"Canonicalized BundleSDF mesh: {bundle_sdf_mesh_path}") 457 | 458 | # Physical property estimation in the BundleSDF frame. 459 | bundle_sdf_inertia_params_path = os.path.join( 460 | object_dir, "bundle_sdf_inertial_params.json" 461 | ) 462 | identify_grasped_object_payload( 463 | robot_joint_data_path=robot_id_dir, 464 | object_joint_data_path=os.path.join(object_dir, "system_id_data"), 465 | object_mesh_path=bundle_sdf_mesh_path, 466 | json_output_path=bundle_sdf_inertia_params_path, 467 | ) 468 | with open(bundle_sdf_inertia_params_path, "r") as f: 469 | bundle_sdf_inertia_params = json.load(f) 470 | 471 | # Output the BundleSDF SDFormat file. 472 | bundle_sdf_sdf_output_dir = os.path.join( 473 | object_output_dir, f"{object_name}_bundle_sdf.sdf" 474 | ) 475 | create_sdf( 476 | model_name=object_name, 477 | mesh_parts_dir_name=f"{object_name}_bundle_sdf_parts", 478 | output_path=bundle_sdf_sdf_output_dir, 479 | visual_mesh_path=bundle_sdf_mesh_path, 480 | collision_mesh_path=bundle_sdf_mesh_path, 481 | mass=bundle_sdf_inertia_params["mass"], 482 | center_of_mass=np.array(bundle_sdf_inertia_params["center_of_mass"]), 483 | moment_of_inertia=np.array(bundle_sdf_inertia_params["inertia_matrix"]), 484 | use_hydroelastic=False, # Enable for more accurate but Drake-specific SDFormat 485 | use_coacd=True, 486 | ) 487 | 488 | # Nerfacto reconstruction + SDFormat output. 489 | nerfstudio_output_dir = os.path.join(object_output_dir, "nerfstudio") 490 | run_nerfstudio( 491 | data_dir=object_dir, 492 | output_dir=nerfstudio_output_dir, 493 | use_depth=use_depth, 494 | ) 495 | 496 | # Frosting reconstruction + SDFormat output. 497 | frosting_output_dir = os.path.join(object_output_dir, "frosting") 498 | run_frosting( 499 | data_dir=object_dir, 500 | output_dir=frosting_output_dir, 501 | use_depth=use_depth, 502 | ) 503 | 504 | # Neuralangelo reconstruction + SDFormat output. 505 | neuralangelo_output_dir = os.path.join(object_output_dir, "neuralangelo") 506 | run_neuralangelo(data_dir=object_dir, output_dir=neuralangelo_output_dir) 507 | 508 | # Canonicalize the Neuralangelo mesh. 509 | neuralangelo_mesh_path = os.path.join( 510 | object_output_dir, "neuralangelo_mesh/mesh.obj" 511 | ) 512 | canonicalize_mesh_from_file( 513 | mesh_path=neuralangelo_mesh_path, 514 | output_path=neuralangelo_mesh_path, 515 | ) 516 | logging.info(f"Canonicalized Neuralangelo mesh: {neuralangelo_mesh_path}") 517 | 518 | # Physical property estimation in the Neuralangelo frame. 519 | neuralangelo_inertia_params_path = os.path.join( 520 | object_dir, "neuralangelo_inertial_params.json" 521 | ) 522 | identify_grasped_object_payload( 523 | robot_joint_data_path=robot_id_dir, 524 | object_joint_data_path=os.path.join(object_dir, "system_id_data"), 525 | object_mesh_path=neuralangelo_mesh_path, 526 | json_output_path=neuralangelo_inertia_params_path, 527 | ) 528 | with open(neuralangelo_inertia_params_path, "r") as f: 529 | neuralangelo_inertia_params = json.load(f) 530 | 531 | # Output the Neuralangelo SDFormat file. 532 | neuralangelo_sdf_output_dir = os.path.join( 533 | object_output_dir, f"{object_name}_neuralangelo.sdf" 534 | ) 535 | create_sdf( 536 | model_name=object_name, 537 | mesh_parts_dir_name=f"{object_name}_neuralangelo_parts", 538 | output_path=neuralangelo_sdf_output_dir, 539 | visual_mesh_path=neuralangelo_mesh_path, 540 | collision_mesh_path=neuralangelo_mesh_path, 541 | mass=neuralangelo_inertia_params["mass"], 542 | center_of_mass=np.array(neuralangelo_inertia_params["center_of_mass"]), 543 | moment_of_inertia=np.array( 544 | neuralangelo_inertia_params["inertia_matrix"] 545 | ), 546 | use_hydroelastic=False, # Enable for more accurate but Drake-specific SDFormat 547 | use_coacd=True, 548 | ) 549 | 550 | logging.info(f"Finished processing object: {object_dir}") 551 | torch.cuda.empty_cache() 552 | 553 | finally: 554 | # Remove the file handler after processing this object. 555 | logging.getLogger().removeHandler(file_handler) 556 | file_handler.close() 557 | 558 | 559 | if __name__ == "__main__": 560 | logging.basicConfig(level=logging.INFO) 561 | 562 | parser = argparse.ArgumentParser( 563 | description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 564 | ) 565 | parser.add_argument( 566 | "--data-dir", 567 | type=str, 568 | required=True, 569 | help="Path to the directory where the collected data is saved. This should be " 570 | "the top-level directory that contains all the object subdirectories.", 571 | ) 572 | parser.add_argument( 573 | "--robot-id-dir", 574 | type=str, 575 | required=True, 576 | help="Path to the directory where the robot ID data is saved. This should be " 577 | "the top-level directory that contains all different gripper opening " 578 | "subdirectories. NOTE that the robot parameters should have already been " 579 | "identified.", 580 | ) 581 | parser.add_argument( 582 | "--output-dir", 583 | type=str, 584 | required=True, 585 | help="Path to the directory where the generated assets will be saved. A new " 586 | "subdirectory will be created for each object with the same name as the " 587 | "subdirectory in the data directory.", 588 | ) 589 | parser.add_argument( 590 | "--skip-segmentation", 591 | action="store_true", 592 | help="If specified, skip the segmentation step. This requires segmentation " 593 | "data to be already present in the data directory. It might be useful if you " 594 | "want to run segmentation using the `segment_moving_obj_data.py` script and " 595 | "manually specified positive/ negative annotations which is significantly " 596 | "more robust than automatic annotations from DINO.", 597 | ) 598 | parser.add_argument( 599 | "--bundle-sdf-interpolate-missing-vertices", 600 | action="store_true", 601 | help="If specified, interpolate missing vertices in the BundleSDF texture map. " 602 | "This results in higher quality textures but is extremely slow.", 603 | ) 604 | parser.add_argument( 605 | "--use-depth", 606 | action="store_true", 607 | help="If specified, use depth images for geometric reconstruction when " 608 | "supported by the reconstruction method.", 609 | ) 610 | args = parser.parse_args() 611 | 612 | if not os.path.exists(args.data_dir): 613 | raise FileNotFoundError(f"Data directory {args.data_dir} does not exist.") 614 | 615 | if not os.path.exists(args.output_dir): 616 | os.makedirs(args.output_dir) 617 | 618 | main( 619 | data_dir=args.data_dir, 620 | robot_id_dir=args.robot_id_dir, 621 | output_dir=args.output_dir, 622 | skip_segmentation=args.skip_segmentation, 623 | bundle_sdf_interpolate_missing_vertices=args.bundle_sdf_interpolate_missing_vertices, 624 | use_depth=args.use_depth, 625 | ) 626 | -------------------------------------------------------------------------------- /run_data_collection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for collecting all the data needed for asset generation. Runs in a loop until 3 | the first bin is empty. 4 | """ 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import copy 10 | import pickle 11 | import datetime 12 | import shutil 13 | from scipy.spatial.transform import Rotation as R 14 | from pydrake.geometry import ( 15 | StartMeshcat, 16 | RenderLabel, 17 | Role, 18 | ) 19 | from pydrake.all import ConstantVectorSource, VectorLogSink 20 | from pydrake.systems.analysis import Simulator 21 | from pydrake.systems.framework import DiagramBuilder 22 | from pydrake.systems.sensors import CameraInfo 23 | from pydrake.systems.primitives import ( 24 | PortSwitch, 25 | Multiplexer, 26 | Demultiplexer 27 | ) 28 | from pydrake.perception import ( 29 | DepthImageToPointCloud 30 | ) 31 | 32 | # from manipulation.systems import AddIiwaDifferentialIK 33 | from manipulation.systems import ExtractPose 34 | from manipulation.station import MakeHardwareStation, LoadScenario 35 | 36 | import pydrake.planning as mut 37 | from pydrake.common import RandomGenerator, Parallelism, use_native_cpp_logging 38 | from pydrake.planning import (RobotDiagramBuilder, 39 | SceneGraphCollisionChecker, 40 | CollisionCheckerParams) 41 | from pydrake.math import ( 42 | RigidTransform, 43 | RotationMatrix, 44 | RollPitchYaw, 45 | ) 46 | from pathlib import Path 47 | from planning.two_grasp_display_planner import PlannerState, PickState 48 | from pydrake.solvers import MosekSolver, GurobiSolver 49 | from pydrake.all import LeafSystem, Value, Context, InputPort 50 | from planning.two_grasp_display_planner import TwoGraspPlanner 51 | from planning.turntable_planner import TurntablePlanner 52 | from perception.image_saver import ImageSaver 53 | from perception.camera_in_world import CameraPoseInWorldSource 54 | from planning.trajectory_sources import TrajectoryWithTimingInformationSource, DummyTrajSource 55 | from planning.diffik import AddIiwaDifferentialIK 56 | # from iiwa import IiwaHardwareStationDiagram 57 | 58 | def get_regions_static(scenario_path, dirstr): 59 | print("generating static regions") 60 | use_native_cpp_logging() 61 | params = dict(edge_step_size=0.125) 62 | builder = RobotDiagramBuilder() 63 | builder.parser().AddModels(scenario_path) 64 | iiwa_model_instance_index = builder.plant().GetModelInstanceByName("iiwa") 65 | wsg_model_instance_index = builder.plant().GetModelInstanceByName("wsg") 66 | params["robot_model_instances"] = [iiwa_model_instance_index, wsg_model_instance_index] 67 | params["model"] = builder.Build() 68 | checker = SceneGraphCollisionChecker(**params) 69 | 70 | options = mut.IrisFromCliqueCoverOptions() 71 | options.num_points_per_coverage_check = 5000 72 | options.num_points_per_visibility_round = 1000 73 | options.minimum_clique_size = 16 74 | options.coverage_termination_threshold = 0.7 75 | 76 | generator = RandomGenerator(0) 77 | 78 | if (MosekSolver().available() and MosekSolver().enabled()) or ( 79 | GurobiSolver().available() and GurobiSolver().enabled()): 80 | # We need a MIP solver to be available to run this method. 81 | sets = mut.IrisInConfigurationSpaceFromCliqueCover( 82 | checker=checker, options=options, generator=generator, 83 | sets=[] 84 | ) 85 | 86 | if len(sets) < 1: 87 | raise("No regions found") 88 | 89 | time_str = datetime.datetime.now().strftime('%d%m%y_%H%M%S') 90 | with open(dirstr+f'/{scenario_path.split("/")[-1]}_{time_str}_regions.pkl', 'wb') as f: 91 | pickle.dump(sets, f) 92 | 93 | return sets 94 | else: 95 | print("No solvers available") 96 | 97 | class SystemIDDataSaver(LeafSystem): 98 | def __init__( 99 | self, output_dir: str 100 | ): 101 | super().__init__() 102 | self.output_dir = Path(output_dir) 103 | self.output_dir.mkdir(exist_ok=True, parents=True) 104 | 105 | # Each array has shape (7,) 106 | self.measured_positions: list[np.ndarray] = [] 107 | self.measured_torques: list[np.ndarray] = [] 108 | self.measured_times: list[float] = [] 109 | self.measured_wsg_positions: list[float] = [] 110 | self.current_start_time = None 111 | self.object_saved = False 112 | 113 | self._planner_state_input_port = self.DeclareAbstractInputPort( 114 | "planner_state", model_value=Value(PlannerState.START)) 115 | self._pick_state_input_port = self.DeclareAbstractInputPort( 116 | "pick_state", model_value=Value(PickState.IDLE)) 117 | 118 | self._iiwa_position_input_port = self.DeclareVectorInputPort( 119 | "iiwa.position_measured", size=7 120 | ) 121 | self._iiwa_torque_input_port = self.DeclareVectorInputPort( 122 | "iiwa.torque_measured", size=7 123 | ) 124 | self._wsg_position_input_port = self.DeclareVectorInputPort( 125 | "wsg.position_measured", size=1 126 | ) 127 | 128 | self.DeclarePeriodicPublishEvent(1e-3, 0.0, self.save_logs) 129 | 130 | 131 | def save_logs(self, context: Context): 132 | mode = self._planner_state_input_port.Eval(context) 133 | pick_mode = self._pick_state_input_port.Eval(context) 134 | if mode == PlannerState.RESET and not self.object_saved: 135 | self.save_to_disk() 136 | self.measured_positions = [] 137 | self.measured_torques = [] 138 | self.measured_times = [] 139 | self.object_saved = True 140 | self.current_start_time = None 141 | elif mode == PlannerState.START: 142 | self.object_saved = False 143 | elif mode == PlannerState.SYS_ID_GRASP and pick_mode == PickState.DISPLAY: 144 | # Store current positions and torques. 145 | positions = self._iiwa_position_input_port.Eval(context) 146 | torques = self._iiwa_torque_input_port.Eval(context) 147 | wsg_position = self._wsg_position_input_port.Eval(context) 148 | 149 | time = context.get_time() 150 | if self.current_start_time is None: 151 | self.current_start_time = time 152 | traj_time = time - self.current_start_time # Want trajectory data to start at zero time 153 | 154 | self.measured_positions.append(positions) 155 | self.measured_torques.append(torques) 156 | self.measured_times.append(traj_time) 157 | self.measured_wsg_positions.append(wsg_position) 158 | 159 | def save_to_disk(self): 160 | if self.current_start_time is None: 161 | print("No system ID data to save...") 162 | return 163 | print("Saving system ID data to disk.") 164 | 165 | # Convert to numpy arrays. 166 | measured_position_data = np.stack(self.measured_positions) # Shape (T, 7) 167 | measured_torque_data = np.stack(self.measured_torques) # Shape (T, 7) 168 | sample_times_s = np.array(self.measured_times) # Shape (T,) 169 | measured_wsg_positions = np.array(self.measured_wsg_positions).squeeze(1) # Shape (T,) 170 | 171 | # Remove duplicated samples. 172 | _, unique_indices = np.unique(sample_times_s, return_index=True) 173 | if len(unique_indices) < len(sample_times_s): 174 | print(f"{len(unique_indices)} out of {len(sample_times_s)} data points are unique!") 175 | measured_position_data = measured_position_data[unique_indices] 176 | measured_torque_data = measured_torque_data[unique_indices] 177 | sample_times_s = sample_times_s[unique_indices] 178 | measured_wsg_positions = measured_wsg_positions[unique_indices] 179 | 180 | # Save to disk. 181 | np.save(self.output_dir / "joint_positions.npy", measured_position_data) 182 | np.save(self.output_dir / "joint_torques.npy", measured_torque_data) 183 | np.save(self.output_dir / "sample_times_s.npy", sample_times_s) 184 | np.save(self.output_dir / "wsg_positions.npy", measured_wsg_positions) 185 | 186 | print("Saved system id data to", self.output_dir) 187 | 188 | 189 | def start_scenario( 190 | dirstr = "temp", 191 | scenario_path="scenario_data_grasping.yml", 192 | models_path="scenario_data_grasping_no_object.dmd.yaml", 193 | gripper_model_path="", 194 | use_hardware=False, 195 | save_imgs=False, 196 | turntable=False, 197 | time_horizon=10.0, 198 | num_objects=1 199 | ): 200 | 201 | meshcat.ResetRenderMode() 202 | 203 | builder = DiagramBuilder() 204 | 205 | dir_path = os.path.dirname(os.path.realpath(__file__)) 206 | filename = os.path.join(dir_path, "scalable_real2sim", "pickplace_data_collection", "scenario_datas", scenario_path) 207 | scenario = LoadScenario(filename=filename) 208 | # station: IiwaHardwareStationDiagram = builder.AddNamedSystem( 209 | # "station", 210 | # IiwaHardwareStationDiagram( 211 | # scenario=scenario, has_wsg=True, use_hardware=use_hardware 212 | # ), 213 | # ) 214 | models_package = os.path.abspath(os.path.join(dir_path, "scalable_real2sim", "pickplace_data_collection", "models", "package.xml")) 215 | station = builder.AddSystem(MakeHardwareStation(scenario, meshcat, hardware=False, package_xmls=[models_package])) 216 | if use_hardware: 217 | scenario.plant_config.time_step = 5e-3 # Controller frequency 218 | external_station = builder.AddSystem(MakeHardwareStation(scenario, meshcat, hardware=True, package_xmls=[models_package])) 219 | plant = station.GetSubsystemByName("plant") 220 | # plant = station.get_plant() 221 | 222 | # initialize image writer and save directories 223 | if os.path.exists(dirstr): 224 | shutil.rmtree(dirstr) 225 | os.makedirs(dirstr) 226 | if save_imgs: 227 | # save images 228 | if use_hardware: 229 | # this doesn't work because saving images takes too long and bottlenecks the whole system 230 | # please use scripts/realsense.py instead in parallel 231 | img_saver = builder.AddSystem( 232 | ImageSaver( 233 | depth_format="32F", 234 | dirstr=dirstr, 235 | labels=False, 236 | camera_info=True, 237 | use_hardware=True, 238 | num_objects=num_objects 239 | ) 240 | ) 241 | else: 242 | img_saver = builder.AddSystem( 243 | ImageSaver( 244 | depth_format="32F", 245 | dirstr=dirstr, 246 | labels=True, 247 | camera_info=False, 248 | ob_in_cam=True, 249 | object_index=plant.GetBodyByName("base_link_mustard").index(), 250 | camera_index=plant.GetBodyByName("base").index() 251 | ) 252 | ) 253 | builder.Connect(station.GetOutputPort("camera0.label_image"), img_saver.GetInputPort("label_in")) 254 | builder.Connect(station.GetOutputPort("camera0.rgb_image"), img_saver.GetInputPort("rgb_in")) 255 | builder.Connect(station.GetOutputPort("camera0.depth_image"), img_saver.GetInputPort("depth_in")) 256 | builder.Connect(station.GetOutputPort("body_poses"), img_saver.GetInputPort("body_poses")) 257 | 258 | sys_id_saver = builder.AddSystem(SystemIDDataSaver(output_dir=os.path.join(dirstr, "system_id_data"))) 259 | 260 | # initialize point cloud output ports and save camera instrinsics 261 | if not use_hardware: 262 | camera0 = station.GetSubsystemByName("rgbd_sensor_camera0") 263 | camera1 = station.GetSubsystemByName("rgbd_sensor_camera1") 264 | camera2 = station.GetSubsystemByName("rgbd_sensor_camera2") 265 | camera_bin = station.GetSubsystemByName("rgbd_sensor_camera_bin") 266 | K = camera0.default_color_render_camera().core().intrinsics().intrinsic_matrix() 267 | if save_imgs: 268 | np.savetxt(dirstr+"/cam_K.txt", K) 269 | 270 | camera0_pcd = builder.AddSystem(DepthImageToPointCloud(camera0.default_depth_render_camera().core().intrinsics())) 271 | camera1_pcd = builder.AddSystem(DepthImageToPointCloud(camera1.default_depth_render_camera().core().intrinsics())) 272 | camera2_pcd = builder.AddSystem(DepthImageToPointCloud(camera2.default_depth_render_camera().core().intrinsics())) 273 | camera_bin_pcd = builder.AddSystem(DepthImageToPointCloud(camera_bin.default_depth_render_camera().core().intrinsics())) 274 | builder.Connect(station.GetOutputPort("camera0.depth_image"), camera0_pcd.GetInputPort("depth_image")) 275 | builder.Connect(station.GetOutputPort("camera1.depth_image"), camera1_pcd.GetInputPort("depth_image")) 276 | builder.Connect(station.GetOutputPort("camera2.depth_image"), camera2_pcd.GetInputPort("depth_image")) 277 | builder.Connect(station.GetOutputPort("camera_bin.depth_image"), camera_bin_pcd.GetInputPort("depth_image")) 278 | 279 | else: 280 | camera0_pcd = builder.AddSystem(DepthImageToPointCloud(CameraInfo(848, 480, 600.165, 600.165, 429.152, 232.822))) 281 | camera1_pcd = builder.AddSystem(DepthImageToPointCloud(CameraInfo(848, 480, 626.633, 626.633, 432.041, 245.465))) 282 | camera2_pcd = builder.AddSystem(DepthImageToPointCloud(CameraInfo(848, 480, 596.492, 596.492, 416.694, 240.225))) 283 | camera_bin_pcd = builder.AddSystem(DepthImageToPointCloud(CameraInfo(640, 480, 385.218, 385.218, 321.295, 244.071))) 284 | 285 | builder.Connect(external_station.GetOutputPort("camera0.depth_image"), camera0_pcd.GetInputPort("depth_image")) 286 | builder.Connect(external_station.GetOutputPort("camera1.depth_image"), camera1_pcd.GetInputPort("depth_image")) 287 | builder.Connect(external_station.GetOutputPort("camera2.depth_image"), camera2_pcd.GetInputPort("depth_image")) 288 | builder.Connect(external_station.GetOutputPort("camera_bin.depth_image"), camera_bin_pcd.GetInputPort("depth_image")) 289 | 290 | 291 | if use_hardware: 292 | # from camera calibation 293 | # Front camera 294 | x_front_rgb = RigidTransform(np.loadtxt("/home/real2sim/calibrations/2_10_calibrations_aligned/front.txt")) 295 | 296 | # Back Right camera 297 | x_back_right_rgb = RigidTransform(np.loadtxt("/home/real2sim/calibrations/2_10_calibrations_aligned/back_right.txt")) 298 | 299 | # Back Left camera 300 | x_back_left_rgb = RigidTransform(np.loadtxt("/home/real2sim/calibrations/2_10_calibrations_aligned/back_left.txt")) 301 | 302 | # Bin camera 303 | x_bin_rgb = RigidTransform(np.loadtxt("/home/real2sim/calibrations/bin_calibration_2_7_daniilidis.txt")) 304 | 305 | # rgb calibration to depth calibration (from realsense specs) 306 | # Front camera 307 | x_depth_rgb_front = RigidTransform([[0.999986, -0.000127587, 0.00531376, 0.015102], 308 | [0.000116105, 0.999998, 0.00216102, 6.44158e-05], 309 | [-0.00531402, -0.00216038, 0.999984, -0.000426644], 310 | [0, 0, 0, 1]]) 311 | x_front_camera = x_front_rgb @ x_depth_rgb_front 312 | 313 | # Back Right camera 314 | x_depth_rgb_back_right = RigidTransform([[0.999968, -0.00700185, 0.00399879, 0.015085], 315 | [ 0.00701494, 0.99997, -0.00326805, -2.1265e-05], 316 | [-0.00397579, 0.00329599, 0.999987, -0.000455872], 317 | [ 0, 0, 0, 1]]) 318 | x_back_right_camera = x_back_right_rgb @ x_depth_rgb_back_right 319 | 320 | # Back Left camera 321 | x_depth_rgb_back_left = RigidTransform([[0.999998, -0.000191981, -0.00215977, 0.0150991], 322 | [0.000214442, 0.999946, 0.0104041, 7.71731e-05], 323 | [0.00215765, -0.0104046, 0.999944, -0.000317806], 324 | [ 0, 0, 0, 1]]) 325 | x_back_left_camera = x_back_left_rgb @ x_depth_rgb_back_left 326 | 327 | # Bin camera 328 | x_depth_rgb_bin = RigidTransform([[0.999968, 0.00149319, 0.00783427, 0.0147784], 329 | [-0.00146555, 0.999993, -0.00353279, -4.93721e-05], 330 | [-0.00783949, 0.00352119, 0.999963, 0.000204544], 331 | [ 0, 0, 0, 1]]) 332 | x_bin_camera = x_bin_rgb @ x_depth_rgb_bin 333 | else: 334 | # Front camera 335 | x_front_camera = RigidTransform( 336 | RotationMatrix(RollPitchYaw(-150.29508676 / 180. * np.pi, -0.49652966 / 180. * np.pi, 87.69325379 / 180. * np.pi)), 337 | [1.00847, -0.0314675, 1.12864] 338 | ) 339 | 340 | # Back Right camera 341 | x_back_right_camera = RigidTransform( 342 | RotationMatrix(RollPitchYaw(-105.81290946 / 180. * np.pi, 2.14985993, -43.7254432 / 180. * np.pi)), 343 | [-0.110748, -0.931772, 0.388191] 344 | ) 345 | 346 | # Back Left camera 347 | x_back_left_camera = RigidTransform( 348 | RotationMatrix(RollPitchYaw(-102.739428 / 180. * np.pi, -3.69469624 / 180. * np.pi, -149.1420755 / 180. * np.pi)), 349 | [-0.0533544, 1.00955, 0.449207] 350 | ) 351 | 352 | # Bin camera 353 | x_bin_camera = RigidTransform( 354 | RotationMatrix(RollPitchYaw(-164.69831287 / 180. * np.pi, -35.83297034 / 180. * np.pi, -99.44115857 / 180. * np.pi)), 355 | [-0.0574518, 0.874365 , 0.332985] 356 | ) 357 | 358 | # connect stationary camera pcd source 359 | camera0_pose_source = builder.AddSystem(CameraPoseInWorldSource(x_front_camera, handeye=False)) 360 | camera1_pose_source = builder.AddSystem(CameraPoseInWorldSource(x_back_right_camera, handeye=False)) 361 | camera2_pose_source = builder.AddSystem(CameraPoseInWorldSource(x_back_left_camera, handeye=False)) 362 | bin_cam_pose_source = builder.AddSystem(CameraPoseInWorldSource(x_bin_camera, handeye=False)) 363 | 364 | builder.Connect( 365 | camera0_pose_source.GetOutputPort("X_WC"), 366 | camera0_pcd.GetInputPort("camera_pose"), 367 | ) 368 | 369 | builder.Connect( 370 | camera1_pose_source.GetOutputPort("X_WC"), 371 | camera1_pcd.GetInputPort("camera_pose"), 372 | ) 373 | 374 | builder.Connect( 375 | camera2_pose_source.GetOutputPort("X_WC"), 376 | camera2_pcd.GetInputPort("camera_pose"), 377 | ) 378 | 379 | builder.Connect( 380 | bin_cam_pose_source.GetOutputPort("X_WC"), 381 | camera_bin_pcd.GetInputPort("camera_pose"), 382 | ) 383 | 384 | controller_plant = station.GetSubsystemByName( 385 | "iiwa_controller_plant_pointer_system" 386 | ).get() 387 | 388 | # Set up planner 389 | if turntable: 390 | planner = builder.AddSystem(TurntablePlanner( 391 | plant=plant, 392 | controller_plant=controller_plant, 393 | X_WC0=x_front_camera, 394 | X_WC1=x_back_left_camera, 395 | X_WC2=x_back_right_camera, 396 | X_WC_bin=x_bin_camera, 397 | meshcat=meshcat, 398 | dirstr=dirstr, 399 | models_path=os.path.join(dir_path, "scalable_real2sim", "pickplace_data_collection", "scenario_datas", models_path), 400 | gripper_model_path=gripper_model_path)) 401 | else: 402 | planner = builder.AddSystem( 403 | TwoGraspPlanner( 404 | plant=plant, 405 | controller_plant=controller_plant, 406 | X_WC0=x_front_camera, 407 | X_WC1=x_back_left_camera, 408 | X_WC2=x_back_right_camera, 409 | X_WC_bin=x_bin_camera, 410 | meshcat=meshcat, 411 | dirstr=dirstr, 412 | time_horizon=time_horizon, 413 | models_path=os.path.join(dir_path, "scalable_real2sim", "pickplace_data_collection", "scenario_datas", models_path), 414 | gripper_model_path=gripper_model_path, 415 | num_objs=num_objects 416 | ) 417 | ) 418 | 419 | if save_imgs: 420 | builder.Connect(planner.GetOutputPort("planner_state"), img_saver.GetInputPort("planner_state")) 421 | 422 | wsg_state_demux: Demultiplexer = builder.AddSystem(Demultiplexer(2, 1)) 423 | if use_hardware: 424 | # Connect the output of external station to the input of internal station 425 | builder.Connect( 426 | external_station.GetOutputPort("iiwa.position_measured"), 427 | station.GetInputPort("iiwa.position"), 428 | ) 429 | 430 | builder.Connect( 431 | external_station.GetOutputPort("wsg.state_measured"), 432 | wsg_state_demux.get_input_port(), 433 | ) 434 | builder.Connect( 435 | wsg_state_demux.get_output_port(0), 436 | station.GetInputPort("wsg.position"), 437 | ) 438 | builder.Connect( 439 | wsg_state_demux.get_output_port(0), 440 | planner.GetInputPort("wsg.position_measured"), 441 | ) 442 | else: 443 | builder.Connect( 444 | station.GetOutputPort("wsg.state_measured"), 445 | wsg_state_demux.get_input_port(), 446 | ) 447 | builder.Connect( 448 | wsg_state_demux.get_output_port(0), 449 | planner.GetInputPort("wsg.position_measured"), 450 | ) 451 | 452 | 453 | # Connect system ID data saver ports. 454 | builder.Connect(planner.GetOutputPort("planner_state"), sys_id_saver.GetInputPort("planner_state")) 455 | builder.Connect(planner.GetOutputPort("pick_state"), sys_id_saver.GetInputPort("pick_state")) 456 | if not use_hardware: 457 | builder.Connect( 458 | station.GetOutputPort("iiwa.position_measured"), 459 | sys_id_saver.GetInputPort("iiwa.position_measured"), 460 | ) 461 | builder.Connect( 462 | station.GetOutputPort("iiwa.torque_measured"), 463 | sys_id_saver.GetInputPort("iiwa.torque_measured"), 464 | ) 465 | else: 466 | builder.Connect( 467 | external_station.GetOutputPort("iiwa.position_measured"), 468 | sys_id_saver.GetInputPort("iiwa.position_measured"), 469 | ) 470 | builder.Connect( 471 | external_station.GetOutputPort("iiwa.torque_measured"), 472 | sys_id_saver.GetInputPort("iiwa.torque_measured"), 473 | ) 474 | 475 | builder.Connect( 476 | wsg_state_demux.get_output_port(0), 477 | sys_id_saver.GetInputPort("wsg.position_measured"), 478 | ) 479 | 480 | if not use_hardware: 481 | builder.Connect( 482 | station.GetOutputPort("iiwa.position_measured"), 483 | planner.GetInputPort("iiwa_position"), 484 | ) 485 | else: 486 | builder.Connect( 487 | external_station.GetOutputPort("iiwa.position_measured"), 488 | planner.GetInputPort("iiwa_position"), 489 | ) 490 | 491 | joint_traj_source: TrajectoryWithTimingInformationSource = ( 492 | builder.AddNamedSystem( 493 | "joint_traj_source", 494 | TrajectoryWithTimingInformationSource( 495 | trajectory_size=7 496 | ), 497 | ) 498 | ) 499 | 500 | builder.Connect( 501 | planner.GetOutputPort("joint_position_trajectory"), 502 | joint_traj_source.GetInputPort("trajectory"), 503 | ) 504 | if use_hardware: 505 | builder.Connect( 506 | external_station.GetOutputPort("iiwa.position_commanded"), 507 | joint_traj_source.GetInputPort("current_cmd"), 508 | ) 509 | else: 510 | builder.Connect( 511 | station.GetOutputPort("iiwa.position_measured"), 512 | joint_traj_source.GetInputPort("current_cmd"), 513 | ) 514 | 515 | if use_hardware: 516 | builder.Connect( 517 | planner.GetOutputPort("wsg_position"), 518 | external_station.GetInputPort("wsg.position"), 519 | ) 520 | else: 521 | builder.Connect( 522 | planner.GetOutputPort("wsg_position"), 523 | station.GetInputPort("wsg.position"), 524 | ) 525 | 526 | # Increase max force. 527 | wsg_force_source = builder.AddNamedSystem( 528 | "wsg_force_source", ConstantVectorSource([80.0]) # 80N is max 529 | ) 530 | builder.Connect( 531 | wsg_force_source.get_output_port(), station.GetInputPort("wsg.force_limit") 532 | ) 533 | 534 | # Set up logging for system ID. 535 | num_positions = 7 536 | measured_position_logger: VectorLogSink = builder.AddNamedSystem( 537 | "measured_position_logger", 538 | VectorLogSink(num_positions, publish_period=scenario.plant_config.time_step), 539 | ) 540 | builder.Connect( 541 | station.GetOutputPort("iiwa.position_measured"), 542 | measured_position_logger.get_input_port(), 543 | ) 544 | measured_torque_logger: VectorLogSink = builder.AddNamedSystem( 545 | "measured_torque_logger", 546 | VectorLogSink(num_positions, publish_period=scenario.plant_config.time_step), 547 | ) 548 | builder.Connect( 549 | station.GetOutputPort("iiwa.torque_measured"), 550 | measured_torque_logger.get_input_port(), 551 | ) 552 | 553 | # Set up differential inverse kinematics. 554 | velocity_limits = 0.2 * np.ones(7) 555 | acceleration_limits = 0.1 * np.ones(7) 556 | diff_ik = AddIiwaDifferentialIK( 557 | builder, 558 | controller_plant, 559 | frame=None, 560 | velocity_lims=velocity_limits, 561 | acceleration_lims=acceleration_limits, # doesn't actually do anything since using this stops the robot from moving??? 562 | joint_centering_gain=1.0 563 | ) 564 | builder.Connect(planner.GetOutputPort("X_WG"), diff_ik.get_input_port(0)) 565 | if use_hardware: 566 | builder.Connect( 567 | external_station.GetOutputPort("iiwa.state_estimated"), 568 | diff_ik.GetInputPort("robot_state"), 569 | ) 570 | else: 571 | builder.Connect( 572 | station.GetOutputPort("iiwa.state_estimated"), 573 | diff_ik.GetInputPort("robot_state"), 574 | ) 575 | builder.Connect( 576 | planner.GetOutputPort("reset_diff_ik"), 577 | diff_ik.GetInputPort("use_robot_state"), 578 | ) 579 | 580 | # The DiffIK and the direct position-control modes go through a PortSwitch 581 | switch = builder.AddSystem(PortSwitch(7)) 582 | builder.Connect(diff_ik.get_output_port(), switch.DeclareInputPort("diff_ik")) 583 | builder.Connect( 584 | joint_traj_source.get_output_port(), 585 | switch.DeclareInputPort("position"), 586 | ) 587 | if use_hardware: 588 | builder.Connect(switch.get_output_port(), external_station.GetInputPort("iiwa.position")) 589 | else: 590 | builder.Connect(switch.get_output_port(), station.GetInputPort("iiwa.position")) 591 | builder.Connect( 592 | planner.GetOutputPort("control_mode"), 593 | switch.get_port_selector_input_port(), 594 | ) 595 | 596 | builder.Connect( 597 | camera0_pcd.GetOutputPort("point_cloud"), 598 | planner.GetInputPort("cloud_front"), 599 | ) 600 | builder.Connect( 601 | camera1_pcd.GetOutputPort("point_cloud"), 602 | planner.GetInputPort("cloud_back_left"), 603 | ) 604 | builder.Connect( 605 | camera2_pcd.GetOutputPort("point_cloud"), 606 | planner.GetInputPort("cloud_back_right"), 607 | ) 608 | builder.Connect( 609 | camera_bin_pcd.GetOutputPort("point_cloud"), 610 | planner.GetInputPort("cloud_bin"), 611 | ) 612 | builder.Connect( 613 | station.GetOutputPort("body_poses"), 614 | planner.GetInputPort("body_poses"), 615 | ) 616 | 617 | # Build diagram 618 | diagram = builder.Build() 619 | context = diagram.CreateDefaultContext() 620 | 621 | # Simulate 622 | simulator = Simulator(diagram) 623 | simulator_context = simulator.get_mutable_context() 624 | 625 | # Remove labels of anything but mustard and gripper 626 | if not use_hardware: 627 | scene_graph = station.GetSubsystemByName("scene_graph") 628 | source_id = plant.get_source_id() 629 | scene_graph_context = scene_graph.GetMyMutableContextFromRoot(simulator_context) 630 | query_object = scene_graph.get_query_output_port().Eval(scene_graph_context) 631 | inspector = query_object.inspector() 632 | for geometry_id in inspector.GetAllGeometryIds(): 633 | properties = copy.deepcopy(inspector.GetPerceptionProperties(geometry_id)) 634 | if properties is None: 635 | continue 636 | frame_id = inspector.GetFrameId(geometry_id) 637 | body = plant.GetBodyFromFrameId(frame_id) 638 | if body.model_instance() == plant.GetModelInstanceByName("mustard_bottle"): 639 | properties.UpdateProperty("label", "id", RenderLabel(0)) # Make mustard label 0 640 | elif body.model_instance() == plant.GetModelInstanceByName("wsg"): 641 | properties.UpdateProperty("label", "id", RenderLabel(1)) # Make gripper label 1 642 | else: 643 | properties.UpdateProperty("label", "id", RenderLabel.kDontCare) 644 | scene_graph.RemoveRole(scene_graph_context, source_id, geometry_id, Role.kPerception) 645 | scene_graph.AssignRole(scene_graph_context, source_id, geometry_id, properties) 646 | 647 | simulator.set_target_realtime_rate(1.0) 648 | 649 | meshcat.AddButton("Stop Simulation", "Escape") 650 | print("Press Escape to stop the simulation") 651 | while meshcat.GetButtonClicks("Stop Simulation") < 1 and not planner.done: 652 | simulator.AdvanceTo(simulator.get_context().get_time() + 5000.0) 653 | 654 | meshcat.DeleteButton("Stop Simulation") 655 | 656 | 657 | if __name__ == "__main__": 658 | parser = argparse.ArgumentParser() 659 | parser.add_argument( 660 | "--scenario_path", 661 | default="scenario_data_grasping.yml", 662 | help="yaml file with scenario", 663 | ) 664 | parser.add_argument( 665 | "--models_path", 666 | default="scenario_data_grasping.dmd.yaml", 667 | help="dmd.yaml file with scenario, used for checking for collisions", 668 | nargs='?', 669 | ) 670 | parser.add_argument( 671 | "--save_dir", 672 | default="temp", 673 | help="directory to save images in", 674 | nargs='?', 675 | ) 676 | parser.add_argument( 677 | "--num_objects", 678 | default="1", 679 | help="number of objects to scan", 680 | nargs='?', 681 | ) 682 | parser.add_argument( 683 | "--use_hardware", 684 | action="store_true", 685 | help="Whether to use real world hardware.", 686 | ) 687 | parser.add_argument( 688 | "--save_imgs", 689 | action='store_true', 690 | help="yaml file with scenario", 691 | ) 692 | parser.add_argument( 693 | "--use_custom_path_planner", 694 | action="store_true", 695 | help="Whether to use user implemented path planner.", 696 | ) 697 | parser.add_argument( 698 | "--turntable", 699 | action="store_true", 700 | help="Whether to use turntable planner.", 701 | ) 702 | parser.add_argument( 703 | "--time_horizon", 704 | type=float, 705 | default=10.0, 706 | help="The time horizon/ duration of the trajectory. Only used for Fourier " 707 | + "series trajectories.", 708 | ) 709 | args = parser.parse_args() 710 | 711 | directory_path = os.path.dirname(os.path.abspath(__file__)) 712 | gripper_model_path = "package://pickplace_data_collection/schunk_wsg_50_large_grippers_w_buffer.sdf" 713 | 714 | # Start the visualizer. 715 | meshcat = StartMeshcat() 716 | 717 | save_dir_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), 'tests', args.save_dir)) 718 | start_scenario( 719 | save_dir_path, 720 | scenario_path= args.scenario_path, 721 | gripper_model_path=gripper_model_path, 722 | models_path=args.models_path, 723 | use_hardware=args.use_hardware, 724 | save_imgs=args.save_imgs, 725 | num_objects=int(args.num_objects), 726 | turntable=args.turntable, 727 | time_horizon=args.time_horizon, 728 | ) 729 | -------------------------------------------------------------------------------- /scalable_real2sim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nepfaff/scalable-real2sim/edee99c65583c636ea401c64dfa8883e8e4dae5d/scalable_real2sim/__init__.py -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nepfaff/scalable-real2sim/edee99c65583c636ea401c64dfa8883e8e4dae5d/scalable_real2sim/data_processing/__init__.py -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/alpha_channel.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | 11 | def load_mask(mask_path: Path) -> np.ndarray: 12 | """Loads a mask, whether it's in .npy or .png format.""" 13 | if mask_path.suffix == ".npy": 14 | mask = np.load(mask_path) 15 | elif mask_path.suffix == ".png": 16 | mask = np.asarray(Image.open(mask_path).convert("L")) 17 | else: 18 | raise ValueError(f"Unsupported mask format: {mask_path.suffix}") 19 | return mask 20 | 21 | 22 | def process_image_with_alpha( 23 | img_path: Path, mask_path: Path, out_dir_path: Path 24 | ) -> None: 25 | img = np.asarray(Image.open(img_path).convert("RGB")) 26 | mask = load_mask(mask_path) 27 | 28 | # Ensure the mask has the correct shape for concatenation. 29 | if mask.ndim == 2: 30 | mask = mask[:, :, np.newaxis] 31 | 32 | img_w_alpha = np.concatenate((img, mask), axis=-1) 33 | 34 | out_path = out_dir_path / img_path.name 35 | img_w_alpha_pil = Image.fromarray(img_w_alpha) 36 | img_w_alpha_pil.save(out_path) 37 | 38 | 39 | def add_alpha_channel(img_dir: str, mask_dir: str, out_dir: str) -> None: 40 | """Adds an alpha channel to images using corresponding masks. 41 | 42 | This function processes all PNG images in the specified image directory 43 | and combines them with masks from the specified mask directory. The 44 | resulting images with an alpha channel are saved in the specified output 45 | directory. The function ensures that the number of images matches the 46 | number of masks. 47 | 48 | Args: 49 | img_dir (str): The directory containing the input images. 50 | mask_dir (str): The directory containing the input masks (in .npy or .png format). 51 | out_dir (str): The directory where the output images with alpha channels will be 52 | saved. 53 | 54 | Raises: 55 | AssertionError: If no images or masks are found, or if the number of images 56 | does not match the number of masks. 57 | """ 58 | image_dir_path = Path(img_dir) 59 | mask_dir_path = Path(mask_dir) 60 | out_dir_path = Path(out_dir) 61 | 62 | image_files = sorted(list(image_dir_path.glob("*.png"))) 63 | mask_files = sorted( 64 | list(mask_dir_path.glob("*.npy")) + list(mask_dir_path.glob("*.png")) 65 | ) 66 | 67 | assert len(image_files) > 0, f"No images found in {image_dir_path}" 68 | mask_files = sorted( 69 | [mask for mask in mask_files if mask.stem in {img.stem for img in image_files}] 70 | ) 71 | assert len(mask_files) > 0, f"No matching masks found in {mask_dir_path}" 72 | assert len(image_files) == len(mask_files), "Number of images and masks must match." 73 | out_dir_path.mkdir(exist_ok=True) 74 | 75 | with concurrent.futures.ThreadPoolExecutor() as executor: 76 | list( 77 | tqdm( 78 | executor.map( 79 | lambda args: process_image_with_alpha(*args), 80 | zip(image_files, mask_files, [out_dir_path] * len(image_files)), 81 | ), 82 | total=len(image_files), 83 | desc="Adding alpha channel", 84 | ) 85 | ) 86 | -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/colmap.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import shutil 5 | import subprocess 6 | 7 | import numpy as np 8 | 9 | from scipy.spatial.transform import Rotation as R 10 | 11 | 12 | def convert_bundlesdf_to_colmap_format(data_dir: str, output_dir: str) -> None: 13 | # Read intrinsic matrix 14 | cam_K_path = os.path.join(data_dir, "cam_K.txt") 15 | if not os.path.exists(cam_K_path): 16 | raise FileNotFoundError(f"Intrinsic matrix file not found: {cam_K_path}") 17 | cam_K = np.loadtxt(cam_K_path) 18 | 19 | # Extract parameters 20 | fx = cam_K[0, 0] 21 | fy = cam_K[1, 1] 22 | cx = cam_K[0, 2] 23 | cy = cam_K[1, 2] 24 | 25 | camera_id = 1 # Assume a single camera 26 | model = "PINHOLE" 27 | params = [fx, fy, cx, cy] 28 | 29 | # Get list of image files 30 | image_dir = os.path.join(data_dir, "images") 31 | image_files = sorted(glob.glob(os.path.join(image_dir, "*.png"))) 32 | if not image_files: 33 | raise FileNotFoundError(f"No images found in directory: {image_dir}") 34 | 35 | # Get image size from the first image 36 | from PIL import Image 37 | 38 | with Image.open(image_files[0]) as img: 39 | width, height = img.size 40 | 41 | # Create output directories 42 | sparse_dir = os.path.join(output_dir, "sparse", "0") 43 | sparse_bin_dir = os.path.join(output_dir, "sparse", "0_bin") 44 | os.makedirs(sparse_dir, exist_ok=True) 45 | os.makedirs(sparse_bin_dir, exist_ok=True) 46 | 47 | # Write cameras.txt 48 | cameras_txt_path = os.path.join(sparse_dir, "cameras.txt") 49 | with open(cameras_txt_path, "w") as f: 50 | f.write("# Camera list with one line of data per camera:\n") 51 | f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") 52 | f.write( 53 | f'{camera_id} {model} {int(width)} {int(height)} {" ".join(map(str, params))}\n' 54 | ) 55 | 56 | # Write images.txt 57 | images_txt_path = os.path.join(sparse_dir, "images.txt") 58 | with open(images_txt_path, "w") as f: 59 | f.write("# Image list with two lines of data per image:\n") 60 | f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, IMAGE_NAME\n") 61 | f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n") 62 | 63 | for idx, img_path in enumerate(image_files): 64 | image_id = idx + 1 65 | image_name = os.path.basename(img_path) 66 | 67 | # Read world-to-camera transformation directly from ob_in_cam (X_CW) 68 | pose_idx = os.path.splitext(image_name)[0] # Get '000000' from '000000.png' 69 | pose_path = os.path.join(data_dir, "poses", f"{pose_idx}.txt") 70 | if not os.path.exists(pose_path): 71 | raise FileNotFoundError(f"Pose file not found: {pose_path}") 72 | X_CW = np.loadtxt(pose_path) # Shape (4, 4) 73 | 74 | # Extract rotation and translation directly 75 | R_CW = X_CW[:3, :3] 76 | p_CW = X_CW[:3, 3] 77 | 78 | # Convert rotation matrix to quaternion 79 | rot = R.from_matrix(R_CW) 80 | quat = rot.as_quat() # Returns [qx, qy, qz, qw] 81 | qx, qy, qz, qw = quat 82 | # COLMAP expects [qw, qx, qy, qz] 83 | 84 | # Write to images.txt 85 | f.write( 86 | f"{image_id} {qw} {qx} {qy} {qz} {p_CW[0]} {p_CW[1]} {p_CW[2]} {camera_id} {image_name}\n" 87 | ) 88 | 89 | # Write empty line for 2D points (since we have none) 90 | f.write("\n") 91 | 92 | # Write empty points3D.txt 93 | points3D_txt_path = os.path.join(sparse_dir, "points3D.txt") 94 | with open(points3D_txt_path, "w") as f: 95 | f.write("# 3D point list with one line of data per point:\n") 96 | f.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[]\n") 97 | f.write("\n") 98 | 99 | # Convert text format to binary using COLMAP's model_converter. 100 | subprocess.run( 101 | [ 102 | "colmap", 103 | "model_converter", 104 | "--input_path", 105 | sparse_dir, 106 | "--output_path", 107 | sparse_bin_dir, 108 | "--output_type", 109 | "BIN", 110 | ], 111 | check=True, 112 | ) 113 | 114 | if not data_dir == output_dir: 115 | # Copy images to images directory 116 | images_output_dir = os.path.join(output_dir, "images") 117 | os.makedirs(images_output_dir, exist_ok=True) 118 | 119 | for img_path in image_files: 120 | shutil.copy(img_path, images_output_dir) 121 | 122 | logging.info("Conversion to COLMAP format completed successfully.") 123 | -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/frosting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | import subprocess 4 | 5 | from pathlib import Path 6 | 7 | from .alpha_channel import add_alpha_channel 8 | from .colmap import convert_bundlesdf_to_colmap_format 9 | from .image_subsampling import select_and_copy_dissimilar_images 10 | from .masks import invert_masks_in_directory 11 | 12 | 13 | def process_moving_obj_data_for_sugar( 14 | input_dir: str, output_dir: str, num_images: int = 800, use_depth: bool = False 15 | ): 16 | """Process moving object/static camera data for SuGAR. 17 | 18 | Args: 19 | input_dir: Path to input directory containing: 20 | - cam_K.txt # Camera intrinsic parameters 21 | - gripper_masks/ # Masks for the gripper 22 | - masks/ # Masks for the object + gripper 23 | - ob_in_cam/ # Poses of the object in camera frame (X_CO) 24 | - rgb/ # RGB images 25 | output_dir: Path to output directory 26 | num_images: Number of images to sample (default: 800) 27 | use_depth: Whether to use depth images (default: False) 28 | """ 29 | input_dir = Path(input_dir) 30 | output_dir = Path(output_dir) 31 | 32 | # Step 0: Copy input directory to output directory 33 | if output_dir.exists(): 34 | shutil.rmtree(output_dir) 35 | shutil.copytree(input_dir, output_dir) 36 | 37 | rgb_alpha_dir = output_dir / "rgb_alpha" 38 | add_alpha_channel( 39 | img_dir=str(output_dir / "rgb"), 40 | mask_dir=str(output_dir / "masks"), 41 | out_dir=str(rgb_alpha_dir), 42 | ) 43 | 44 | gripper_masks_inverted_dir = output_dir / "gripper_masks_inverted" 45 | invert_masks_in_directory( 46 | input_dir=str(output_dir / "gripper_masks"), 47 | output_dir=str(gripper_masks_inverted_dir), 48 | ) 49 | 50 | images_dino_sampled_dir = output_dir / "images_dino_sampled" 51 | if len(list((rgb_alpha_dir).glob("*.png"))) >= num_images: 52 | select_and_copy_dissimilar_images( 53 | image_dir=str(rgb_alpha_dir), 54 | output_dir=str(images_dino_sampled_dir), 55 | K=num_images, 56 | model_name="dino", 57 | ) 58 | else: 59 | logging.info( 60 | f"Skipping image selection: not enough images " 61 | f"(found: {len(list((rgb_alpha_dir).glob('*.png')))}, required: {num_images})" 62 | ) 63 | shutil.copytree(rgb_alpha_dir, images_dino_sampled_dir) 64 | 65 | shutil.rmtree(output_dir / "gripper_masks", ignore_errors=True) 66 | shutil.rmtree(output_dir / "gripper_masks", ignore_errors=True) 67 | (output_dir / "gripper_masks_inverted").rename(output_dir / "gripper_masks") 68 | 69 | shutil.rmtree(output_dir / "images", ignore_errors=True) 70 | (output_dir / "images_dino_sampled").rename(output_dir / "images") 71 | 72 | shutil.rmtree(output_dir / "poses", ignore_errors=True) 73 | (output_dir / "ob_in_cam").rename(output_dir / "poses") 74 | 75 | convert_bundlesdf_to_colmap_format( 76 | data_dir=str(output_dir), output_dir=str(output_dir) 77 | ) 78 | 79 | compute_pcd_script = ( 80 | Path(__file__).parent.parent / "Frosting/gaussian_splatting/compute_pcd_init.sh" 81 | ) 82 | subprocess.run(["bash", str(compute_pcd_script), str(output_dir)], check=True) 83 | 84 | if use_depth: 85 | convert_depth_script = ( 86 | Path(__file__).parent.parent 87 | / "Frosting/gaussian_splatting/utils/make_depth_scales.py" 88 | ) 89 | args = [ 90 | "--base_dir", 91 | str(output_dir), 92 | "--depths_dir", 93 | str(output_dir / "depths"), 94 | ] 95 | subprocess.run(["python", str(convert_depth_script), *args], check=True) 96 | else: 97 | logging.info("Skipping depth image conversion.") 98 | 99 | logging.info("Done preprocessing data for Frosting.") 100 | -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/image_subsampling.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import logging 3 | import os 4 | import shutil 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from PIL import Image 10 | from scipy.spatial.distance import cdist 11 | from tqdm import tqdm 12 | 13 | 14 | def load_and_preprocess_images(image_paths, preprocess, device): 15 | features = [] 16 | valid_image_paths = [] 17 | 18 | def process_image(img_path): 19 | try: 20 | img = Image.open(img_path).convert("RGB") 21 | input_tensor = preprocess(img).unsqueeze(0).to(device) 22 | return input_tensor, img_path 23 | except Exception as e: 24 | logging.error(f"Error processing {img_path}: {e}") 25 | return None, None 26 | 27 | with concurrent.futures.ThreadPoolExecutor() as executor: 28 | results = list( 29 | tqdm( 30 | executor.map(process_image, image_paths), 31 | total=len(image_paths), 32 | desc="Processing images", 33 | ) 34 | ) 35 | 36 | for feature, img_path in results: 37 | if feature is not None: 38 | features.append(feature) 39 | valid_image_paths.append(img_path) 40 | 41 | if not features: 42 | raise ValueError("No valid images were found.") 43 | return torch.cat(features, dim=0), valid_image_paths 44 | 45 | 46 | def extract_features(model_name, model, images, batch_size: int): 47 | all_features = [] 48 | 49 | # Split images into batches 50 | images_batched = torch.split(images, batch_size) 51 | 52 | with torch.no_grad(): 53 | for image_batch in tqdm(images_batched, desc="Extracting features (batches)"): 54 | if model_name == "clip": 55 | batch_features = model.encode_image(image_batch) 56 | elif model_name == "dino": 57 | batch_features = model(image_batch) 58 | else: 59 | raise ValueError(f"Unsupported model: {model_name}") 60 | 61 | # Normalize features 62 | batch_features = batch_features / batch_features.norm(dim=1, keepdim=True) 63 | 64 | # Append batch features 65 | all_features.append(batch_features.cpu().numpy()) 66 | 67 | # Concatenate all batch features 68 | return np.concatenate(all_features, axis=0) 69 | 70 | 71 | def select_most_dissimilar_images( 72 | features: np.ndarray, K: int, N: int | None = None 73 | ) -> list[int]: 74 | """ 75 | Selects the K most dissimilar images based on cosine distance between their features. 76 | 77 | Args: 78 | features (np.ndarray): An array of image features where each row corresponds to 79 | an image. 80 | K (int): The number of dissimilar images to select. 81 | N (int, optional): Maximum number of consecutive frames to skip between selected 82 | frames. 83 | 84 | Returns: 85 | list[int]: A sorted list of indices representing the selected dissimilar images. 86 | """ 87 | distance_matrix = cdist(features, features, metric="cosine") 88 | N_total = len(features) 89 | 90 | # If N is provided, first select frames with uniform spacing. 91 | if N is not None: 92 | first_round_indices = list(range(0, N_total, N + 1)) 93 | remaining_indices = set(range(N_total)) - set(first_round_indices) 94 | 95 | if len(first_round_indices) >= K: 96 | return first_round_indices[:K] 97 | 98 | selected_indices = first_round_indices 99 | num_remaining = K - len(first_round_indices) 100 | logging.info( 101 | f"Selected {len(selected_indices)} images with uniform spacing. Selecting " 102 | f"remaining {num_remaining} images based on dissimilarity..." 103 | ) 104 | else: 105 | # Start with the most dissimilar image. 106 | selected_indices = [np.argmax(np.sum(distance_matrix, axis=1))] 107 | remaining_indices = set(range(N_total)) - set(selected_indices) 108 | num_remaining = K - 1 109 | 110 | # Select the remaining images based on dissimilarity 111 | for _ in tqdm(range(num_remaining), "Selecting most dissimilar images"): 112 | min_distances = np.min( 113 | distance_matrix[list(remaining_indices)][:, selected_indices], axis=1 114 | ) 115 | next_index = list(remaining_indices)[np.argmax(min_distances)] 116 | 117 | selected_indices.append(next_index) 118 | remaining_indices.remove(next_index) 119 | 120 | return sorted(selected_indices) 121 | 122 | 123 | def select_and_copy_dissimilar_images( 124 | image_dir: str, 125 | output_dir: str, 126 | K: int, 127 | N: int | None = None, 128 | model_name: str = "dino", 129 | device: str | None = None, 130 | batch_size: int = 256, 131 | ) -> None: 132 | """Selects the K most dissimilar images from a directory and copies them to an output directory. 133 | 134 | Args: 135 | image_dir (str): The directory containing the input images. 136 | output_dir (str): The directory where the selected images will be saved. 137 | K (int): The number of dissimilar images to select. 138 | N (int, optional): Maximum number of consecutive frames to skip between selected 139 | frames. 140 | model_name (str, optional): The name of the model to use for feature extraction 141 | ('clip' or 'dino'). 142 | device (str, optional): The device to use for model inference ('cuda' or 'cpu'). 143 | If not specified, it will automatically select based on availability. 144 | batch_size (int, optional): The batch size for processing images. Default is 256. 145 | """ 146 | # Automatically select device if not specified 147 | if device is None: 148 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 149 | logging.info(f"Using device: {device}") 150 | else: 151 | device = torch.device(device) 152 | 153 | # Get all image file paths 154 | image_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif") 155 | image_paths = [ 156 | os.path.join(image_dir, fname) 157 | for fname in os.listdir(image_dir) 158 | if fname.lower().endswith(image_extensions) 159 | ] 160 | 161 | if len(image_paths) < K: 162 | logging.warning(f"Number of images ({len(image_paths)}) is less than K ({K}).") 163 | return 164 | 165 | # Load the selected model 166 | if model_name == "clip": 167 | import clip 168 | 169 | model, preprocess = clip.load("ViT-B/32", device=device) 170 | elif model_name == "dino": 171 | # Install timm if not already installed 172 | try: 173 | import timm 174 | except ImportError: 175 | raise ImportError( 176 | "timm is not installed. Please install it with `pip install timm`." 177 | ) 178 | 179 | model = timm.create_model("vit_base_patch16_224_dino", pretrained=True) 180 | model.eval() 181 | model.to(device) 182 | 183 | # Define DINO preprocessing 184 | from torchvision import transforms 185 | 186 | preprocess = transforms.Compose( 187 | [ 188 | transforms.Resize(248), 189 | transforms.CenterCrop(224), 190 | transforms.ToTensor(), 191 | transforms.Normalize( 192 | mean=(0.485, 0.456, 0.406), # ImageNet means 193 | std=(0.229, 0.224, 0.225), # ImageNet stds 194 | ), 195 | ] 196 | ) 197 | else: 198 | raise ValueError(f"Unsupported model: {model_name}") 199 | 200 | try: 201 | # Load and preprocess images 202 | images, valid_image_paths = load_and_preprocess_images( 203 | image_paths, preprocess, device 204 | ) 205 | 206 | # Extract features 207 | features = extract_features(model_name, model, images, batch_size) 208 | 209 | # Select the K most dissimilar images 210 | selected_indices = select_most_dissimilar_images(features, K, N) 211 | selected_images = [valid_image_paths[idx] for idx in selected_indices] 212 | 213 | # Create output directory if it doesn't exist 214 | os.makedirs(output_dir, exist_ok=True) 215 | 216 | # Copy selected images to the output directory 217 | with concurrent.futures.ThreadPoolExecutor() as executor: 218 | list( 219 | tqdm( 220 | executor.map( 221 | lambda img_path: shutil.copy(img_path, output_dir), 222 | selected_images, 223 | ), 224 | total=len(selected_images), 225 | desc="Copying selected images", 226 | ) 227 | ) 228 | 229 | logging.info(f"Selected {K} most dissimilar images using {model_name.upper()}.") 230 | logging.info(f"\nSelected images have been saved to '{output_dir}' directory.") 231 | finally: 232 | # Clean up GPU memory 233 | if torch.cuda.is_available(): 234 | model.cpu() 235 | torch.cuda.empty_cache() 236 | del model, images, features 237 | -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/masks.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | def invert_binary_mask(input_path: str, output_path: str) -> None: 9 | """Inverts a binary mask. 10 | 11 | This function opens an image file, converts it to grayscale, inverts the pixel 12 | values, and saves the inverted image to a new file. 13 | """ 14 | # Open the image 15 | img = Image.open(input_path) 16 | 17 | # Convert the image to grayscale (if it's not already) 18 | img = img.convert("L") 19 | 20 | # Invert the binary mask 21 | inverted_img = Image.eval(img, lambda x: 255 - x) 22 | 23 | # Save the inverted image 24 | inverted_img.save(output_path) 25 | 26 | 27 | def invert_masks_in_directory(input_dir: str, output_dir: str) -> None: 28 | """Inverts all binary masks in a directory. 29 | 30 | This function processes all PNG files in the specified input directory, 31 | inverts their pixel values, and saves the inverted images to the specified 32 | output directory. 33 | """ 34 | # Create output directory if it doesn't exist 35 | if not os.path.exists(output_dir): 36 | os.makedirs(output_dir) 37 | 38 | # Get all PNG files in the input directory 39 | mask_files = [f for f in os.listdir(input_dir) if f.endswith(".png")] 40 | 41 | def process_mask(mask_file): 42 | input_path = os.path.join(input_dir, mask_file) 43 | output_path = os.path.join(output_dir, mask_file) 44 | invert_binary_mask(input_path, output_path) 45 | 46 | # Invert each mask and save to the output directory using ThreadPoolExecutor 47 | with concurrent.futures.ThreadPoolExecutor() as executor: 48 | list( 49 | tqdm( 50 | executor.map(process_mask, mask_files), 51 | total=len(mask_files), 52 | desc="Inverting masks", 53 | ) 54 | ) 55 | -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/nerfstudio.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import shutil 7 | 8 | import numpy as np 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | from .alpha_channel import add_alpha_channel 14 | from .image_subsampling import select_and_copy_dissimilar_images 15 | from .masks import invert_masks_in_directory 16 | 17 | 18 | def convert_txt_or_png_to_nerfstudio_depth( 19 | folder_path: str, 20 | output_folder: str, 21 | image_folder: str = None, 22 | max_depth_value: int = 65535, 23 | bit_depth: int = 16, 24 | ) -> None: 25 | """Converts depth data from .txt or .png files to the format expected by Nerfstudio. 26 | 27 | This function processes all relevant files in the specified folder, reading depth data 28 | from either .txt or .png files. The depth values are assumed to be in meters for .txt 29 | files and in millimeters for .png files. The processed depth data is saved in the 30 | specified output folder. 31 | Background depth values are set to max_depth_value. 32 | 33 | Args: 34 | folder_path (str): The path to the folder containing the input .txt or .png files. 35 | output_folder (str): The path to the folder where the output depth data will be saved. 36 | image_folder (str, optional): An optional path to an image folder (default is None). 37 | If provided, the function will adjust the depth values based on the transparency 38 | (alpha channel) of the corresponding image. 39 | max_depth_value (int, optional): The maximum depth value to consider (default is 65535). 40 | bit_depth (int, optional): The bit depth of the output data (default is 16). 41 | """ 42 | # Ensure the output folder exists 43 | os.makedirs(output_folder, exist_ok=True) 44 | 45 | # Regular expression to extract numeric part from filenames 46 | number_pattern = re.compile(r"(\d+)") 47 | 48 | # List all relevant files in the folder (both .txt and .png) 49 | files = [f for f in os.listdir(folder_path) if f.endswith((".txt", ".png"))] 50 | 51 | def process_file(filename): 52 | file_path = os.path.join(folder_path, filename) 53 | 54 | # Handle both .txt and .png files 55 | if filename.endswith(".txt"): 56 | # Extract numeric part from the filename 57 | match = number_pattern.search(filename) 58 | if not match: 59 | logging.warning(f"No numeric part found in {filename}. Skipping.") 60 | return 61 | numeric_part = match.group(1) 62 | 63 | # Read depth data from the text file 64 | try: 65 | depth_data_meters = np.loadtxt(file_path) 66 | except Exception as e: 67 | logging.error(f"Error reading {filename}: {e}") 68 | return 69 | elif filename.endswith(".png"): 70 | # Extract numeric part from the filename 71 | match = number_pattern.search(filename) 72 | if not match: 73 | logging.warning(f"No numeric part found in {filename}. Skipping.") 74 | return 75 | numeric_part = match.group(1) 76 | 77 | # Read depth data from the PNG file (assuming it contains depth values) 78 | try: 79 | depth_image = Image.open(file_path).convert( 80 | "I" 81 | ) # 'I' mode for 32-bit pixels 82 | depth_data_meters = ( 83 | np.array(depth_image) / 1000.0 84 | ) # Assuming PNG stores millimeters 85 | except Exception as e: 86 | logging.error(f"Error reading {filename}: {e}") 87 | return 88 | else: 89 | return # Skip any files that aren't .txt or .png 90 | 91 | # Convert depth from meters to millimeters 92 | depth_data_mm = depth_data_meters * 1000.0 93 | 94 | # Clip depth values to the specified maximum depth value 95 | depth_data_mm_clipped = np.clip(depth_data_mm, 0, max_depth_value) 96 | 97 | # If image_folder is provided, adjust depth values based on transparency 98 | if image_folder: 99 | # Look for the corresponding image file in the image folder 100 | image_file_found = False 101 | for image_filename in os.listdir(image_folder): 102 | if image_filename.endswith((".png", ".jpg", ".jpeg", ".tiff", ".bmp")): 103 | image_match = number_pattern.search(image_filename) 104 | if image_match and numeric_part == image_match.group(1): 105 | image_file_path = os.path.join(image_folder, image_filename) 106 | image_file_found = True 107 | break 108 | 109 | if not image_file_found: 110 | return 111 | 112 | # Open the corresponding image to check for transparency 113 | try: 114 | # Open the image and ensure it has an alpha channel 115 | image = Image.open(image_file_path).convert("RGBA") 116 | alpha_channel = image.split()[3] # Extract the alpha channel 117 | alpha_array = np.array(alpha_channel) 118 | 119 | # Create a mask where alpha == 0 (transparent pixels) 120 | transparent_mask = alpha_array == 0 121 | 122 | # Ensure the mask dimensions match the depth data dimensions 123 | if transparent_mask.shape != depth_data_mm_clipped.shape: 124 | logging.warning( 125 | f"Dimension mismatch between depth data and alpha " 126 | f"mask in {filename}." 127 | ) 128 | return 129 | 130 | # Set depth values to max_depth_value where the image is transparent 131 | depth_data_mm_clipped[transparent_mask] = max_depth_value 132 | 133 | except Exception as e: 134 | logging.error(f"Error processing image file {image_file_path}: {e}") 135 | return 136 | 137 | # Convert depth data to the specified bit depth 138 | if bit_depth == 16: 139 | depth_data_uint = depth_data_mm_clipped.astype(np.uint16) 140 | elif bit_depth == 32: 141 | depth_data_uint = depth_data_mm_clipped.astype(np.uint32) 142 | else: 143 | logging.error( 144 | f"Unsupported bit depth: {bit_depth}. Supported values are 16 and 32." 145 | ) 146 | return 147 | 148 | # Construct the output filename 149 | base_filename = os.path.splitext(filename)[0] 150 | if bit_depth == 16: 151 | output_filename = base_filename + ".png" 152 | else: # For 32-bit, use TIFF format 153 | output_filename = base_filename + ".tiff" 154 | output_file_path = os.path.join(output_folder, output_filename) 155 | 156 | # Save the depth data as an image 157 | try: 158 | depth_image = Image.fromarray(depth_data_uint) 159 | if bit_depth == 16: 160 | depth_image.save(output_file_path, format="PNG") 161 | else: # Save as TIFF for 32-bit depth 162 | depth_image.save(output_file_path, format="TIFF") 163 | except Exception as e: 164 | logging.error(f"Error saving {output_filename}: {e}") 165 | 166 | # Use ThreadPoolExecutor to speed up the processing 167 | with concurrent.futures.ThreadPoolExecutor() as executor: 168 | list( 169 | tqdm( 170 | executor.map(process_file, files), 171 | total=len(files), 172 | desc="Processing files", 173 | ) 174 | ) 175 | 176 | 177 | def transform_pose_opengl_to_opencv(pose: np.ndarray) -> np.ndarray: 178 | """Converts a pose from OpenGL to OpenCV format or vice versa (inverse transform 179 | is identical). 180 | 181 | This function takes a 4x4 pose matrix in OpenGL format and converts it to OpenCV 182 | format. The conversion involves flipping the y and z axes of the pose matrix. 183 | """ 184 | flip_yz = np.eye(4) 185 | flip_yz[1, 1] = -1 186 | flip_yz[2, 2] = -1 187 | return pose @ flip_yz 188 | 189 | 190 | def convert_bundle_sdf_poses_to_nerfstudio_poses(folder_path: str) -> None: 191 | """Converts BundleSDF poses to the format expected by Nerfstudio. The poses are 192 | overwritten in place. 193 | """ 194 | with os.scandir(folder_path) as paths: 195 | for path in paths: 196 | X_CW_opencv = np.loadtxt(path.path) 197 | X_WC_opencv = np.linalg.inv(X_CW_opencv) 198 | X_WC_opengl = transform_pose_opengl_to_opencv(X_WC_opencv) 199 | np.savetxt(path.path, X_WC_opengl) 200 | 201 | 202 | def read_intrinsics(cam_K_path): 203 | """ 204 | Reads the camera intrinsics from cam_K.txt. 205 | """ 206 | K = np.loadtxt(cam_K_path) 207 | if K.shape != (3, 3): 208 | raise ValueError("Camera intrinsics should be a 3x3 matrix.") 209 | return K 210 | 211 | 212 | def read_pose(pose_path): 213 | """ 214 | Reads the camera pose from a text file. 215 | """ 216 | pose = np.loadtxt(pose_path) 217 | if pose.shape != (4, 4): 218 | raise ValueError(f"Pose file {pose_path} should contain a 4x4 matrix.") 219 | return pose 220 | 221 | 222 | def collect_files(directory, extensions): 223 | """ 224 | Collects files from a directory with specified extensions. 225 | Returns a dictionary mapping from numerical filenames to file paths. 226 | """ 227 | files = {} 228 | for filename in os.listdir(directory): 229 | name, ext = os.path.splitext(filename) 230 | if ext.lower() in extensions and name.isdigit(): 231 | idx = int(name) 232 | files[idx] = os.path.join(directory, filename) 233 | return files 234 | 235 | 236 | def downsample_indices(indices, num_images): 237 | """ 238 | Downsamples the list of indices to the desired number of images. 239 | """ 240 | if num_images >= len(indices) or num_images <= 0: 241 | return indices # Return all indices if num_images is invalid 242 | indices = sorted(indices) 243 | interval = len(indices) / num_images 244 | selected_indices = [indices[int(i * interval)] for i in range(num_images)] 245 | return selected_indices 246 | 247 | 248 | def create_nerfstudio_transforms_json( 249 | data_dir: str, 250 | output_path: str, 251 | use_depth: bool = False, 252 | depth_dir: str = None, 253 | use_masks: bool = False, 254 | masks_dir: str = None, 255 | num_images: int = None, 256 | ) -> None: 257 | """ 258 | Creates the transforms.json file for Nerfstudio. 259 | 260 | Args: 261 | data_dir (str): The directory containing the camera intrinsics, images, and poses. 262 | The directory should contain the following files: 263 | - cam_K.txt: The camera intrinsics. 264 | - images: The directory containing the images. 265 | - poses: The directory containing the poses. 266 | - optional: depth: The directory containing the depth files. 267 | - optional: gripper_masks: The directory containing the gripper masks. 268 | output_path (str): The path where the transforms.json file will be saved. 269 | use_depth (bool, optional): Whether to include depth information. Defaults to False. 270 | depth_dir (str, optional): The directory containing depth files. Defaults to None. 271 | use_masks (bool, optional): Whether to include mask information. Defaults to False. 272 | masks_dir (str, optional): The directory containing mask files. Defaults to None. 273 | num_images (int, optional): The number of images to include in the output. Defaults to None. 274 | """ 275 | # Paths to required files and directories 276 | cam_K_path = os.path.join(data_dir, "cam_K.txt") 277 | images_dir = os.path.join(data_dir, "images") 278 | poses_dir = os.path.join(data_dir, "poses") 279 | 280 | if use_depth: 281 | depth_dir = depth_dir if depth_dir else os.path.join(data_dir, "depth") 282 | if use_masks: 283 | masks_dir = masks_dir if masks_dir else os.path.join(data_dir, "gripper_masks") 284 | 285 | # Read camera intrinsics 286 | K = read_intrinsics(cam_K_path) 287 | fx = K[0, 0] 288 | fy = K[1, 1] 289 | cx = K[0, 2] 290 | cy = K[1, 2] 291 | 292 | # Collect image files 293 | image_files = collect_files( 294 | images_dir, extensions={".jpg", ".png", ".jpeg", ".tiff"} 295 | ) 296 | pose_files = collect_files(poses_dir, extensions={".txt"}) 297 | 298 | if use_depth: 299 | depth_files = collect_files(depth_dir, extensions={".png", ".jpg", ".exr"}) 300 | if use_masks: 301 | mask_files = collect_files(masks_dir, extensions={".png", ".jpg", ".bmp"}) 302 | 303 | # Ensure that we have matching images and poses 304 | indices = sorted(set(image_files.keys()) & set(pose_files.keys())) 305 | 306 | # Downsample indices if num_images is specified 307 | if num_images is not None: 308 | indices = downsample_indices(indices, num_images) 309 | 310 | frames = [] 311 | for idx in indices: 312 | image_path = os.path.relpath(image_files[idx], data_dir) 313 | pose_path = pose_files[idx] 314 | 315 | # Read pose 316 | transform_matrix = read_pose(pose_path) 317 | transform_matrix = transform_matrix.tolist() 318 | 319 | frame = { 320 | "file_path": image_path, 321 | "transform_matrix": transform_matrix, 322 | } 323 | 324 | if use_depth and idx in depth_files: 325 | depth_path = os.path.relpath(depth_files[idx], data_dir) 326 | frame["depth_file_path"] = depth_path 327 | 328 | if use_masks and idx in mask_files: 329 | mask_path = os.path.relpath(mask_files[idx], data_dir) 330 | frame["mask_path"] = mask_path 331 | 332 | frames.append(frame) 333 | 334 | # Collect image dimensions from the first image 335 | if frames: 336 | sample_image_path = os.path.join(data_dir, frames[0]["file_path"]) 337 | from PIL import Image 338 | 339 | with Image.open(sample_image_path) as img: 340 | w, h = img.size 341 | else: 342 | raise ValueError("No frames available to process.") 343 | 344 | # Construct the transforms.json data 345 | transforms = { 346 | "fl_x": fx, 347 | "fl_y": fy, 348 | "cx": cx, 349 | "cy": cy, 350 | "w": w, 351 | "h": h, 352 | "camera_model": "OPENCV", 353 | "frames": frames, 354 | } 355 | 356 | # Save to output path 357 | with open(output_path, "w") as f: 358 | json.dump(transforms, f, indent=4) 359 | 360 | logging.info(f"transforms.json file saved to {output_path}") 361 | 362 | 363 | def preprocess_data_for_nerfstudio( 364 | data_dir: str, output_dir: str, num_images: int | None = None 365 | ) -> None: 366 | """Preprocesses the data for Nerfstudio. 367 | 368 | Args: 369 | data_dir (str): The data directory after BundleSDF processing. 370 | output_dir (str): The directory where the preprocessed data will be saved. 371 | num_images (int, optional): Number of images to sample. If None, uses all images. 372 | """ 373 | # Step 0: Copy input directory to output directory 374 | os.makedirs(output_dir, exist_ok=True) 375 | os.system(f"cp -r {data_dir}/* {output_dir}") 376 | 377 | # Create temporary directories for intermediate processing 378 | rgb_alpha_dir = os.path.join(output_dir, "rgb_alpha") 379 | depth_converted_dir = os.path.join(output_dir, "depth_converted") 380 | gripper_masks_inverted_dir = os.path.join(output_dir, "gripper_masks_inverted") 381 | images_sampled_dir = os.path.join(output_dir, "images_dino_sampled") 382 | 383 | # Step 1: Add alpha channel to RGB images 384 | add_alpha_channel( 385 | img_dir=os.path.join(output_dir, "rgb"), 386 | mask_dir=os.path.join(output_dir, "masks"), 387 | out_dir=rgb_alpha_dir, 388 | ) 389 | 390 | # Step 2: Convert depth data to Nerfstudio format 391 | convert_txt_or_png_to_nerfstudio_depth( 392 | folder_path=os.path.join(output_dir, "depth"), 393 | output_folder=depth_converted_dir, 394 | image_folder=rgb_alpha_dir, 395 | ) 396 | 397 | # Step 3: Convert poses to OpenGL format 398 | convert_bundle_sdf_poses_to_nerfstudio_poses( 399 | folder_path=os.path.join(output_dir, "ob_in_cam") 400 | ) 401 | 402 | # Step 4: Invert gripper masks 403 | invert_masks_in_directory( 404 | input_dir=os.path.join(output_dir, "gripper_masks"), 405 | output_dir=gripper_masks_inverted_dir, 406 | ) 407 | 408 | # Step 5: Sample most dissimilar images 409 | if num_images is not None: 410 | select_and_copy_dissimilar_images( 411 | image_dir=rgb_alpha_dir, 412 | output_dir=images_sampled_dir, 413 | K=num_images, 414 | model_name="dino", 415 | ) 416 | else: 417 | logging.info("Skipping image sampling as num_images is not specified.") 418 | shutil.copytree(rgb_alpha_dir, images_sampled_dir) 419 | 420 | # Step 6: Create filtered directories based on sampled images 421 | depth_filtered_dir = os.path.join(output_dir, "depth_filtered") 422 | poses_filtered_dir = os.path.join(output_dir, "poses_filtered") 423 | gripper_masks_filtered_dir = os.path.join(output_dir, "gripper_masks_filtered") 424 | 425 | os.makedirs(depth_filtered_dir, exist_ok=True) 426 | os.makedirs(poses_filtered_dir, exist_ok=True) 427 | os.makedirs(gripper_masks_filtered_dir, exist_ok=True) 428 | 429 | # Get base filenames from sampled images 430 | sampled_files = [ 431 | f 432 | for f in os.listdir(images_sampled_dir) 433 | if f.endswith((".png", ".jpg", ".jpeg")) 434 | ] 435 | basenames = [os.path.splitext(f)[0] for f in sampled_files] 436 | 437 | # Copy corresponding files 438 | for basename in basenames: 439 | if os.path.exists(os.path.join(depth_converted_dir, f"{basename}.png")): 440 | os.system(f"cp {depth_converted_dir}/{basename}.png {depth_filtered_dir}/") 441 | if os.path.exists(os.path.join(output_dir, "ob_in_cam", f"{basename}.txt")): 442 | os.system(f"cp {output_dir}/ob_in_cam/{basename}.txt {poses_filtered_dir}/") 443 | if os.path.exists(os.path.join(gripper_masks_inverted_dir, f"{basename}.png")): 444 | os.system( 445 | f"cp {gripper_masks_inverted_dir}/{basename}.png {gripper_masks_filtered_dir}/" 446 | ) 447 | 448 | # Step 7: Clean up and rename directories 449 | cleanup_dirs = [ 450 | "rgb", 451 | "rgb_alpha", 452 | "masks", 453 | "gripper_masks", 454 | "gripper_masks_inverted", 455 | "depth", 456 | "depth_converted", 457 | "ob_in_cam", 458 | ] 459 | for d in cleanup_dirs: 460 | path = os.path.join(output_dir, d) 461 | if os.path.exists(path): 462 | os.system(f"rm -rf {path}") 463 | 464 | # Rename filtered directories to final names 465 | os.system(f"mv {gripper_masks_filtered_dir} {output_dir}/gripper_masks") 466 | os.system(f"mv {depth_filtered_dir} {output_dir}/depth") 467 | os.system(f"mv {images_sampled_dir} {output_dir}/images") 468 | os.system(f"mv {poses_filtered_dir} {output_dir}/poses") 469 | 470 | # Remove any .mp4 files 471 | os.system(f"rm -f {output_dir}/*.mp4") 472 | 473 | # Step 8: Create transforms.json 474 | create_nerfstudio_transforms_json( 475 | data_dir=output_dir, 476 | output_path=os.path.join(output_dir, "transforms.json"), 477 | use_depth=True, 478 | use_masks=True, 479 | num_images=num_images, 480 | ) 481 | 482 | logging.info("Done preprocessing data for Nerfstudio.") 483 | -------------------------------------------------------------------------------- /scalable_real2sim/data_processing/neuralangelo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | 5 | from pathlib import Path 6 | from types import SimpleNamespace 7 | 8 | from scalable_real2sim.neuralangelo.projects.neuralangelo.scripts.convert_data_to_json import ( 9 | data_to_json, 10 | ) 11 | from scalable_real2sim.neuralangelo.projects.neuralangelo.scripts.generate_config import ( 12 | generate_config, 13 | ) 14 | 15 | from .alpha_channel import add_alpha_channel 16 | from .colmap import convert_bundlesdf_to_colmap_format 17 | from .image_subsampling import select_and_copy_dissimilar_images 18 | from .masks import invert_masks_in_directory 19 | 20 | 21 | def process_moving_obj_data_for_neuralangelo( 22 | input_dir: str, output_dir: str, num_images: int = 800 23 | ) -> str: 24 | """Process moving object/static camera data for Neuralangelo. 25 | 26 | Args: 27 | input_dir: Path to input directory containing: 28 | - cam_K.txt # Camera intrinsic parameters 29 | - gripper_masks/ # Masks for the gripper 30 | - masks/ # Masks for the object + gripper 31 | - ob_in_cam/ # Poses of the object in camera frame (X_CO) 32 | - rgb/ # RGB images 33 | output_dir: Path to output directory 34 | num_images: Number of images to sample (default: 800) 35 | 36 | Returns: 37 | cfg_path: Path to the generated config file. 38 | """ 39 | input_dir = Path(input_dir) 40 | output_dir = Path(output_dir) 41 | 42 | # Step 0: Copy input directory to output directory 43 | if output_dir.exists(): 44 | shutil.rmtree(output_dir) 45 | shutil.copytree(input_dir, output_dir) 46 | 47 | rgb_alpha_dir = output_dir / "rgb_alpha" 48 | add_alpha_channel( 49 | img_dir=str(output_dir / "rgb"), 50 | mask_dir=str(output_dir / "masks"), 51 | out_dir=str(rgb_alpha_dir), 52 | ) 53 | 54 | gripper_masks_inverted_dir = output_dir / "gripper_masks_inverted" 55 | invert_masks_in_directory( 56 | input_dir=str(output_dir / "gripper_masks"), 57 | output_dir=str(gripper_masks_inverted_dir), 58 | ) 59 | 60 | images_dino_sampled_dir = output_dir / "images_dino_sampled" 61 | if len(list((rgb_alpha_dir).glob("*.png"))) >= num_images: 62 | select_and_copy_dissimilar_images( 63 | image_dir=str(rgb_alpha_dir), 64 | output_dir=str(images_dino_sampled_dir), 65 | K=num_images, 66 | model_name="dino", 67 | ) 68 | else: 69 | logging.info( 70 | f"Skipping image selection: not enough images " 71 | f"(found: {len(list((rgb_alpha_dir).glob('*.png')))}, required: {num_images})" 72 | ) 73 | shutil.copytree(rgb_alpha_dir, images_dino_sampled_dir) 74 | 75 | shutil.rmtree(output_dir / "gripper_masks", ignore_errors=True) 76 | shutil.rmtree(output_dir / "gripper_masks", ignore_errors=True) 77 | (output_dir / "gripper_masks_inverted").rename(output_dir / "gripper_masks") 78 | 79 | shutil.rmtree(output_dir / "images", ignore_errors=True) 80 | (output_dir / "images_dino_sampled").rename(output_dir / "images") 81 | 82 | shutil.rmtree(output_dir / "poses", ignore_errors=True) 83 | (output_dir / "ob_in_cam").rename(output_dir / "poses") 84 | 85 | convert_bundlesdf_to_colmap_format( 86 | data_dir=str(output_dir), output_dir=str(output_dir) 87 | ) 88 | 89 | data_to_json(args=SimpleNamespace(data_dir=str(output_dir), scene_type="object")) 90 | 91 | cfg_path = generate_config( 92 | args=SimpleNamespace( 93 | data_dir=os.path.abspath(str(output_dir)), 94 | config_path=os.path.join(output_dir, "neuralangelo_recon.yaml"), 95 | sequence_name="neuralangelo_recon", 96 | scene_type="object", 97 | auto_exposure_wb=True, 98 | val_short_size=300, 99 | ) 100 | ) 101 | 102 | logging.info("Done preprocessing data for Neuralangelo.") 103 | 104 | return os.path.abspath(cfg_path) 105 | -------------------------------------------------------------------------------- /scalable_real2sim/output/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nepfaff/scalable-real2sim/edee99c65583c636ea401c64dfa8883e8e4dae5d/scalable_real2sim/output/__init__.py -------------------------------------------------------------------------------- /scalable_real2sim/output/canonicalize.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | import trimesh 6 | 7 | from scipy.spatial.transform import Rotation 8 | 9 | 10 | def get_mesh_rotation_pca(vertices: np.ndarray) -> np.ndarray: 11 | """ 12 | Returns a rotation matrix for rotating the vertices' principle component onto the 13 | z-axis and the minor component onto the x-axis. 14 | """ 15 | cov = np.cov(vertices.T) 16 | eigval, eigvec = np.linalg.eig(cov) 17 | 18 | order = eigval.argsort() 19 | principal_component = eigvec[:, order[-1]] 20 | minor_component = eigvec[:, order[0]] 21 | 22 | # Rotate mesh to align the principal component with the z-axis and the minor 23 | # component with the x-axis. 24 | z_axis, x_axis = [0.0, 0.0, 1.0], [1.0, 0.0, 0.0] 25 | rot_components_to_axes, _ = Rotation.align_vectors( 26 | np.array([z_axis, x_axis]), np.stack([principal_component, minor_component]) 27 | ) 28 | return rot_components_to_axes.as_matrix() 29 | 30 | 31 | def get_mesh_rotation_obb(mesh: o3d.geometry.TriangleMesh) -> np.ndarray: 32 | """ 33 | Returns a rotation matrix for rotating the mesh OBB's longest side onto the z-axis 34 | and the shortest side onto the 35 | x-axis. 36 | """ 37 | obb = ( 38 | mesh.get_oriented_bounding_box() 39 | ) # Computes the OBB based on PCA of the convex hull 40 | box_points = np.asarray(obb.get_box_points()) 41 | # The order of the points stays fixed (obtained these vectors from visual analysis) 42 | largest_vec = box_points[1] - box_points[0] 43 | smallest_candidate1 = box_points[2] - box_points[0] 44 | smallest_candidate2 = box_points[3] - box_points[0] 45 | smallest_vec = ( 46 | smallest_candidate1 47 | if np.linalg.norm(smallest_candidate1) < np.linalg.norm(smallest_candidate2) 48 | else smallest_candidate2 49 | ) 50 | 51 | # Rotate mesh to align the OBB's largest side with the z-axis and the smallest side 52 | # with the x-axis. 53 | z_axis, x_axis = [0.0, 0.0, 1.0], [1.0, 0.0, 0.0] 54 | rot_components_to_axes, _ = Rotation.align_vectors( 55 | np.array([z_axis, x_axis]), np.stack([largest_vec, smallest_vec]) 56 | ) 57 | return rot_components_to_axes.as_matrix() 58 | 59 | 60 | def axis_align_mesh( 61 | mesh: o3d.geometry.TriangleMesh, 62 | viz: bool = False, 63 | use_obb: bool = False, 64 | mesh_already_at_origin: bool = False, 65 | ) -> Tuple[o3d.geometry.TriangleMesh, np.ndarray]: 66 | """ 67 | Axis aligned the mesh based on OBB if `use_obb` is true and based on PCA otherwise. 68 | """ 69 | vertices = np.asarray(mesh.vertices) 70 | 71 | if not mesh_already_at_origin: 72 | # Put at world origin. 73 | vertices_at_origin = vertices - np.mean(vertices, axis=0) 74 | else: 75 | vertices_at_origin = vertices 76 | 77 | rot = get_mesh_rotation_obb(mesh) if use_obb else get_mesh_rotation_pca(vertices) 78 | vertices_rotated = vertices_at_origin @ rot.T 79 | 80 | mesh.vertices = o3d.utility.Vector3dVector(vertices_rotated) 81 | 82 | if viz: 83 | o3d.visualization.draw_geometries( 84 | [mesh, o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)] 85 | ) 86 | 87 | return mesh, rot 88 | 89 | 90 | def canonicalize_mesh(mesh: o3d.geometry.TriangleMesh) -> o3d.geometry.TriangleMesh: 91 | """Center at origin and axis align the mesh.""" 92 | # Center at origin. 93 | vertices = np.asarray(mesh.vertices) 94 | mesh_translation = np.mean(vertices, axis=0) 95 | mesh.translate(-mesh_translation) 96 | 97 | # Axis align the mesh. 98 | mesh, _ = axis_align_mesh( 99 | mesh, viz=False, use_obb=True, mesh_already_at_origin=True 100 | ) 101 | 102 | return mesh 103 | 104 | 105 | def canonicalize_mesh_from_file(mesh_path: str, output_path: str) -> None: 106 | """Canonicalize a mesh from a file while preserving texture and save the result.""" 107 | 108 | # Load mesh with trimesh (preserves texture/materials) 109 | mesh = trimesh.load(mesh_path, process=False) 110 | 111 | # Convert Trimesh to Open3D (only vertices and faces) 112 | o3d_mesh = o3d.geometry.TriangleMesh() 113 | o3d_mesh.vertices = o3d.utility.Vector3dVector(mesh.vertices) 114 | o3d_mesh.triangles = o3d.utility.Vector3iVector(mesh.faces) 115 | 116 | # Apply canonicalization (assumes canonicalize_mesh returns an Open3D mesh) 117 | canonicalized_mesh = canonicalize_mesh(o3d_mesh) 118 | 119 | # Convert back to Trimesh (preserving original materials) 120 | canonicalized_trimesh = trimesh.Trimesh( 121 | vertices=np.asarray(canonicalized_mesh.vertices), 122 | faces=np.asarray(canonicalized_mesh.triangles), 123 | visual=mesh.visual, # Keep original materials and textures 124 | ) 125 | 126 | # Save with texture/materials 127 | canonicalized_trimesh.export(output_path) 128 | -------------------------------------------------------------------------------- /scalable_real2sim/output/sdformat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on our script at 3 | https://github.com/RussTedrake/manipulation/blob/master/manipulation/create_sdf_from_mesh.py 4 | """ 5 | 6 | import argparse 7 | import logging 8 | import os 9 | 10 | from pathlib import Path 11 | from typing import List 12 | 13 | import coacd 14 | import numpy as np 15 | import trimesh 16 | 17 | from lxml import etree as ET 18 | 19 | 20 | def perform_convex_decomposition( 21 | mesh: trimesh.Trimesh, 22 | mesh_parts_dir_name: str, 23 | mesh_dir: Path, 24 | preview_with_trimesh: bool, 25 | use_coacd: bool = False, 26 | coacd_kwargs: dict | None = None, 27 | vhacd_kwargs: dict | None = None, 28 | ) -> List[Path]: 29 | """Given a mesh, performs a convex decomposition of it with either VHACD or CoACD. 30 | The resulting convex parts are saved in a subfolder named `_parts`. 31 | 32 | Args: 33 | mesh (trimesh.Trimesh): The mesh to decompose. 34 | mesh_parts_dir_name (str): The name of the mesh parts directory. 35 | mesh_dir (Path): The path to the directory that the mesh is stored in. This is 36 | used for creating the mesh parts directory. 37 | preview_with_trimesh (bool): Whether to open (and block on) a window to preview 38 | the decomposition. 39 | use_coacd (bool): Whether to use CoACD instead of VHACD for decomposition. 40 | coacd_kwargs (dict | None): The CoACD-specific parameters. 41 | vhacd_kwargs (dict | None): The VHACD-specific parameters. 42 | 43 | Returns: 44 | List[Path]: The paths of the convex pieces. 45 | """ 46 | # Create a subdir for the convex parts 47 | out_dir = mesh_dir / mesh_parts_dir_name 48 | os.makedirs(out_dir, exist_ok=True) 49 | 50 | if preview_with_trimesh: 51 | logging.info( 52 | "Showing mesh before convex decomposition. Close the window to proceed." 53 | ) 54 | mesh.show() 55 | 56 | logging.info( 57 | "Performing convex decomposition. This might take a couple of minutes for " 58 | + "complicated meshes and fine resolution settings." 59 | ) 60 | try: 61 | # Create a copy of the mesh for decomposition. 62 | mesh_copy = trimesh.Trimesh( 63 | vertices=mesh.vertices.copy(), faces=mesh.faces.copy() 64 | ) 65 | if use_coacd: 66 | coacd.set_log_level("error") 67 | coacd_mesh = coacd.Mesh(mesh_copy.vertices, mesh_copy.faces) 68 | coacd_result = coacd.run_coacd(coacd_mesh, **(coacd_kwargs or {})) 69 | # Convert CoACD result to trimesh objects. 70 | convex_pieces = [] 71 | for vertices, faces in coacd_result: 72 | piece = trimesh.Trimesh(vertices, faces) 73 | convex_pieces.append(piece) 74 | else: 75 | vhacd_settings = vhacd_kwargs or {} 76 | convex_pieces = mesh_copy.convex_decomposition(**vhacd_settings) 77 | if not isinstance(convex_pieces, list): 78 | convex_pieces = [convex_pieces] 79 | except Exception as e: 80 | logging.error(f"Problem performing decomposition: {e}") 81 | exit(1) 82 | 83 | if preview_with_trimesh: 84 | # Display the convex decomposition, giving each a random colors 85 | for part in convex_pieces: 86 | this_color = trimesh.visual.random_color() 87 | part.visual.face_colors[:] = this_color 88 | scene = trimesh.scene.scene.Scene() 89 | for part in convex_pieces: 90 | scene.add_geometry(part) 91 | 92 | logging.info( 93 | f"Showing the mesh convex decomposition into {len(convex_pieces)} parts. " 94 | + "Close the window to proceed." 95 | ) 96 | scene.show() 97 | 98 | convex_piece_paths: List[Path] = [] 99 | for i, part in enumerate(convex_pieces): 100 | piece_name = f"convex_piece_{i:03d}.obj" 101 | path = out_dir / piece_name 102 | part.export(path) 103 | convex_piece_paths.append(path) 104 | 105 | return convex_piece_paths 106 | 107 | 108 | def create_sdf( 109 | model_name: str, 110 | mesh_parts_dir_name: str, 111 | output_path: Path, 112 | visual_mesh_path: Path, 113 | collision_mesh_path: Path, 114 | mass: float, 115 | center_of_mass: np.ndarray, 116 | moment_of_inertia: np.ndarray, 117 | use_hydroelastic: bool = False, 118 | is_compliant: bool = False, 119 | hydroelastic_modulus: float | None = None, 120 | hunt_crossley_dissipation: float | None = None, 121 | mu_dynamic: float | None = None, 122 | mu_static: float | None = None, 123 | preview_with_trimesh: bool = False, 124 | use_coacd: bool = False, 125 | coacd_kwargs: dict | None = None, 126 | vhacd_kwargs: dict | None = None, 127 | ) -> None: 128 | """Performs convex decomposition of the collision mesh and adds it to the SDFormat 129 | file with all other input properties. 130 | 131 | Args: 132 | model_name (str): The name of the model. The link will be named 133 | `_body_link`. 134 | mesh_parts_dir_name (str): The name of the mesh parts directory. 135 | output_path (Path): The path to the output SDFormat file. Must end in `.sdf`. 136 | visual_mesh_path (Path): The path to the mesh that will be used as the visual 137 | geometry. 138 | collision_mesh_path (Path): The path to the mesh that will be used for convex 139 | decomposition into collision pieces. NOTE that this mesh is expected to 140 | align with the visual mesh. 141 | mass (float): The mass in kg of the mesh. 142 | center_of_mass (np.ndarray): The center of mass of the mesh, expressed in the 143 | mesh's local frame. 144 | moment_of_inertia (np.ndarray): The moment of inertia of the mesh expressed in 145 | the mesh's local frame and about the center of mass. 146 | use_hydroelastic (bool): Whether to use Hydroelastic contact by adding Drake 147 | specific tags to the SDFormat file. 148 | is_compliant (bool): Whether the SDFormat file will be used for compliant 149 | Hydroelastic simulations. The object will behave as rigid Hydroelastic if this 150 | is not specified. 151 | hydroelastic_modulus (float): The Hydroelastic Modulus. This is only used if 152 | `is_compliant` is True. The default value leads to low compliance. See 153 | https://drake.mit.edu/doxygen_cxx/group__hydroelastic__user__guide.html for how 154 | to pick a value. 155 | hunt_crossley_dissipation (Union[float, None]): The optional Hydroelastic 156 | Hunt-Crossley dissipation (s/m). See 157 | https://drake.mit.edu/doxygen_cxx/group__hydroelastic__user__guide.html for how 158 | to pick a value. 159 | mu_dynamic (Union[float, None]): The coefficient of dynamic friction. 160 | mu_static (Union[float, None]): The coefficient of static friction. 161 | preview_with_trimesh (bool): Whether to open (and block on) a window to preview 162 | the decomposition. 163 | use_coacd (bool): Whether to use CoACD instead of VHACD for convex decomposition. 164 | coacd_kwargs (dict | None): The CoACD-specific parameters. 165 | vhacd_kwargs (dict | None): The VHACD-specific parameters. 166 | """ 167 | # Handle string paths. 168 | visual_mesh_path = Path(visual_mesh_path) 169 | collision_mesh_path = Path(collision_mesh_path) 170 | output_path = Path(output_path) 171 | 172 | # Validate input. 173 | if not output_path.suffix == ".sdf": 174 | raise ValueError("Output path must end in `.sdf`.") 175 | if (use_coacd and vhacd_kwargs is not None) or ( 176 | not use_coacd and coacd_kwargs is not None 177 | ): 178 | raise ValueError("Cannot use both CoACD and VHACD.") 179 | 180 | # Generate the SDFormat headers 181 | root_item = ET.Element("sdf", version="1.7", nsmap={"drake": "drake.mit.edu"}) 182 | model_item = ET.SubElement(root_item, "model", name=model_name) 183 | link_item = ET.SubElement(model_item, "link", name=f"{model_name}_body_link") 184 | pose_item = ET.SubElement(link_item, "pose") 185 | pose_item.text = "0 0 0 0 0 0" 186 | 187 | # Add the physical properties 188 | inertial_item = ET.SubElement(link_item, "inertial") 189 | mass_item = ET.SubElement(inertial_item, "mass") 190 | mass_item.text = str(mass) 191 | com_item = ET.SubElement(inertial_item, "pose") 192 | com_item.text = ( 193 | f"{center_of_mass[0]:.5f} {center_of_mass[1]:.5f} {center_of_mass[2]:.5f} 0 0 0" 194 | ) 195 | inertia_item = ET.SubElement(inertial_item, "inertia") 196 | for i in range(3): 197 | for j in range(i, 3): 198 | item = ET.SubElement(inertia_item, "i" + "xyz"[i] + "xyz"[j]) 199 | item.text = f"{moment_of_inertia[i, j]:.5e}" 200 | 201 | # Add the original mesh as the visual mesh 202 | visual_mesh_path = visual_mesh_path.relative_to(output_path.parent) 203 | visual_item = ET.SubElement(link_item, "visual", name="visual") 204 | geometry_item = ET.SubElement(visual_item, "geometry") 205 | mesh_item = ET.SubElement(geometry_item, "mesh") 206 | uri_item = ET.SubElement(mesh_item, "uri") 207 | uri_item.text = visual_mesh_path.as_posix() 208 | 209 | # Compute the convex decomposition and use it as the collision geometry 210 | collision_mesh = trimesh.load( 211 | collision_mesh_path, skip_materials=True, force="mesh" 212 | ) 213 | mesh_piece_paths = perform_convex_decomposition( 214 | mesh=collision_mesh, 215 | mesh_parts_dir_name=mesh_parts_dir_name, 216 | mesh_dir=output_path.parent, 217 | preview_with_trimesh=preview_with_trimesh, 218 | use_coacd=use_coacd, 219 | coacd_kwargs=coacd_kwargs, 220 | vhacd_kwargs=vhacd_kwargs, 221 | ) 222 | for i, mesh_piece_path in enumerate(mesh_piece_paths): 223 | mesh_piece_path = mesh_piece_path.relative_to(output_path.parent) 224 | collision_item = ET.SubElement( 225 | link_item, "collision", name=f"collision_{i:03d}" 226 | ) 227 | geometry_item = ET.SubElement(collision_item, "geometry") 228 | mesh_item = ET.SubElement(geometry_item, "mesh") 229 | uri_item = ET.SubElement(mesh_item, "uri") 230 | uri_item.text = mesh_piece_path.as_posix() 231 | ET.SubElement(mesh_item, "{drake.mit.edu}declare_convex") 232 | 233 | if use_hydroelastic: 234 | # Add proximity properties 235 | proximity_item = ET.SubElement( 236 | collision_item, "{drake.mit.edu}proximity_properties" 237 | ) 238 | if is_compliant: 239 | ET.SubElement(proximity_item, "{drake.mit.edu}compliant_hydroelastic") 240 | hydroelastic_moulus_item = ET.SubElement( 241 | proximity_item, "{drake.mit.edu}hydroelastic_modulus" 242 | ) 243 | hydroelastic_moulus_item.text = f"{hydroelastic_modulus:.3e}" 244 | else: 245 | ET.SubElement(proximity_item, "{drake.mit.edu}rigid_hydroelastic") 246 | if hunt_crossley_dissipation is not None: 247 | hunt_crossley_dissipation_item = ET.SubElement( 248 | proximity_item, "{drake.mit.edu}hunt_crossley_dissipation" 249 | ) 250 | hunt_crossley_dissipation_item.text = f"{hunt_crossley_dissipation:.3f}" 251 | if mu_dynamic is not None: 252 | mu_dynamic_item = ET.SubElement( 253 | proximity_item, "{drake.mit.edu}mu_dynamic" 254 | ) 255 | mu_dynamic_item.text = f"{mu_dynamic:.3f}" 256 | if mu_static is not None: 257 | mu_static_item = ET.SubElement( 258 | proximity_item, "{drake.mit.edu}mu_static" 259 | ) 260 | mu_static_item.text = f"{mu_static:.3f}" 261 | 262 | logging.info(f"Writing SDF to {output_path}") 263 | ET.ElementTree(root_item).write(output_path, pretty_print=True) 264 | 265 | 266 | if __name__ == "__main__": 267 | parser = argparse.ArgumentParser( 268 | description="Create a Drake-compatible SDFormat file for a triangle mesh." 269 | ) 270 | parser.add_argument( 271 | "--mesh", 272 | type=str, 273 | required=True, 274 | help="Path to mesh file.", 275 | ) 276 | parser.add_argument( 277 | "--mass", 278 | type=float, 279 | required=True, 280 | help="The mass in kg of the object that is represented by the mesh. This is " 281 | + "used for computing the moment of inertia.", 282 | ) 283 | parser.add_argument( 284 | "--scale", 285 | type=float, 286 | default=1.0, 287 | help="Scale factor to convert the specified mesh's coordinates to meters.", 288 | ) 289 | parser.add_argument( 290 | "--compliant", 291 | action="store_true", 292 | help="Whether the SDFormat file will be used for compliant Hydroelastic " 293 | + "simulations. The object will behave as rigid Hydroelastic if this is not " 294 | + "specified.", 295 | ) 296 | parser.add_argument( 297 | "--hydroelastic_modulus", 298 | type=float, 299 | default=1.0e8, 300 | help="The Hydroelastic Modulus. This is only used if --compliant is specified. " 301 | + "The default value leads to low compliance. See " 302 | + "https://drake.mit.edu/doxygen_cxx/group__hydroelastic__user__guide.html for " 303 | + "how to pick a value.", 304 | ) 305 | parser.add_argument( 306 | "--hunt_crossley_dissipation", 307 | type=float, 308 | default=None, 309 | help="The Hydroelastic Hunt-Crossley dissipation (s/m). See " 310 | + "https://drake.mit.edu/doxygen_cxx/group__hydroelastic__user__guide.html for " 311 | + "how to pick a value.", 312 | ) 313 | parser.add_argument( 314 | "--mu_dynamic", 315 | type=float, 316 | default=None, 317 | help="The coefficient of dynamic friction.", 318 | ) 319 | parser.add_argument( 320 | "--mu_static", 321 | type=float, 322 | default=None, 323 | help="The coefficient of static friction.", 324 | ) 325 | parser.add_argument( 326 | "--preview", 327 | action="store_true", 328 | help="Whether to preview the decomposition.", 329 | ) 330 | parser.add_argument( 331 | "--log_level", 332 | type=str, 333 | default="INFO", 334 | choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], 335 | help="Log level.", 336 | ) 337 | 338 | # Create argument groups for VHACD and CoACD. 339 | vhacd_group = parser.add_argument_group("VHACD parameters") 340 | coacd_group = parser.add_argument_group("CoACD parameters") 341 | 342 | parser.add_argument( 343 | "--use_coacd", 344 | action="store_true", 345 | help="Use CoACD instead of VHACD for convex decomposition.", 346 | ) 347 | 348 | # CoACD arguments. 349 | coacd_group.add_argument( 350 | "--threshold", 351 | type=float, 352 | help="CoACD threshold parameter for determining concavity.", 353 | ) 354 | coacd_group.add_argument( 355 | "--preprocess_resolution", 356 | type=int, 357 | help="Resolution used in preprocessing step.", 358 | ) 359 | coacd_group.add_argument( 360 | "--coacd_resolution", 361 | type=int, 362 | help="Main resolution parameter for decomposition.", 363 | ) 364 | coacd_group.add_argument( 365 | "--mcts_nodes", 366 | type=int, 367 | help="Number of nodes for Monte Carlo Tree Search.", 368 | ) 369 | coacd_group.add_argument( 370 | "--mcts_iterations", 371 | type=int, 372 | help="Number of iterations for Monte Carlo Tree Search.", 373 | ) 374 | coacd_group.add_argument( 375 | "--mcts_max_depth", 376 | type=int, 377 | help="Maximum depth for Monte Carlo Tree Search.", 378 | ) 379 | coacd_group.add_argument( 380 | "--preprocess_mode", 381 | type=str, 382 | default="auto", 383 | choices=["auto", "voxel", "sampling"], 384 | help="CoACD preprocess mode.", 385 | ) 386 | coacd_group.add_argument( 387 | "--pca", action="store_true", help="Enable PCA pre-processing." 388 | ) 389 | 390 | # VHACD arguments. 391 | vhacd_group.add_argument( 392 | "--vhacd_resolution", 393 | type=int, 394 | default=10000000, 395 | help="VHACD voxel resolution.", 396 | ) 397 | vhacd_group.add_argument( 398 | "--maxConvexHulls", 399 | type=int, 400 | default=64, 401 | help="VHACD maximum number of convex hulls/ mesh pieces.", 402 | ) 403 | vhacd_group.add_argument( 404 | "--minimumVolumePercentErrorAllowed", 405 | type=float, 406 | default=1.0, 407 | help="VHACD minimum allowed volume percentage error.", 408 | ) 409 | vhacd_group.add_argument( 410 | "--maxRecursionDepth", 411 | type=int, 412 | default=10, 413 | help="VHACD maximum recursion depth.", 414 | ) 415 | vhacd_group.add_argument( 416 | "--no_shrinkWrap", 417 | action="store_true", 418 | help="Whether or not to shrinkwrap the voxel positions to the source mesh on " 419 | + "output.", 420 | ) 421 | vhacd_group.add_argument( 422 | "--fillMode", 423 | type=str, 424 | default="flood", 425 | choices=["flood", "raycast", "surface"], 426 | help="VHACD maximum recursion depth.", 427 | ) 428 | vhacd_group.add_argument( 429 | "--maxNumVerticesPerCH", 430 | type=int, 431 | default=64, 432 | help="VHACD maximum number of triangles per convex hull.", 433 | ) 434 | vhacd_group.add_argument( 435 | "--no_asyncACD", 436 | action="store_true", 437 | help="Whether or not to run VHACD asynchronously, taking advantage of " 438 | + "additional cores.", 439 | ) 440 | vhacd_group.add_argument( 441 | "--minEdgeLength", 442 | type=int, 443 | default=2, 444 | help="VHACD minimum voxel patch edge length.", 445 | ) 446 | 447 | args = parser.parse_args() 448 | logging.basicConfig(level=args.log_level) 449 | 450 | # Separate VHACD and CoACD parameters. 451 | vhacd_params = ( 452 | { 453 | "resolution": args.vhacd_resolution, 454 | "maxConvexHulls": args.maxConvexHulls, 455 | "minimumVolumePercentErrorAllowed": args.minimumVolumePercentErrorAllowed, 456 | "maxRecursionDepth": args.maxRecursionDepth, 457 | "shrinkWrap": not args.no_shrinkWrap, 458 | "fillMode": args.fillMode, 459 | "maxNumVerticesPerCH": args.maxNumVerticesPerCH, 460 | "asyncACD": not args.no_asyncACD, 461 | "minEdgeLength": args.minEdgeLength, 462 | } 463 | if not args.use_coacd 464 | else None 465 | ) 466 | coacd_params = {} 467 | for param in [ 468 | "threshold", 469 | "preprocess_resolution", 470 | "coacd_resolution", 471 | "mcts_nodes", 472 | "mcts_iterations", 473 | "mcts_max_depth", 474 | "preprocess_mode", 475 | ]: 476 | value = getattr(args, param) 477 | if value is not None: 478 | key = "resolution" if param == "coacd_resolution" else param 479 | coacd_params[key] = value 480 | 481 | # create_sdf_from_mesh( 482 | # mesh_path=mesh_path, 483 | # mass=mass, 484 | # scale=args.scale, 485 | # is_compliant=is_compliant, 486 | # hydroelastic_modulus=hydroelastic_modulus, 487 | # hunt_crossley_dissipation=hunt_crossley_dissipation, 488 | # mu_dynamic=mu_dynamic, 489 | # mu_static=mu_static, 490 | # preview_with_trimesh=args.preview, 491 | # use_coacd=args.use_coacd, 492 | # coacd_kwargs=coacd_params, 493 | # vhacd_kwargs=vhacd_params, 494 | # ) 495 | -------------------------------------------------------------------------------- /scalable_real2sim/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nepfaff/scalable-real2sim/edee99c65583c636ea401c64dfa8883e8e4dae5d/scalable_real2sim/segmentation/__init__.py -------------------------------------------------------------------------------- /scalable_real2sim/segmentation/detect_object.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from PIL import Image 4 | from transformers import AutoProcessor, LlavaForConditionalGeneration 5 | 6 | 7 | def detect_object(image_path: str) -> str: 8 | # Load LLaVA model and processor 9 | processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", use_fast=True) 10 | model = LlavaForConditionalGeneration.from_pretrained( 11 | "llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16, device_map="auto" 12 | ) 13 | 14 | # Load and process image 15 | raw_image = Image.open(image_path).convert("RGB") 16 | 17 | # Define conversation using the correct format 18 | conversation = [ 19 | { 20 | "role": "user", 21 | "content": [ 22 | { 23 | "type": "text", 24 | "text": "What is this object? Example: 'A red tomatoe soup can'.", 25 | }, 26 | {"type": "image"}, 27 | ], 28 | } 29 | ] 30 | prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) 31 | 32 | # Process inputs 33 | inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to( 34 | "cuda", torch.float16 35 | ) 36 | 37 | # Generate response 38 | outputs = model.generate( 39 | **inputs, 40 | max_new_tokens=30, 41 | num_beams=3, 42 | ) 43 | response = processor.decode(outputs[0], skip_special_tokens=True) 44 | 45 | # Clean up the response to get only the object name 46 | response = response.split("ASSISTANT:")[-1].strip() 47 | if "is" in response.lower(): # Remove any "is" or "is a" or similar phrases 48 | response = response.split("is")[-1].strip() 49 | response = ( 50 | response.strip('." ').lstrip("a ").lstrip("an ") 51 | ) # Remove articles and punctuation 52 | 53 | return response 54 | -------------------------------------------------------------------------------- /scalable_real2sim/segmentation/segment_moving_object_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from 3 | https://github.com/IDEA-Research/Grounded-SAM-2/blob/main/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py 4 | """ 5 | 6 | import glob 7 | import logging 8 | import os 9 | import shutil 10 | 11 | from itertools import chain 12 | 13 | import cv2 14 | import numpy as np 15 | import torch 16 | 17 | from PIL import Image 18 | from sam2.build_sam import build_sam2, build_sam2_video_predictor 19 | from sam2.sam2_image_predictor import SAM2ImagePredictor 20 | from tqdm import tqdm 21 | from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor 22 | 23 | """ 24 | Hyperparameters for Grounding and Tracking 25 | """ 26 | PROMPT_TYPE_FOR_VIDEO = "point" # Choose from ["point", "box", "mask"] 27 | OFFLOAD_VIDEO_TO_CPU = True # Prevents OOM for large videos but is slower. 28 | OFFLOAD_STATE_TO_CPU = True 29 | 30 | 31 | def convert_png_to_jpg(input_folder, output_folder): 32 | # Create the output folder if it doesn't exist 33 | os.makedirs(output_folder, exist_ok=True) 34 | 35 | # Loop through all files in the input folder 36 | for filename in os.listdir(input_folder): 37 | if filename.endswith(".png"): 38 | # Full path to the input image 39 | input_path = os.path.join(input_folder, filename) 40 | 41 | # Open the image using PIL 42 | with Image.open(input_path) as img: 43 | # Convert the image to RGB mode (JPG doesn't support transparency) 44 | img = img.convert("RGB") 45 | 46 | # Generate the output filename with .jpg extension 47 | output_filename = os.path.splitext(filename)[0] + ".jpg" 48 | 49 | # Full path to the output image 50 | output_path = os.path.join(output_folder, output_filename) 51 | 52 | # Save the image in JPG format 53 | img.save(output_path, "JPEG") 54 | 55 | 56 | def sample_points_from_masks(masks, num_points): 57 | """ 58 | Sample points from masks and return their absolute coordinates. 59 | 60 | Args: 61 | masks: np.array with shape (n, h, w) 62 | num_points: int 63 | 64 | Returns: 65 | points: np.array with shape (n, points, 2) 66 | """ 67 | n, h, w = masks.shape 68 | points = [] 69 | 70 | for i in range(n): 71 | # Find the valid mask points 72 | indices = np.argwhere(masks[i] == 1) 73 | # Convert from (y, x) to (x, y) 74 | indices = indices[:, ::-1] 75 | 76 | if len(indices) == 0: 77 | # If there are no valid points, append an empty array 78 | points.append(np.array([])) 79 | continue 80 | 81 | # Resampling if there's not enough points 82 | if len(indices) < num_points: 83 | sampled_indices = np.random.choice(len(indices), num_points, replace=True) 84 | else: 85 | sampled_indices = np.random.choice(len(indices), num_points, replace=False) 86 | 87 | sampled_points = indices[sampled_indices] 88 | points.append(sampled_points) 89 | 90 | # Convert to np.array 91 | points = np.array(points, dtype=np.float32) 92 | return points 93 | 94 | 95 | def segment_moving_obj_data( 96 | rgb_dir: str, 97 | output_dir: str, 98 | txt_prompt: str | None = None, 99 | txt_prompt_index: int = 0, 100 | neg_txt_prompt: str | None = None, 101 | num_neg_frames: int = 10, 102 | debug_dir: str | None = None, 103 | gui_frames: list[str] | None = None, 104 | ): 105 | # Ensure mutual exclusivity between GUI and text prompts 106 | if gui_frames is not None: 107 | if txt_prompt is not None or neg_txt_prompt is not None: 108 | raise ValueError("Cannot use both GUI frames and text prompts.") 109 | else: 110 | if txt_prompt is None: 111 | raise ValueError( 112 | "Text prompt must be provided if GUI frames are not specified." 113 | ) 114 | 115 | sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" 116 | model_cfg = "sam2_hiera_l.yaml" 117 | 118 | # Download checkpoint if not exist. 119 | if not os.path.exists(sam2_checkpoint): 120 | logging.info("Downloading sam2_hiera_large.pt checkpoint...") 121 | BASE_URL = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" 122 | sam2_hiera_l_url = f"{BASE_URL}sam2_hiera_large.pt" 123 | status = os.system(f"wget {sam2_hiera_l_url} -P ./checkpoints/") 124 | if status != 0: 125 | raise RuntimeError("Failed to download the checkpoint.") 126 | 127 | """ 128 | Step 1: Environment settings and model initialization for SAM 2 129 | """ 130 | # Use bfloat16 for the entire script 131 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 132 | 133 | if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: 134 | # Turn on tfloat32 for Ampere GPUs 135 | torch.backends.cuda.matmul.allow_tf32 = True 136 | torch.backends.cudnn.allow_tf32 = True 137 | 138 | # Initialize SAM image predictor and video predictor models 139 | video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) 140 | sam2_image_model = build_sam2(model_cfg, sam2_checkpoint) 141 | image_predictor = SAM2ImagePredictor(sam2_image_model) 142 | 143 | # Build Grounding DINO from Hugging Face (used only if not using GUI) 144 | if gui_frames is None: 145 | model_id = "IDEA-Research/grounding-dino-tiny" 146 | device = "cuda" if torch.cuda.is_available() else "cpu" 147 | processor = AutoProcessor.from_pretrained(model_id) 148 | grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained( 149 | model_id 150 | ).to(device) 151 | 152 | # Convert PNG to JPG as required for video predictor. 153 | jpg_dir = os.path.join(rgb_dir, "jpg") 154 | 155 | if gui_frames is not None: 156 | # Find minimum frame number from gui_frames 157 | min_frame_num = min(int(frame) for frame in gui_frames) 158 | 159 | # Only convert PNGs with numbers >= min_frame_num 160 | if os.path.exists(jpg_dir): 161 | shutil.rmtree(jpg_dir) 162 | os.makedirs(jpg_dir) 163 | for filename in os.listdir(rgb_dir): 164 | if filename.endswith(".png"): 165 | frame_num = int(os.path.splitext(filename)[0]) 166 | if frame_num >= min_frame_num: 167 | input_path = os.path.join(rgb_dir, filename) 168 | with Image.open(input_path) as img: 169 | img = img.convert("RGB") 170 | output_filename = os.path.splitext(filename)[0] + ".jpg" 171 | output_path = os.path.join(jpg_dir, output_filename) 172 | img.save(output_path, "JPEG") 173 | else: 174 | # If not using GUI, convert all PNGs 175 | convert_png_to_jpg(rgb_dir, jpg_dir) 176 | 177 | # Scan all the JPEG frame names in this directory 178 | frame_names = [ 179 | p 180 | for p in os.listdir(jpg_dir) 181 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 182 | ] 183 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 184 | frame_count = len(frame_names) 185 | 186 | # Create a mapping from image names to frame indices 187 | frame_name_to_idx = { 188 | os.path.splitext(p)[0]: idx for idx, p in enumerate(frame_names) 189 | } 190 | 191 | # Initialize video predictor state 192 | inference_state = video_predictor.init_state( 193 | video_path=jpg_dir, 194 | offload_video_to_cpu=OFFLOAD_VIDEO_TO_CPU, 195 | offload_state_to_cpu=OFFLOAD_STATE_TO_CPU, 196 | ) 197 | 198 | # Clear the debug directory before writing 199 | if debug_dir is not None: 200 | if not os.path.exists(debug_dir): 201 | os.makedirs(debug_dir) 202 | else: 203 | # Clear the debug directory 204 | for item in os.listdir(debug_dir): 205 | item_path = os.path.join(debug_dir, item) 206 | try: 207 | if os.path.isfile(item_path) or os.path.islink(item_path): 208 | os.remove(item_path) # Remove file or symlink 209 | elif os.path.isdir(item_path): 210 | shutil.rmtree(item_path) # Remove directory 211 | except Exception as e: 212 | logging.error(f"Error deleting {item_path}: {e}") 213 | 214 | """ 215 | Step 2: Prompt Grounding DINO for box coordinates or collect points via GUI 216 | """ 217 | 218 | object_id_counter = 1 # Start object IDs from 1 219 | txt_prompt_counter = 0 220 | neg_txt_prompt_counter = 0 221 | 222 | if gui_frames is not None: 223 | for frame_name in gui_frames: 224 | if frame_name not in frame_name_to_idx: 225 | raise ValueError(f"Frame {frame_name} not found in the RGB directory.") 226 | frame_idx = frame_name_to_idx[frame_name] 227 | 228 | # Read the frame 229 | img_path = os.path.join(jpg_dir, frame_names[frame_idx]) 230 | image = cv2.imread(img_path) 231 | image_display = image.copy() 232 | 233 | positive_points = [] 234 | negative_points = [] 235 | 236 | def mouse_callback(event, x, y, flags, param): 237 | image_display = param["image"].copy() # Get fresh copy of image 238 | scale = param["scale"] 239 | orig_x = int(x / scale) 240 | orig_y = int(y / scale) 241 | 242 | # Draw all existing points 243 | for px, py in positive_points: 244 | px, py = int(px * scale), int(py * scale) 245 | cv2.circle( 246 | image_display, 247 | (px, py), 248 | radius=5, 249 | color=(0, 255, 0), 250 | thickness=-1, 251 | ) 252 | for px, py in negative_points: 253 | px, py = int(px * scale), int(py * scale) 254 | cv2.circle( 255 | image_display, 256 | (px, py), 257 | radius=5, 258 | color=(0, 0, 255), 259 | thickness=-1, 260 | ) 261 | 262 | if event == cv2.EVENT_LBUTTONDOWN: 263 | positive_points.append([orig_x, orig_y]) 264 | cv2.circle( 265 | image_display, (x, y), radius=5, color=(0, 255, 0), thickness=-1 266 | ) 267 | elif event == cv2.EVENT_RBUTTONDOWN: 268 | negative_points.append([orig_x, orig_y]) 269 | cv2.circle( 270 | image_display, (x, y), radius=5, color=(0, 0, 255), thickness=-1 271 | ) 272 | 273 | # Add instructions text 274 | instructions = [ 275 | "Left click: Add positive point (green)", 276 | "Right click: Add negative point (red)", 277 | "u: Undo last point", 278 | "r: Reset all points", 279 | "q: Finish frame", 280 | ] 281 | y_offset = 30 282 | for inst in instructions: 283 | cv2.putText( 284 | image_display, 285 | inst, 286 | (10, y_offset), 287 | cv2.FONT_HERSHEY_SIMPLEX, 288 | 0.6, 289 | (255, 255, 255), 290 | 2, 291 | ) 292 | y_offset += 25 293 | 294 | cv2.imshow("Frame", image_display) 295 | 296 | # Use a default screen height (80% of 1080p) 297 | target_height = int(1080 * 0.8) 298 | 299 | # Calculate scaling factor to fit image to target height 300 | scale = min(1.0, target_height / image.shape[0]) 301 | 302 | # Resize image for display 303 | image_display = cv2.resize(image.copy(), None, fx=scale, fy=scale) 304 | 305 | cv2.namedWindow("Frame") 306 | cv2.setMouseCallback( 307 | "Frame", mouse_callback, {"image": image, "scale": scale} 308 | ) 309 | cv2.imshow("Frame", image_display) 310 | 311 | logging.info( 312 | f"Annotating frame {frame_name}. Left click to add positive points, " 313 | "right click to add negative points. Press 'q' to finish." 314 | ) 315 | 316 | while True: 317 | key = cv2.waitKey(1) & 0xFF 318 | if key == ord("q"): 319 | break 320 | elif key == ord("u"): # Undo last point 321 | if len(positive_points) + len(negative_points) > 0: 322 | if len(negative_points) > 0: 323 | negative_points.pop() 324 | else: 325 | positive_points.pop() 326 | mouse_callback( 327 | None, 0, 0, None, {"image": image, "scale": scale} 328 | ) 329 | elif key == ord("r"): # Reset points 330 | positive_points.clear() 331 | negative_points.clear() 332 | mouse_callback(None, 0, 0, None, {"image": image, "scale": scale}) 333 | 334 | cv2.destroyAllWindows() 335 | 336 | # Convert points to numpy arrays 337 | positive_points = np.array(positive_points, dtype=np.float32) 338 | negative_points = np.array(negative_points, dtype=np.float32) 339 | 340 | # Combine points and labels 341 | if len(positive_points) > 0 and len(negative_points) > 0: 342 | point_coords = np.concatenate( 343 | [positive_points, negative_points], axis=0 344 | ) 345 | point_labels = np.concatenate( 346 | [np.ones(len(positive_points)), np.zeros(len(negative_points))] 347 | ) 348 | elif len(positive_points) > 0: 349 | point_coords = positive_points 350 | point_labels = np.ones(len(positive_points)) 351 | elif len(negative_points) > 0: 352 | point_coords = negative_points 353 | point_labels = np.zeros(len(negative_points)) 354 | else: 355 | raise ValueError(f"No points provided for frame '{frame_name}'.") 356 | 357 | # Set image for image predictor 358 | image_predictor.set_image(image) 359 | 360 | # Predict mask using the points 361 | masks, scores, logits = image_predictor.predict( 362 | point_coords=point_coords, 363 | point_labels=point_labels, 364 | box=None, 365 | multimask_output=False, 366 | ) 367 | # Convert the mask shape to (n, H, W) 368 | if masks.ndim == 4: 369 | masks = masks.squeeze(1) 370 | 371 | # Process the detection results 372 | OBJECTS = [object_id_counter] # Assign unique object ID 373 | 374 | # Save debug images with points 375 | if debug_dir is not None: 376 | img_path = os.path.join(jpg_dir, frame_names[frame_idx]) 377 | image_debug = cv2.imread(img_path) 378 | # Draw the query points 379 | for point in positive_points: 380 | x, y = point.astype(int) 381 | cv2.circle( 382 | image_debug, (x, y), radius=5, color=(0, 255, 0), thickness=-1 383 | ) 384 | for point in negative_points: 385 | x, y = point.astype(int) 386 | cv2.circle( 387 | image_debug, (x, y), radius=5, color=(0, 0, 255), thickness=-1 388 | ) 389 | # Save the image 390 | save_name = f"gui_prompt_{frame_name}.jpg" 391 | save_path = os.path.join(debug_dir, save_name) 392 | cv2.imwrite(save_path, image_debug) 393 | 394 | """ 395 | Step 3: Register each object's positive points to video predictor with separate add_new_points call 396 | """ 397 | 398 | if PROMPT_TYPE_FOR_VIDEO == "point": 399 | for obj_id in OBJECTS: 400 | labels = point_labels.astype(np.int32) 401 | points = point_coords 402 | ( 403 | _, 404 | out_obj_ids, 405 | out_mask_logits, 406 | ) = video_predictor.add_new_points_or_box( 407 | inference_state=inference_state, 408 | frame_idx=frame_idx, 409 | obj_id=obj_id, 410 | points=points, 411 | labels=labels, 412 | ) 413 | else: 414 | raise NotImplementedError( 415 | "For GUI input, only point prompts are supported." 416 | ) 417 | 418 | object_id_counter += 1 # Increment object ID for next object 419 | 420 | else: 421 | """ 422 | Step 2: Prompt Grounding DINO for box coordinates 423 | """ 424 | 425 | # Function to get DINO boxes 426 | def get_dino_boxes(text, frame_idx): 427 | img_path = os.path.join(jpg_dir, frame_names[frame_idx]) 428 | image = Image.open(img_path) 429 | 430 | inputs = processor(images=image, text=text, return_tensors="pt").to(device) 431 | with torch.no_grad(): 432 | outputs = grounding_model(**inputs) 433 | 434 | results = processor.post_process_grounded_object_detection( 435 | outputs, 436 | inputs.input_ids, 437 | box_threshold=0.4, 438 | text_threshold=0.3, 439 | target_sizes=[image.size[::-1]], 440 | ) 441 | 442 | input_boxes = results[0]["boxes"].cpu().numpy() 443 | confidences = results[0]["scores"].cpu().numpy().tolist() 444 | class_names = results[0]["labels"] 445 | return input_boxes, confidences, class_names 446 | 447 | input_boxes, confidences, class_names = get_dino_boxes( 448 | txt_prompt, txt_prompt_index 449 | ) 450 | 451 | assert ( 452 | len(input_boxes) > 0 453 | ), "No results found for the text prompt. Make sure that the prompt ends with a dot '.'!" 454 | 455 | # Prompt SAM image predictor to get the mask for the object 456 | img_path = os.path.join(jpg_dir, frame_names[txt_prompt_index]) 457 | image = Image.open(img_path) 458 | image_predictor.set_image(np.array(image.convert("RGB"))) 459 | 460 | # Process the detection results 461 | OBJECTS = class_names 462 | 463 | # Prompt SAM 2 image predictor to get the mask for the object 464 | masks, scores, logits = image_predictor.predict( 465 | point_coords=None, 466 | point_labels=None, 467 | box=input_boxes, 468 | multimask_output=False, 469 | ) 470 | # Convert the mask shape to (n, H, W) 471 | if masks.ndim == 4: 472 | masks = masks.squeeze(1) 473 | 474 | """ 475 | Step 3: Register each object's positive points to video predictor with separate add_new_points call 476 | """ 477 | 478 | if PROMPT_TYPE_FOR_VIDEO == "point": 479 | # sample the positive points from mask for each object 480 | all_sample_points = sample_points_from_masks(masks=masks, num_points=10) 481 | 482 | # Save debug images with bounding boxes and query points for txt_prompt 483 | if debug_dir is not None: 484 | img_path = os.path.join(jpg_dir, frame_names[txt_prompt_index]) 485 | image = cv2.imread(img_path) 486 | # Draw the boxes 487 | for box in input_boxes: 488 | x_min, y_min, x_max, y_max = box.astype(int) 489 | cv2.rectangle( 490 | image, 491 | (x_min, y_min), 492 | (x_max, y_max), 493 | color=(0, 255, 0), 494 | thickness=2, 495 | ) 496 | # Draw the query points 497 | for points in all_sample_points: 498 | for point in points: 499 | x, y = point.astype(int) 500 | cv2.circle( 501 | image, (x, y), radius=5, color=(0, 255, 0), thickness=-1 502 | ) 503 | # Save the image 504 | save_name = f"txt_prompt_{txt_prompt_counter}.jpg" 505 | save_path = os.path.join(debug_dir, save_name) 506 | cv2.imwrite(save_path, image) 507 | txt_prompt_counter += 1 508 | 509 | for object_id, (label, points) in enumerate( 510 | zip(OBJECTS, all_sample_points), start=object_id_counter 511 | ): 512 | # label one means positive (do mask), label zero means negative (don't mask) 513 | labels = np.ones((points.shape[0]), dtype=np.int32) 514 | _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( 515 | inference_state=inference_state, 516 | frame_idx=txt_prompt_index, 517 | obj_id=object_id, 518 | points=points, 519 | labels=labels, 520 | ) 521 | object_id_counter += len(OBJECTS) 522 | 523 | # Handle negative prompts if provided 524 | if neg_txt_prompt is not None: 525 | image_predictor.reset_predictor() 526 | 527 | neg_id_start_orig = neg_id_start = object_id_counter 528 | for idx in tqdm( 529 | np.linspace(0, frame_count - 1, num_neg_frames, dtype=int), 530 | desc="Adding negative", 531 | leave=False, 532 | ): 533 | neg_input_boxes, _, neg_class_names = get_dino_boxes( 534 | neg_txt_prompt, idx 535 | ) 536 | if len(neg_input_boxes) == 0: 537 | continue 538 | 539 | img_path = os.path.join(jpg_dir, frame_names[idx]) 540 | image = Image.open(img_path) 541 | image_predictor.set_image(np.array(image.convert("RGB"))) 542 | 543 | # prompt SAM image predictor to get the mask for the negative object 544 | neg_masks, _, _ = image_predictor.predict( 545 | point_coords=None, 546 | point_labels=None, 547 | box=neg_input_boxes, 548 | multimask_output=False, 549 | ) 550 | # convert the mask shape to (n, H, W) 551 | if neg_masks.ndim == 4: 552 | neg_masks = neg_masks.squeeze(1) 553 | 554 | if PROMPT_TYPE_FOR_VIDEO == "point": 555 | # sample the negative points from mask for each object 556 | num_points = 1 557 | neg_all_sample_points = sample_points_from_masks( 558 | masks=neg_masks, num_points=num_points 559 | ) 560 | 561 | # Save debug images with bounding boxes and query points for neg_txt_prompt 562 | if debug_dir is not None: 563 | img_path_cv = os.path.join(jpg_dir, frame_names[idx]) 564 | image_cv = cv2.imread(img_path_cv) 565 | # Draw the boxes 566 | for box in neg_input_boxes: 567 | x_min, y_min, x_max, y_max = box.astype(int) 568 | cv2.rectangle( 569 | image_cv, 570 | (x_min, y_min), 571 | (x_max, y_max), 572 | color=(0, 0, 255), 573 | thickness=2, 574 | ) 575 | # Draw the query points 576 | for points in neg_all_sample_points: 577 | for point in points: 578 | x, y = point.astype(int) 579 | cv2.circle( 580 | image_cv, 581 | (x, y), 582 | radius=5, 583 | color=(0, 0, 255), 584 | thickness=-1, 585 | ) 586 | # Save the image 587 | save_name = f"neg_txt_prompt_{neg_txt_prompt_counter}.jpg" 588 | save_path = os.path.join(debug_dir, save_name) 589 | cv2.imwrite(save_path, image_cv) 590 | neg_txt_prompt_counter += 1 591 | 592 | for object_id, (label, points) in enumerate( 593 | zip(neg_class_names, neg_all_sample_points), start=neg_id_start 594 | ): 595 | # label zero means negative (don't mask) 596 | labels = np.zeros((points.shape[0]), dtype=np.int32) 597 | _, _, _ = video_predictor.add_new_points_or_box( 598 | inference_state=inference_state, 599 | frame_idx=txt_prompt_index, 600 | obj_id=object_id, 601 | points=points, 602 | labels=labels, 603 | ) 604 | neg_id_start += len(neg_class_names) 605 | # Handle 'box' and 'mask' prompts similarly 606 | object_id_counter = neg_id_start 607 | 608 | # Clear GPU memory. 609 | image_predictor.model.cpu() 610 | del image_predictor 611 | 612 | """ 613 | Step 4: Propagate the video predictor to get the segmentation results for each frame 614 | """ 615 | video_segments = {} # Contains the per-frame segmentation results 616 | for ( 617 | out_frame_idx, 618 | out_obj_ids, 619 | out_mask_logits, 620 | ) in video_predictor.propagate_in_video(inference_state): 621 | video_segments[out_frame_idx] = { 622 | out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() 623 | for i, out_obj_id in enumerate(out_obj_ids) 624 | } 625 | 626 | """ 627 | Step 5: Save masks and overlay images 628 | """ 629 | if not os.path.exists(output_dir): 630 | os.makedirs(output_dir) 631 | # else: 632 | # png_files = glob.glob(os.path.join(output_dir, "*.png")) 633 | # jpg_files = glob.glob(os.path.join(output_dir, "*.jpg")) 634 | 635 | # Loop through the list of .png and .jpg files and delete them. 636 | # for file in chain(png_files, jpg_files): 637 | # try: 638 | # os.remove(file) 639 | # except Exception as e: 640 | # logging.error(f"Error deleting {file}: {e}") 641 | 642 | # Create 'mask_overlayed' subfolder inside debug_dir 643 | if debug_dir is not None: 644 | overlay_dir = os.path.join(debug_dir, "mask_overlayed") 645 | if not os.path.exists(overlay_dir): 646 | os.makedirs(overlay_dir) 647 | else: 648 | # Clear the overlay directory 649 | files = glob.glob(os.path.join(overlay_dir, "*")) 650 | for f in files: 651 | try: 652 | os.remove(f) 653 | except Exception as e: 654 | logging.error(f"Error deleting {f}: {e}") 655 | 656 | for frame_idx, segments in video_segments.items(): 657 | if gui_frames is None and neg_txt_prompt is not None: 658 | pos_segments = {k: v for k, v in segments.items() if k < neg_id_start_orig} 659 | else: 660 | pos_segments = segments 661 | 662 | if len(pos_segments) == 0: 663 | continue # Skip if there are no positive segments 664 | 665 | masks = list(pos_segments.values()) 666 | masks = np.concatenate(masks, axis=0) 667 | 668 | # Save masks 669 | union_mask = np.any(masks, axis=0) 670 | union_mask_8bit = (union_mask.astype(np.uint8)) * 255 671 | 672 | # Get the corresponding image name and change the extension to .png 673 | image_name = frame_names[frame_idx] 674 | mask_name = os.path.splitext(image_name)[0] + ".png" 675 | 676 | cv2.imwrite(os.path.join(output_dir, mask_name), union_mask_8bit) 677 | 678 | # If debug_dir is specified, save the overlaid image 679 | if debug_dir is not None: 680 | # Read the original image 681 | img_path = os.path.join(jpg_dir, image_name) 682 | image = cv2.imread(img_path) 683 | if image is None: 684 | logging.warning( 685 | f"Could not read image at {img_path}. Skipping overlay." 686 | ) 687 | continue 688 | 689 | # Create a color mask (reddish color) 690 | color_mask = np.zeros_like(image) 691 | color_mask[:, :, 2] = 150 # Red channel intensity 692 | color_mask[:, :, 1] = 0 # Green channel 693 | color_mask[:, :, 0] = 0 # Blue channel 694 | 695 | # Apply the mask to the color mask 696 | mask = union_mask.astype(bool) 697 | alpha = 0.5 # Transparency factor 698 | 699 | if mask.any(): 700 | try: 701 | # Blend the original image and the color mask where mask is True 702 | image[mask] = cv2.addWeighted( 703 | image[mask], 1 - alpha, color_mask[mask], alpha, 0 704 | ) 705 | except Exception as e: 706 | logging.error(f"Error blending mask on image '{image_name}': {e}") 707 | continue 708 | else: 709 | logging.warning(f"No mask to overlay on image '{image_name}'.") 710 | continue 711 | 712 | # Save the overlaid image 713 | overlay_name = os.path.splitext(image_name)[0] + ".jpg" 714 | cv2.imwrite(os.path.join(overlay_dir, overlay_name), image) 715 | 716 | # Delete the jpg folder. 717 | try: 718 | shutil.rmtree(jpg_dir) 719 | except Exception as e: 720 | logging.error(f"Error deleting {jpg_dir}: {e}") 721 | 722 | logging.info("Done segmenting moving object data.") 723 | -------------------------------------------------------------------------------- /scripts/compute_geometric_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for computing geometric error metrics between a GT and reconstructed visual 3 | mesh. 4 | 5 | The meshes are canonicalized and aligned with ICP before computing the metrics. 6 | """ 7 | 8 | import argparse 9 | 10 | import numpy as np 11 | import open3d as o3d 12 | 13 | from scipy.spatial import KDTree 14 | 15 | from scalable_real2sim.output.canonicalize import canonicalize_mesh 16 | 17 | 18 | def compute_chamfer_distance(points1: np.ndarray, points2: np.ndarray) -> float: 19 | """ 20 | Compute the chamfer distance between two point clouds. 21 | """ 22 | # Build kd-trees for fast nearest neighbor queries. 23 | tree1 = KDTree(points1) 24 | tree2 = KDTree(points2) 25 | 26 | # Find nearest neighbors in both directions. 27 | dist1, _ = tree1.query(points2) 28 | dist2, _ = tree2.query(points1) 29 | 30 | # Compute mean and max chamfer distances. 31 | distance = (np.mean(dist1) + np.mean(dist2)) / 2.0 32 | 33 | return distance 34 | 35 | 36 | def sample_points_with_curvature( 37 | mesh: o3d.geometry.TriangleMesh, num_points: int 38 | ) -> np.ndarray: 39 | """ 40 | Sample points from the mesh surface, weighted by curvature. 41 | """ 42 | # Compute vertex normals and curvature. 43 | mesh.compute_vertex_normals() 44 | mesh.compute_triangle_normals() 45 | curvatures = np.linalg.norm(np.asarray(mesh.vertex_normals), axis=1) 46 | 47 | # Normalize curvatures to use as probabilities. 48 | probabilities = curvatures / np.sum(curvatures) 49 | 50 | # Weighted random sampling of vertices. 51 | vertices = np.asarray(mesh.vertices) 52 | sampled_indices = np.random.choice(len(vertices), size=num_points, p=probabilities) 53 | sampled_points = vertices[sampled_indices] 54 | 55 | return sampled_points 56 | 57 | 58 | def sample_points_from_mesh( 59 | mesh: o3d.geometry.TriangleMesh, num_points: int, use_curvature: bool = False 60 | ) -> np.ndarray: 61 | """ 62 | Sample points from a mesh using either uniform or curvature-based sampling. 63 | """ 64 | if use_curvature: 65 | points = sample_points_with_curvature(mesh, num_points) 66 | else: 67 | pcd = mesh.sample_points_uniformly(number_of_points=num_points) 68 | points = np.asarray(pcd.points) 69 | return points 70 | 71 | 72 | def prepare_mesh(mesh_path: str) -> o3d.geometry.TriangleMesh: 73 | """ 74 | Load and canonicalize a mesh. 75 | """ 76 | # Load and get main cluster. 77 | mesh = o3d.io.read_triangle_mesh(mesh_path) 78 | 79 | mesh = canonicalize_mesh(mesh) 80 | 81 | return mesh 82 | 83 | 84 | def compute_fpfh_features( 85 | pcd: o3d.geometry.PointCloud, voxel_size: float 86 | ) -> o3d.pipelines.registration.Feature: 87 | """ 88 | Compute FPFH features for a point cloud. 89 | """ 90 | radius_normal = voxel_size * 2 91 | radius_feature = voxel_size * 5 92 | 93 | pcd.estimate_normals( 94 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30) 95 | ) 96 | 97 | fpfh = o3d.pipelines.registration.compute_fpfh_feature( 98 | pcd, o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100) 99 | ) 100 | return fpfh 101 | 102 | 103 | def refine_alignment_with_icp( 104 | mesh1: o3d.geometry.TriangleMesh, mesh2: o3d.geometry.TriangleMesh 105 | ) -> o3d.geometry.TriangleMesh: 106 | """ 107 | Refine the alignment of two meshes using feature matching and ICP. 108 | """ 109 | # Sample points and create point clouds. 110 | source_pcd = mesh1.sample_points_uniformly(number_of_points=5000) 111 | target_pcd = mesh2.sample_points_uniformly(number_of_points=5000) 112 | 113 | # Voxel downsampling. 114 | voxel_size = 0.005 115 | source_down = source_pcd.voxel_down_sample(voxel_size) 116 | target_down = target_pcd.voxel_down_sample(voxel_size) 117 | 118 | # Estimate normals. 119 | source_down.estimate_normals( 120 | search_param=o3d.geometry.KDTreeSearchParamHybrid( 121 | radius=voxel_size * 2, max_nn=30 122 | ) 123 | ) 124 | target_down.estimate_normals( 125 | search_param=o3d.geometry.KDTreeSearchParamHybrid( 126 | radius=voxel_size * 2, max_nn=30 127 | ) 128 | ) 129 | 130 | # Initial alignment using center of mass. 131 | source_center = source_down.get_center() 132 | target_center = target_down.get_center() 133 | initial_translation = target_center - source_center 134 | 135 | init_transform = np.eye(4) 136 | init_transform[:3, 3] = initial_translation 137 | 138 | # Compute FPFH features. 139 | source_fpfh = compute_fpfh_features(source_down, voxel_size) 140 | target_fpfh = compute_fpfh_features(target_down, voxel_size) 141 | 142 | # Global registration using RANSAC. 143 | result_ransac = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( 144 | source=source_down, 145 | target=target_down, 146 | source_feature=source_fpfh, 147 | target_feature=target_fpfh, 148 | mutual_filter=False, 149 | max_correspondence_distance=voxel_size * 5, 150 | estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint( 151 | False 152 | ), 153 | ransac_n=3, 154 | checkers=[ 155 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.5), 156 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance( 157 | voxel_size * 5 158 | ), 159 | ], 160 | criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(1000000, 500), 161 | ) 162 | 163 | # Use RANSAC result if successful, otherwise use initial translation. 164 | initial_alignment = ( 165 | result_ransac.transformation if result_ransac.fitness > 0 else init_transform 166 | ) 167 | 168 | # Refine with ICP. 169 | reg_p2p = o3d.pipelines.registration.registration_icp( 170 | source_down, 171 | target_down, 172 | max_correspondence_distance=voxel_size * 5, 173 | init=initial_alignment, 174 | estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(), 175 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria( 176 | max_iteration=500, 177 | relative_fitness=1e-7, 178 | relative_rmse=1e-7, 179 | ), 180 | ) 181 | 182 | # Transform the source mesh. 183 | mesh1.transform(reg_p2p.transformation) 184 | return mesh1 185 | 186 | 187 | def compute_f_score( 188 | points1: np.ndarray, points2: np.ndarray, threshold: float = 0.01 189 | ) -> tuple[float, float, float]: 190 | """ 191 | Compute F-score between two point clouds. 192 | points1 is treated as prediction, points2 as ground truth. 193 | """ 194 | tree2 = KDTree(points2) 195 | tree1 = KDTree(points1) 196 | 197 | # For precision: How many points in points1 (prediction) are close to points2 (GT)? 198 | dist_p = tree2.query(points1)[0] 199 | precision = np.mean(dist_p < threshold) 200 | 201 | # For recall: How many points in points2 (GT) are close to points1 (prediction)? 202 | dist_r = tree1.query(points2)[0] 203 | recall = np.mean(dist_r < threshold) 204 | 205 | # F-score is the harmonic mean. 206 | f_score = 2 * precision * recall / (precision + recall + 1e-8) 207 | 208 | return f_score, precision, recall 209 | 210 | 211 | def compute_normal_consistency( 212 | mesh1: o3d.geometry.TriangleMesh, 213 | mesh2: o3d.geometry.TriangleMesh, 214 | num_points: int = 10000, 215 | ) -> float: 216 | """ 217 | Compute normal consistency between two meshes. 218 | Samples points and their normals, then compares normal directions at closest points. 219 | """ 220 | # Compute vertex normals for both meshes. 221 | mesh1.compute_vertex_normals() 222 | mesh2.compute_vertex_normals() 223 | 224 | # Sample points and get their normals directly from the mesh. 225 | pcd1 = mesh1.sample_points_poisson_disk(num_points) 226 | pcd2 = mesh2.sample_points_poisson_disk(num_points) 227 | 228 | points1 = np.asarray(pcd1.points) 229 | points2 = np.asarray(pcd2.points) 230 | normals1 = np.asarray(pcd1.normals) 231 | normals2 = np.asarray(pcd2.normals) 232 | 233 | # Find closest point correspondences. 234 | tree = KDTree(points2) 235 | _, indices = tree.query(points1) 236 | 237 | # Compute absolute dot product between corresponding normals. 238 | normal_consistency = np.mean(np.abs(np.sum(normals1 * normals2[indices], axis=1))) 239 | 240 | return normal_consistency 241 | 242 | 243 | def compute_iou( 244 | mesh1: o3d.geometry.TriangleMesh, 245 | mesh2: o3d.geometry.TriangleMesh, 246 | resolution: int = 100, 247 | ) -> float: 248 | """ 249 | Compute IoU (Intersection over Union) using voxel occupancy. 250 | """ 251 | # Compute voxel size based on combined bounds of both meshes 252 | combined_bounds = np.vstack( 253 | [ 254 | mesh1.get_min_bound(), 255 | mesh1.get_max_bound(), 256 | mesh2.get_min_bound(), 257 | mesh2.get_max_bound(), 258 | ] 259 | ) 260 | combined_size = np.linalg.norm( 261 | combined_bounds.max(axis=0) - combined_bounds.min(axis=0) 262 | ) 263 | voxel_size = combined_size / resolution 264 | 265 | # Create voxel grids for both meshes using the same voxel size 266 | voxel1 = o3d.geometry.VoxelGrid.create_from_triangle_mesh( 267 | mesh1, voxel_size=voxel_size 268 | ) 269 | voxel2 = o3d.geometry.VoxelGrid.create_from_triangle_mesh( 270 | mesh2, voxel_size=voxel_size 271 | ) 272 | 273 | # Get voxel indices. 274 | voxels1 = set( 275 | (v.grid_index[0], v.grid_index[1], v.grid_index[2]) for v in voxel1.get_voxels() 276 | ) 277 | voxels2 = set( 278 | (v.grid_index[0], v.grid_index[1], v.grid_index[2]) for v in voxel2.get_voxels() 279 | ) 280 | 281 | # Compute intersection and union. 282 | intersection = len(voxels1.intersection(voxels2)) 283 | union = len(voxels1.union(voxels2)) 284 | 285 | iou = intersection / (union + 1e-8) 286 | return iou 287 | 288 | 289 | def main(): 290 | parser = argparse.ArgumentParser() 291 | parser.add_argument("gt_mesh", type=str, help="Path to the ground truth mesh") 292 | parser.add_argument("pred_mesh", type=str, help="Path to the predicted mesh") 293 | parser.add_argument( 294 | "--num_points", 295 | type=int, 296 | default=1000000, 297 | help="Number of points to sample from each mesh", 298 | ) 299 | parser.add_argument( 300 | "--use_curvature", 301 | action="store_true", 302 | help="Use curvature-based sampling for the meshes", 303 | ) 304 | parser.add_argument( 305 | "--vis", action="store_true", help="Visualize the aligned meshes" 306 | ) 307 | parser.add_argument( 308 | "--f_score_threshold", 309 | type=float, 310 | default=0.001, 311 | help="Distance threshold for F-score computation", 312 | ) 313 | parser.add_argument( 314 | "--iou_resolution", 315 | type=int, 316 | default=25, 317 | help="Resolution for IoU computation", 318 | ) 319 | args = parser.parse_args() 320 | num_points = args.num_points 321 | use_curvature = args.use_curvature 322 | f_score_threshold = args.f_score_threshold 323 | iou_resolution = args.iou_resolution 324 | 325 | # Prepare both meshes. 326 | gt_mesh = prepare_mesh(args.gt_mesh) 327 | raw_pred_mesh = prepare_mesh(args.pred_mesh) 328 | 329 | # Refine alignment using ICP. 330 | pred_mesh = refine_alignment_with_icp(raw_pred_mesh, gt_mesh) 331 | 332 | # Sample points after ICP refinement. 333 | gt_points = sample_points_from_mesh(gt_mesh, num_points, use_curvature) 334 | pred_points = sample_points_from_mesh(pred_mesh, num_points, use_curvature) 335 | 336 | # Compute all metrics. 337 | chamfer_distance = compute_chamfer_distance(pred_points, gt_points) 338 | f_score, precision, recall = compute_f_score( 339 | pred_points, gt_points, f_score_threshold 340 | ) 341 | normal_consistency = compute_normal_consistency(pred_mesh, gt_mesh) 342 | iou = compute_iou(pred_mesh, gt_mesh, resolution=iou_resolution) 343 | 344 | # Print results. 345 | print("\nMesh Comparison Metrics:") 346 | print(f"Chamfer Distance (mm): {chamfer_distance * 1000:.6f}") 347 | print(f"F-score: {f_score:.6f}") 348 | print(f" Precision: {precision:.6f} (% of predicted points matched)") 349 | print(f" Recall: {recall:.6f} (% of ground truth points matched)") 350 | print(f"Normal Consistency: {normal_consistency:.6f}") 351 | print(f"IoU: {iou:.6f}") 352 | 353 | if args.vis: 354 | # Create point clouds for visualization. 355 | pcd1 = o3d.geometry.PointCloud() 356 | pcd2 = o3d.geometry.PointCloud() 357 | pcd1.points = o3d.utility.Vector3dVector(pred_points) 358 | pcd2.points = o3d.utility.Vector3dVector(gt_points) 359 | 360 | pcd1.paint_uniform_color([1, 0, 0]) # Red 361 | pcd2.paint_uniform_color([0, 0, 1]) # Blue 362 | frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) 363 | o3d.visualization.draw_geometries([pcd1, pcd2, frame]) 364 | 365 | 366 | if __name__ == "__main__": 367 | main() 368 | -------------------------------------------------------------------------------- /scripts/segment_moving_obj_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import shutil 5 | 6 | from scalable_real2sim.data_processing.image_subsampling import ( 7 | select_and_copy_dissimilar_images, 8 | ) 9 | from scalable_real2sim.segmentation.segment_moving_object_data import ( 10 | segment_moving_obj_data, 11 | ) 12 | 13 | 14 | def downsample_images(rgb_dir: str, num_images: int) -> None: 15 | # Check if we need to subsample images. 16 | rgb_files = sorted(os.listdir(rgb_dir)) 17 | if len(rgb_files) > num_images: 18 | print(f"Found {len(rgb_files)} images, subsampling to {num_images}...") 19 | 20 | # Move original images to rgb_original. 21 | rgb_original_dir = os.path.join(rgb_dir, "..", "rgb_original") 22 | os.makedirs(rgb_original_dir, exist_ok=True) 23 | for f in rgb_files: 24 | shutil.move(os.path.join(rgb_dir, f), os.path.join(rgb_original_dir, f)) 25 | 26 | num_uniform_frames = num_images // 2 # Big gaps cause tracking to fail 27 | max_frame_gap = math.floor(len(rgb_files) / num_uniform_frames) 28 | print(f"Using a maximum frame gap of {max_frame_gap} for image downsampling.") 29 | select_and_copy_dissimilar_images( 30 | image_dir=rgb_original_dir, 31 | output_dir=rgb_dir, 32 | K=num_images, 33 | N=max_frame_gap, 34 | model_name="dino", 35 | ) 36 | print("Image subsampling complete.") 37 | else: 38 | print("No image subsampling needed.") 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument( 44 | "rgb_dir", type=str, help="Path to the folder containing RGB frames" 45 | ) 46 | parser.add_argument( 47 | "output_dir", 48 | type=str, 49 | help="Path to the folder to save binary masks where 1 indicates the object " 50 | "of interest and 0 indicates the background.", 51 | ) 52 | parser.add_argument( 53 | "--txt_prompt", 54 | type=str, 55 | help="Text prompt to use for grounding the object of interest.", 56 | default=None, 57 | ) 58 | parser.add_argument( 59 | "--txt_prompt_index", 60 | type=int, 61 | help="Index of the frame to use for grounding the object of interest.", 62 | default=0, 63 | ) 64 | parser.add_argument( 65 | "--neg_txt_prompt", 66 | type=str, 67 | help="Text prompt to use for grounding the object to ignore.", 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "--num_neg_frames", 72 | type=int, 73 | help="Number of frames to add negatives to (uniformly spaced). Increasing this " 74 | "might lead to OOM.", 75 | default=10, 76 | ) 77 | parser.add_argument( 78 | "--debug_dir", 79 | type=str, 80 | help="Path to the folder to save images with predicted bounding boxes and " 81 | "query points overlaid.", 82 | default=None, 83 | ) 84 | parser.add_argument( 85 | "--gui_frames", 86 | type=str, 87 | nargs="*", 88 | help="List of RGB image names (without extension) to provide GUI labels for. " 89 | "NOTE: Segmentation will start from the first frame in the list. Hence, if you " 90 | "want to segment all images, you should specify the first frame.", 91 | default=None, 92 | ) 93 | parser.add_argument( 94 | "--num_images", 95 | type=int, 96 | help="Number of images to subsample to.", 97 | default=None, 98 | ) 99 | args = parser.parse_args() 100 | rgb_dir = args.rgb_dir 101 | output_dir = args.output_dir 102 | txt_prompt = args.txt_prompt 103 | txt_prompt_index = args.txt_prompt_index 104 | neg_txt_prompt = args.neg_txt_prompt 105 | num_neg_frames = args.num_neg_frames 106 | debug_dir = args.debug_dir 107 | gui_frames = args.gui_frames 108 | num_images = args.num_images 109 | 110 | if num_images is not None: 111 | downsample_images(rgb_dir, num_images) 112 | 113 | segment_moving_obj_data( 114 | rgb_dir=rgb_dir, 115 | output_dir=output_dir, 116 | txt_prompt=txt_prompt, 117 | txt_prompt_index=txt_prompt_index, 118 | neg_txt_prompt=neg_txt_prompt, 119 | num_neg_frames=num_neg_frames, 120 | debug_dir=debug_dir, 121 | gui_frames=gui_frames, 122 | ) 123 | -------------------------------------------------------------------------------- /scripts/subtract_masks.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script subtracts binary masks from two folders. 3 | This can be useful for subtracting the object masks from the gripper masks if the 4 | gripper masks include the object as can happen in bad lighting conditions or when the 5 | object has colors that are similar to the gripper. 6 | """ 7 | 8 | import argparse 9 | import os 10 | 11 | import cv2 12 | import numpy as np 13 | 14 | from tqdm import tqdm 15 | 16 | 17 | def subtract_masks(folder_a: str, folder_b: str, output_folder: str) -> None: 18 | """ 19 | Subtracts binary masks from two folders and saves the result in an output folder. 20 | 21 | This function reads binary mask images from two specified folders, 22 | subtracts the masks, and saves the resulting masks in the output folder. 23 | If both masks are white (255), the result will be set to black (0). 24 | The function only processes files that are common to both folders. 25 | 26 | Args: 27 | folder_a (str): Path to the first folder containing binary masks. 28 | folder_b (str): Path to the second folder containing binary masks. 29 | output_folder (str): Path to the folder where the output masks will be saved. 30 | """ 31 | os.makedirs(output_folder, exist_ok=True) 32 | 33 | files_a = set(os.listdir(folder_a)) 34 | files_b = set(os.listdir(folder_b)) 35 | 36 | common_files = files_a.intersection(files_b) 37 | 38 | for filename in tqdm(common_files): 39 | path_a = os.path.join(folder_a, filename) 40 | path_b = os.path.join(folder_b, filename) 41 | 42 | mask_a = cv2.imread(path_a, cv2.IMREAD_GRAYSCALE) 43 | mask_b = cv2.imread(path_b, cv2.IMREAD_GRAYSCALE) 44 | 45 | if mask_a is None or mask_b is None: 46 | print(f"Skipping {filename}: Could not load image") 47 | continue 48 | 49 | # Ensure same size 50 | if mask_a.shape != mask_b.shape: 51 | print( 52 | f"Skipping {filename}: Shape mismatch {mask_a.shape} vs {mask_b.shape}" 53 | ) 54 | continue 55 | 56 | # Subtract masks: If both are white, set to black 57 | result = np.where((mask_a == 255) & (mask_b == 255), 0, mask_a) 58 | 59 | output_path = os.path.join(output_folder, filename) 60 | cv2.imwrite(output_path, result) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser(description="Subtract binary masks") 65 | parser.add_argument("folder_a", type=str, help="Path to first folder") 66 | parser.add_argument("folder_b", type=str, help="Path to second folder") 67 | parser.add_argument( 68 | "output_folder", type=str, help="Path to output folder that is mask_a - mask_b." 69 | ) 70 | 71 | args = parser.parse_args() 72 | subtract_masks(args.folder_a, args.folder_b, args.output_folder) 73 | --------------------------------------------------------------------------------