├── .gitignore ├── LICENSE.txt ├── README.md ├── configs ├── dataset │ ├── cmap_dataset.yaml │ └── pretrain_dataset.yaml ├── log.yaml ├── model.yaml ├── pretrain.yaml ├── train.yaml └── validate.yaml ├── data_utils ├── CMapDataset.py ├── PretrainDataset.py ├── filter_dataset.py ├── generate_pc.py └── removed_links.json ├── model ├── encoder.py ├── latent_encoder.py ├── mlp.py ├── module.py ├── network.py └── transformer.py ├── pipeline.jpg ├── pretrain.py ├── requirements.txt ├── scripts ├── download_ckpt.sh ├── download_data.sh ├── example_isaac.py ├── example_pretrain.py └── pretrain_order.py ├── train.py ├── utils ├── controller.py ├── func_utils.py ├── hand_model.py ├── mesh_utils.py ├── multilateration.py ├── optimization.py ├── pretrain_utils.py ├── rotation.py ├── se3_transform.py └── vis_utils.py ├── validate.py ├── validation ├── __init__.py ├── asset_info.py ├── isaac_main.py ├── isaac_validator.py └── validate_utils.py └── visualization ├── vis_controller.py ├── vis_dataset.py ├── vis_hand_joint.py ├── vis_hand_link.py ├── vis_optimization.py ├── vis_pretrain.py └── vis_validation.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | .idea/ 3 | logs/ 4 | lightning_logs/ 5 | 6 | ckpt/ 7 | data/ 8 | output/ 9 | outputs/ 10 | tmp/ 11 | validate_output/ 12 | vis_info/ 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhenyu Wei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # $\mathcal{D(R,O)}$ Grasp 2 | 3 | Official Code Repository for **$\mathcal{D(R,O)}$ Grasp: A Unified Representation of Robot and Object Interaction for Cross-Embodiment Dexterous Grasping**. 4 | 5 | [Zhenyu Wei](https://zhenyuwei2003.github.io/)1,2\*, [Zhixuan Xu](https://ariszxxu.github.io/)1\*, [Jingxiang Guo](https://borisguo6.github.io)1, [Yiwen Hou](https://houyiwen.github.io/)1, [Chongkai Gao](https://chongkaigao.com/)1, Zhehao Cai1, Jiayu Luo1, [Lin Shao](https://linsats.github.io/)1 6 | 7 | 1National University of Singapore, 2Shanghai Jiao Tong University 8 | 9 | * denotes equal contribution 10 | 11 |

12 | 13 | Paper arXiv 14 | 15 | 16 | Paper PDF 17 | 18 | 19 | Project Page 20 | 21 |

22 |
23 | main 24 |
25 | 26 | 27 | In this paper, we present $\mathcal{D(R,O)}$ Grasp, a novel framework that models the interaction between the robotic hand in its grasping pose and the object, enabling broad generalization across various robot hands and object geometries. Our model takes the robot hand’s description and object point cloud as inputs and efficiently predicts kinematically valid and stable grasps, demonstrating strong adaptability to diverse robot embodiments and object geometries. 28 | 29 | ---------------- 30 | 31 | ## Prerequisites: 32 | 33 | - Python 3.8 34 | - PyTorch >= 2.3.0 35 | 36 | ## Get Started 37 | 38 | ### 1. Create Python Environment 39 | 40 | ```bash 41 | conda create -n dro python==3.8 42 | conda activate dro 43 | ``` 44 | 45 | ### 2. Install Isaac Gym Environment (Optional) 46 | 47 | You don't need to install Isaac Gym for training and pretraining. If evaluating grasps in Isaac Gym isn't required, you can skip this step. 48 | 49 | Download [Isaac Gym](https://developer.nvidia.com/isaac-gym/download) from the official website, then: 50 | 51 | ```bash 52 | tar -xvf IsaacGym_Preview_4_Package.tar.gz 53 | cd isaacgym/python 54 | pip install -e . 55 | ``` 56 | 57 | ### 3. Install Packages 58 | 59 | Change the project directory, then run: 60 | 61 | ```bash 62 | pip install -r requirements.txt 63 | ``` 64 | 65 | ### 4. Weights & Biases (Optional) 66 | 67 | This project use Weights & Biases to monitor loss curves. If you're not familiar with it, refer to the [W&B Tutorials](https://docs.wandb.ai/tutorials/) if you haven't used before. Alternatively, you can disable the related sections in `train.py` and `pretrain.py`. 68 | 69 | ## Example 70 | 71 | Download our [checkpoint models](https://github.com/zhenyuwei2003/DRO-Grasp/releases/tag/v1.0) and unzip the contents into the `ckpt/` folder, or simply execute: 72 | 73 | ```bash 74 | bash scripts/download_ckpt.sh 75 | ``` 76 | 77 | To verify that the Isaac Gym environment is correctly installed and to evaluate the performance of our model, run `python scripts/example_isaac.py`. You can also run `python scripts/example_pretrain.py` to obtain the matching order of our pretrained model, which is a good indicator of its effectiveness. You can visualize the correspondence matching results by running `python visualizatino/vis_pretrain.py`. 78 | 79 | ## How to use? 80 | 81 | ### Pretraining 82 | 83 | You need to modify the configuration file based on your requirements. Below are the key parameters commonly adjusted in the `config/` folder: 84 | 85 | - `pretrain.yaml` 86 | - `name`: Specify the pretraining model name. 87 | - `gpu`: Set the GPU ID based on the available GPU device(s). 88 | - `training/max_epochs`: Define the number of pretraining epochs. 89 | - `dataset/pretrain_dataset.yaml` 90 | - `robot_names`: Provide the list of robot names to be used for pretraining. 91 | 92 | After updating the config file, simply run: 93 | 94 | ```bash 95 | python pretrain.py 96 | ``` 97 | 98 | To assess the performance of the pretrained model, which is best indicated by lower matching order values, you can run the following command: 99 | 100 | ```bash 101 | python scripts/pretrain_order.py \ 102 | --pretrain_ckpt pretrain_3robots \ # specify your model name 103 | --data_num 200 \ # number of grasps for one robot 104 | --epoch_list 10,20,30,40,50 \ # epochs of the pretrained model you want to test 105 | --robot_names barrett,allegro,shadowhand # list of robots 106 | ``` 107 | 108 | ### Training 109 | 110 | You need to modify the configuration file based on your requirements. Below are the key parameters commonly adjusted in the `config/` folder: 111 | 112 | - `train.yaml` 113 | - `name`: Specify the training model name. 114 | - `gpu`: Set the GPU ID based on the available GPU device(s). 115 | - `training/max_epochs`: Define the number of training epochs. 116 | - `model.yaml` 117 | - `pretrain`: Specify the name of the pretrained model, which should be placed in the `ckpt/` folder. 118 | - `dataset/cmap_dataset.yaml` 119 | - `robot_names`: Provide the list of robot names to be used for pretraining. 120 | - `batch_size`: Set the dataloader batch size as large as possible. Note that a batch size of 1 will roughly consume 4 GB of GPU memory, and it cannot be set to 1 during training due to batch normalization requirements. 121 | - `object_pc_type`: Use `random` for major experiments and `partial` for partial object point cloud input. This parameter should remain the same during training and validation. 122 | 123 | After updating the config file, simply run: 124 | 125 | ``` 126 | python train.py 127 | ``` 128 | 129 | ### Validation 130 | 131 | You need to modify the configuration file based on your requirements. Below are the key parameters commonly adjusted in the `config/` folder: 132 | 133 | - `validate.yaml` 134 | - `name`: Specify the model name you want to validate. 135 | - `gpu`: Set the GPU ID based on the available GPU device(s). 136 | - `split_batch_size`: Set the number of grasps to run in parallel in Isaac Gym, constrained by GPU memory. Maximize this value to speed up the validation process. 137 | - `validate_epochs`: Specify the list of epochs of the trained model to validate. 138 | - `dataset/batch_size`: The total number of grasps for each `(robot, object)` combination to validate. 139 | - `dataset/cmap_dataset.yaml` 140 | - `robot_names`: Provide the list of robot names to be used for validation. 141 | - `batch_size`: Overwritten by `validate.yaml`, ignored. 142 | - `object_pc_type`: Keep this the same as during training. 143 | 144 | After updating the config file, simply run: 145 | 146 | ``` 147 | python validate.py 148 | ``` 149 | 150 | ## Dataset 151 | 152 | You can download our filtered dataset, URDF files and point clouds [here](https://github.com/zhenyuwei2003/DRO-Grasp/releases/tag/v1.0) and unzip the contents into the `data/` folder, or simply execute: 153 | 154 | ```bash 155 | bash scripts/download_data.sh 156 | ``` 157 | 158 | The original `MultiDex` and `CMapDataset` are also included for pretraining purposes. For more details on these datasets, refer to [GenDexGrasp](https://github.com/tengyu-liu/GenDexGrasp). 159 | 160 | ## Repository Structure 161 | 162 | ```bash 163 | DRO-Grasp 164 | ├── ckpt # Checkpoint models stored here 165 | ├── configs # Configuration files 166 | ├── data # Datasets downloaded and stored here 167 | ├── data_utils # Dataset classes and related scripts 168 | ├── model # Network architecture code 169 | ├── output # Saved model checkpoints and log files 170 | ├── tmp # Temporary files generated during subprocess execution 171 | ├── scripts 172 | │   ├── download_ckpt.sh # Download checkpoint models 173 | │   ├── download_data.sh # Download data 174 | │   ├── example_isaac.py # Test Isaac Gym environment and evaluate our model's performance 175 | │   ├── example_pretrain.py # Evaluate our pretrained model's performance 176 | │   └── pretrain_order.py # Evaluate performance of pretrained model 177 | ├── utils # Various utility scripts 178 | ├── validate # Scripts for validating in Isaac Gym 179 | ├── validate_output # Validation results saved here 180 | ├── vis_info # Visualization information saved during validation 181 | └── visualization 182 | │ ├── vis_controller.py # Visualize the effect of the grasp controller 183 | │ ├── vis_dataset.py # Visualize grasps from the dataset 184 | │ ├── vis_hand_joint.py # Visualize hand joint movements 185 | │ ├── vis_hand_link.py # Visualize hand links 186 | │ ├── vis_pretrain.py # Visualize pretrain matching results 187 | │ └── vis_validation.py # Visualize validation results 188 | ├── pretrain.py # Main scripts for pretraining 189 | ├── train.py # Main scripts for training 190 | └── validate.py # Main scripts for validation 191 | ``` 192 | 193 | ## Steps to Apply our Method to a New Hand 194 | 195 | 1. Modify your hand's URDF. You can refer to an existing URDF file for guidance on making modifications. 196 | - Add virtual joints between the world and the robot to represent the base link transform. Ensure these joint names start with `virtual` so the controller can ignore them. 197 | - Use `visualization/vis_optimization.py` to view the optimization result. If no fixed joint exists at the tip links, the optimized point cloud may not align with the target point cloud at those links. In this case, add extra links beyond each tip link and ensure their names begin with `extra` so their transforms are processed before optimization. 198 | 2. Add the hand's URDF and mesh paths to `data/data_urdf/robot/urdf_assets_meta.json` 199 | 3. Specify redundant link names in `data_utils/remove_links.json`. You can visualize the links using `visualization/vis_hand_link.py` to identify which links are irrelevant for contact. 200 | 4. Use `data_utils/generate_pc.py` to sample point clouds for each robot link and save them. 201 | 5. Annotate the `link_dir` for all links in the `get_link_dir()` function of `utils/controller.py`. You can adjust the direction of each link using `vis_hand_direction()` in `visualization/vis_controller.py`, ensuring that the arrow points in the correct direction of motion when the joint value increases. 202 | 6. Pretrain and train the model using your own grasp dataset. 203 | 204 | ## Citation 205 | 206 | If you find our codes or models useful in your work, please cite [our paper](https://arxiv.org/abs/2410.01702): 207 | 208 | ``` 209 | @article{wei2024dro, 210 | title={D(R,O) Grasp: A Unified Representation of Robot and Object Interaction for Cross-Embodiment Dexterous Grasping}, 211 | author={Wei, Zhenyu and Xu, Zhixuan and Guo, Jingxiang and Hou, Yiwen and Gao, Chongkai and Cai, Zhehao and Luo, Jiayu and Shao, Lin}, 212 | journal={arXiv preprint arXiv:2410.01702}, 213 | year={2024} 214 | } 215 | ``` 216 | 217 | ## Contact 218 | 219 | If you have any questions, feel free to contact me through email ([Zhenyu_Wei@sjtu.edu.cn](mailto:Zhenyu_Wei@sjtu.edu.cn))! 220 | -------------------------------------------------------------------------------- /configs/dataset/cmap_dataset.yaml: -------------------------------------------------------------------------------- 1 | robot_names: 2 | - 'barrett' 3 | - 'allegro' 4 | - 'shadowhand' 5 | 6 | debug_object_names: null # use part of the dataset, for debug only 7 | 8 | batch_size: 8 # 1 batch_size ≈ 4 GB GPU memory during training 9 | num_workers: 16 10 | object_pc_type: 'random' # 'fixed', 'random', 'partial' -------------------------------------------------------------------------------- /configs/dataset/pretrain_dataset.yaml: -------------------------------------------------------------------------------- 1 | robot_names: 2 | - 'barrett' 3 | - 'allegro' 4 | - 'shadowhand' 5 | 6 | batch_size: 32 7 | num_workers: 16 8 | -------------------------------------------------------------------------------- /configs/log.yaml: -------------------------------------------------------------------------------- 1 | output_dir: 'output/${name}' 2 | log_dir: '${output_dir}/log' 3 | 4 | hydra: 5 | run: 6 | dir: '${log_dir}/hydra' 7 | sweep: 8 | dir: '${log_dir}/hydra-multirun/' 9 | 10 | wandb: 11 | save_dir: '${log_dir}' 12 | project: 'DROGrasp' 13 | 14 | lightning: 15 | checkpoint_dir: '${log_dir}/checkpoints' -------------------------------------------------------------------------------- /configs/model.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | emb_dim: 512 3 | latent_dim: 64 4 | 5 | pretrain: pretrain_3robots.pth 6 | 7 | center_pc: True 8 | block_computing: True -------------------------------------------------------------------------------- /configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - log.yaml 3 | - model.yaml 4 | - dataset: pretrain_dataset.yaml 5 | - _self_ 6 | 7 | name: 'pretrain_3robots' 8 | 9 | seed: 42 10 | gpu: 11 | - 0 12 | - 1 13 | - 2 14 | - 3 15 | - 4 16 | - 5 17 | - 6 18 | - 7 19 | 20 | load_from_checkpoint: False 21 | log_every_n_steps: 5 22 | 23 | wandb: # override log.yaml 24 | project: 'DROGrasp-Pretrain' 25 | 26 | model: # override model.yaml 27 | encoder: 28 | pretrain: null 29 | 30 | training: 31 | max_epochs: 100 32 | 33 | save_dir: '${output_dir}/state_dict' 34 | save_every_n_epoch: 5 35 | 36 | lr: 1e-4 37 | temperature: 0.1 38 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - log.yaml 3 | - model.yaml 4 | - dataset: cmap_dataset.yaml 5 | - _self_ 6 | 7 | name: 'model_3robots' 8 | 9 | seed: 42 10 | gpu: 11 | - 0 12 | - 1 13 | - 2 14 | - 3 15 | - 4 16 | - 5 17 | - 6 18 | - 7 19 | 20 | load_from_checkpoint: False 21 | log_every_n_steps: 5 22 | 23 | training: 24 | max_epochs: 200 25 | 26 | save_dir: '${output_dir}/state_dict' 27 | save_every_n_epoch: 5 28 | 29 | lr: 1e-4 30 | 31 | loss_kl: True 32 | loss_kl_weight: 0.01 33 | 34 | loss_r: True 35 | loss_r_weight: 1 36 | 37 | loss_se3: True 38 | loss_se3_weight: 0.01 39 | 40 | loss_depth: True 41 | loss_depth_weight: 1 42 | -------------------------------------------------------------------------------- /configs/validate.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model.yaml 3 | - dataset: cmap_dataset.yaml 4 | - _self_ 5 | 6 | model: # override model.yaml 7 | encoder: 8 | pretrain: null 9 | 10 | dataset: # override cmap_dataset.yaml 11 | batch_size: 100 # total grasps of (robot, object) combination 12 | 13 | split_batch_size: 25 # limited by GPU memory 14 | gpu: 0 15 | name: '3robots' 16 | validate_epochs: 17 | - 10 18 | - 20 19 | - 30 20 | - 40 21 | # - 50 22 | # - 60 23 | # - 70 24 | # - 80 25 | # - 90 26 | # - 100 27 | -------------------------------------------------------------------------------- /data_utils/CMapDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import math 5 | import hydra 6 | import random 7 | import trimesh 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(ROOT_DIR) 13 | 14 | from utils.hand_model import create_hand_model 15 | 16 | 17 | class CMapDataset(Dataset): 18 | def __init__( 19 | self, 20 | batch_size: int, 21 | robot_names: list = None, 22 | is_train: bool = True, 23 | debug_object_names: list = None, 24 | num_points: int = 512, 25 | object_pc_type: str = 'random' 26 | ): 27 | self.batch_size = batch_size 28 | self.robot_names = robot_names if robot_names is not None \ 29 | else ['barrett', 'allegro', 'shadowhand'] 30 | self.is_train = is_train 31 | self.num_points = num_points 32 | self.object_pc_type = object_pc_type 33 | 34 | self.hands = {} 35 | self.dofs = [] 36 | for robot_name in self.robot_names: 37 | self.hands[robot_name] = create_hand_model(robot_name, torch.device('cpu')) 38 | self.dofs.append(math.sqrt(self.hands[robot_name].dof)) 39 | 40 | split_json_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/split_train_validate_objects.json') 41 | dataset_split = json.load(open(split_json_path)) 42 | self.object_names = dataset_split['train'] if is_train else dataset_split['validate'] 43 | if debug_object_names is not None: 44 | print("!!! Using debug objects !!!") 45 | self.object_names = debug_object_names 46 | 47 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 48 | metadata = torch.load(dataset_path)['metadata'] 49 | self.metadata = [m for m in metadata if m[1] in self.object_names and m[2] in self.robot_names] 50 | if not self.is_train: 51 | self.combination = [] 52 | for robot_name in self.robot_names: 53 | for object_name in self.object_names: 54 | self.combination.append((robot_name, object_name)) 55 | self.combination = sorted(self.combination) 56 | # print(len(self.metadata)) 57 | # print(len(self.combination)) 58 | 59 | self.object_pcs = {} 60 | if self.object_pc_type != 'fixed': 61 | for object_name in self.object_names: 62 | name = object_name.split('+') 63 | mesh_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') 64 | mesh = trimesh.load_mesh(mesh_path) 65 | object_pc, _ = mesh.sample(65536, return_index=True) 66 | self.object_pcs[object_name] = torch.tensor(object_pc, dtype=torch.float32) 67 | else: 68 | print("!!! Using fixed object pcs !!!") 69 | 70 | def __getitem__(self, index): 71 | """ 72 | Train: sample a batch of data 73 | Validate: get (robot, object) from index, sample a batch of data 74 | """ 75 | if self.is_train: 76 | robot_name_batch = [] 77 | object_name_batch = [] 78 | robot_links_pc_batch = [] 79 | robot_pc_initial_batch = [] 80 | robot_pc_target_batch = [] 81 | object_pc_batch = [] 82 | dro_gt_batch = [] 83 | initial_q_batch = [] 84 | target_q_batch = [] 85 | for idx in range(self.batch_size): 86 | robot_name = random.choice(self.robot_names) 87 | robot_name_batch.append(robot_name) 88 | hand = self.hands[robot_name] 89 | metadata_robot = [(m[0], m[1]) for m in self.metadata if m[2] == robot_name] 90 | 91 | target_q, object_name = random.choice(metadata_robot) 92 | target_q_batch.append(target_q) 93 | object_name_batch.append(object_name) 94 | 95 | robot_links_pc_batch.append(hand.links_pc) 96 | 97 | if self.object_pc_type == 'fixed': 98 | name = object_name.split('+') 99 | object_path = os.path.join(ROOT_DIR, f'data/PointCloud/object/{name[0]}/{name[1]}.pt') 100 | object_pc = torch.load(object_path)[:, :3] 101 | elif self.object_pc_type == 'random': 102 | indices = torch.randperm(65536)[:self.num_points] 103 | object_pc = self.object_pcs[object_name][indices] 104 | object_pc += torch.randn(object_pc.shape) * 0.002 105 | else: # 'partial', remove 50% points 106 | indices = torch.randperm(65536)[:self.num_points * 2] 107 | object_pc = self.object_pcs[object_name][indices] 108 | direction = torch.randn(3) 109 | direction = direction / torch.norm(direction) 110 | proj = object_pc @ direction 111 | _, indices = torch.sort(proj) 112 | indices = indices[self.num_points:] 113 | object_pc = object_pc[indices] 114 | 115 | object_pc_batch.append(object_pc) 116 | 117 | robot_pc_target = hand.get_transformed_links_pc(target_q)[:, :3] 118 | robot_pc_target_batch.append(robot_pc_target) 119 | initial_q = hand.get_initial_q(target_q) 120 | initial_q_batch.append(initial_q) 121 | robot_pc_initial = hand.get_transformed_links_pc(initial_q)[:, :3] 122 | robot_pc_initial_batch.append(robot_pc_initial) 123 | 124 | dro = torch.cdist(robot_pc_target, object_pc, p=2) 125 | dro_gt_batch.append(dro) 126 | 127 | robot_pc_initial_batch = torch.stack(robot_pc_initial_batch) 128 | robot_pc_target_batch = torch.stack(robot_pc_target_batch) 129 | object_pc_batch = torch.stack(object_pc_batch) 130 | dro_gt_batch = torch.stack(dro_gt_batch) 131 | 132 | B, N = self.batch_size, self.num_points 133 | assert robot_pc_initial_batch.shape == (B, N, 3),\ 134 | f"Expected: {(B, N, 3)}, Actual: {robot_pc_initial_batch.shape}" 135 | assert robot_pc_target_batch.shape == (B, N, 3),\ 136 | f"Expected: {(B, N, 3)}, Actual: {robot_pc_target_batch.shape}" 137 | assert object_pc_batch.shape == (B, N, 3),\ 138 | f"Expected: {(B, N, 3)}, Actual: {object_pc_batch.shape}" 139 | assert dro_gt_batch.shape == (B, N, N),\ 140 | f"Expected: {(B, N, N)}, Actual: {dro_gt_batch.shape}" 141 | 142 | return { 143 | 'robot_name': robot_name_batch, # list(len = B): str 144 | 'object_name': object_name_batch, # list(len = B): str 145 | 'robot_links_pc': robot_links_pc_batch, # list(len = B): dict, {link_name: (N_link, 3)} 146 | 'robot_pc_initial': robot_pc_initial_batch, 147 | 'robot_pc_target': robot_pc_target_batch, 148 | 'object_pc': object_pc_batch, 149 | 'dro_gt': dro_gt_batch, 150 | 'initial_q': initial_q_batch, 151 | 'target_q': target_q_batch 152 | } 153 | else: # validate 154 | robot_name, object_name = self.combination[index] 155 | hand = self.hands[robot_name] 156 | 157 | initial_q_batch = torch.zeros([self.batch_size, hand.dof], dtype=torch.float32) 158 | robot_pc_batch = torch.zeros([self.batch_size, self.num_points, 3], dtype=torch.float32) 159 | object_pc_batch = torch.zeros([self.batch_size, self.num_points, 3], dtype=torch.float32) 160 | 161 | for batch_idx in range(self.batch_size): 162 | initial_q = hand.get_initial_q() 163 | robot_pc = hand.get_transformed_links_pc(initial_q)[:, :3] 164 | 165 | if self.object_pc_type == 'partial': 166 | indices = torch.randperm(65536)[:self.num_points * 2] 167 | object_pc = self.object_pcs[object_name][indices] 168 | direction = torch.randn(3) 169 | direction = direction / torch.norm(direction) 170 | proj = object_pc @ direction 171 | _, indices = torch.sort(proj) 172 | indices = indices[self.num_points:] 173 | object_pc = object_pc[indices] 174 | else: 175 | name = object_name.split('+') 176 | object_path = os.path.join(ROOT_DIR, f'data/PointCloud/object/{name[0]}/{name[1]}.pt') 177 | object_pc = torch.load(object_path)[:, :3] 178 | 179 | initial_q_batch[batch_idx] = initial_q 180 | robot_pc_batch[batch_idx] = robot_pc 181 | object_pc_batch[batch_idx] = object_pc 182 | 183 | B, N, DOF = self.batch_size, self.num_points, len(hand.pk_chain.get_joint_parameter_names()) 184 | assert initial_q_batch.shape == (B, DOF), \ 185 | f"Expected: {(B, DOF)}, Actual: {initial_q_batch.shape}" 186 | assert robot_pc_batch.shape == (B, N, 3), \ 187 | f"Expected: {(B, N, 3)}, Actual: {robot_pc_batch.shape}" 188 | assert object_pc_batch.shape == (B, N, 3), \ 189 | f"Expected: {(B, N, 3)}, Actual: {object_pc_batch.shape}" 190 | 191 | return { 192 | 'robot_name': robot_name, # str 193 | 'object_name': object_name, # str 194 | 'initial_q': initial_q_batch, 195 | 'robot_pc': robot_pc_batch, 196 | 'object_pc': object_pc_batch 197 | } 198 | 199 | def __len__(self): 200 | if self.is_train: 201 | return math.ceil(len(self.metadata) / self.batch_size) 202 | else: 203 | return len(self.combination) 204 | 205 | 206 | def custom_collate_fn(batch): 207 | return batch[0] 208 | 209 | 210 | def create_dataloader(cfg, is_train): 211 | dataset = CMapDataset( 212 | batch_size=cfg.batch_size, 213 | robot_names=cfg.robot_names, 214 | is_train=is_train, 215 | debug_object_names=cfg.debug_object_names, 216 | object_pc_type=cfg.object_pc_type 217 | ) 218 | dataloader = DataLoader( 219 | dataset, 220 | batch_size=1, 221 | collate_fn=custom_collate_fn, 222 | num_workers=cfg.num_workers, 223 | shuffle=is_train 224 | ) 225 | return dataloader 226 | -------------------------------------------------------------------------------- /data_utils/PretrainDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(ROOT_DIR) 10 | 11 | from utils.hand_model import create_hand_model 12 | 13 | 14 | class PretrainDataset(Dataset): 15 | def __init__(self, robot_names: list = None): 16 | self.robot_names = robot_names if robot_names is not None \ 17 | else ['barrett', 'allegro', 'shadowhand'] 18 | 19 | self.dataset_len = 0 20 | self.robot_len = {} 21 | self.hands = {} 22 | self.dofs = [] 23 | self.dataset = {} 24 | for robot_name in self.robot_names: 25 | self.hands[robot_name] = create_hand_model(robot_name, torch.device('cpu')) 26 | self.dofs.append(len(self.hands[robot_name].pk_chain.get_joint_parameter_names())) 27 | self.dataset[robot_name] = [] 28 | 29 | dataset_path = os.path.join(ROOT_DIR, f'data/MultiDex_filtered/{robot_name}/{robot_name}.pt') 30 | dataset = torch.load(dataset_path) 31 | metadata = dataset['metadata'] 32 | self.dataset[robot_name].extend(metadata) 33 | self.dataset_len += len(metadata) 34 | self.robot_len[robot_name] = len(metadata) 35 | 36 | def __getitem__(self, index): 37 | robot_name = random.choices(self.robot_names, weights=self.dofs, k=1)[0] 38 | 39 | hand = self.hands[robot_name] 40 | dataset = self.dataset[robot_name] 41 | target_q, _, _ = random.choice(dataset) 42 | 43 | robot_pc_1 = hand.get_transformed_links_pc(target_q)[:, :3] 44 | initial_q = hand.get_initial_q(target_q) 45 | robot_pc_2 = hand.get_transformed_links_pc(initial_q)[:, :3] 46 | 47 | return { 48 | 'robot_pc_1': robot_pc_1, 49 | 'robot_pc_2': robot_pc_2, 50 | } 51 | 52 | def __len__(self): 53 | return self.dataset_len 54 | 55 | 56 | def create_dataloader(cfg): 57 | dataset = PretrainDataset(cfg.robot_names) 58 | dataloader = DataLoader( 59 | dataset, 60 | batch_size=cfg.batch_size, 61 | num_workers=cfg.num_workers, 62 | persistent_workers=True 63 | ) 64 | return dataloader 65 | -------------------------------------------------------------------------------- /data_utils/filter_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import warnings 6 | import torch 7 | import multiprocessing as mp 8 | import subprocess 9 | from termcolor import cprint 10 | from tqdm import tqdm 11 | import trimesh 12 | 13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | sys.path.append(ROOT_DIR) 15 | 16 | from utils.hand_model import create_hand_model 17 | 18 | 19 | def worker(robot_name, object_name, batch_size, gpu): 20 | args = [ 21 | 'python', 22 | os.path.join(ROOT_DIR, 'validation/isaac_main.py'), 23 | '--mode', 'filter', 24 | '--robot_name', robot_name, 25 | '--object_name', object_name, 26 | '--batch_size', str(batch_size), 27 | '--gpu', gpu 28 | ] 29 | # for arg in args: 30 | # print(arg, end=' ') 31 | # print('\n') 32 | start_time = time.time() 33 | ret = subprocess.run(args, capture_output=True, text=True) 34 | end_time = time.time() 35 | info = ret.stdout.strip().splitlines()[-1] 36 | cprint(f'{info:80}', 'light_blue', end=' ') 37 | cprint(f'time: {end_time - start_time:.2f} s', 'yellow') 38 | if not info.startswith('<'): 39 | cprint(ret.stderr.strip(), 'red') 40 | 41 | 42 | def filter_dataset(gpu_list): 43 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset/cmap_dataset.pt') 44 | info = torch.load(dataset_path, map_location='cpu')['info'] 45 | if 'cmap_func' in info: 46 | del info['cmap_func'] 47 | 48 | pool = mp.Pool(processes=len(gpu_list)) 49 | 50 | gpu_index = 0 51 | for robot_name in info.keys(): 52 | for object_name in info[robot_name]['num_per_object'].keys(): 53 | batch_size = info[robot_name]['num_per_object'][object_name] 54 | gpu = gpu_list[gpu_index % len(gpu_list)] 55 | pool.apply_async(worker, args=(robot_name, object_name, batch_size, gpu)) 56 | gpu_index += 1 57 | 58 | pool.close() 59 | pool.join() 60 | 61 | 62 | def post_process(with_heatmap=False): 63 | """ 64 | dataset = { 65 | 'info': { 66 | : { 67 | 'robot_name': str, 68 | 'num_total': int, 69 | 'num_upper_object': int, 70 | 'num_per_object': { 71 | : int, 72 | ... 73 | } 74 | }, 75 | ... 76 | } 77 | 'metadata': [(, , ), ...] 78 | } 79 | """ 80 | 81 | if with_heatmap: 82 | hands = {} 83 | object_pcs = {} 84 | object_normals = {} 85 | 86 | info = {} 87 | metadata = [] 88 | for robot_name in ['allegro', 'barrett', 'ezgripper', 'robotiq_3finger', 'shadowhand']: 89 | num_total = 0 90 | num_upper_object = 0 91 | num_per_object = {} 92 | 93 | metadata_dir = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/{robot_name}') 94 | object_names = os.listdir(metadata_dir) 95 | for file_name in tqdm(sorted(object_names)): 96 | object_name, _, success_num = file_name.rpartition('_') 97 | success_num = int(success_num[:-3]) # remove '.pt' 98 | 99 | num_total += success_num 100 | if success_num > num_upper_object: 101 | num_upper_object = success_num 102 | num_per_object[object_name] = success_num 103 | 104 | q = torch.load(os.path.join(metadata_dir, file_name)) 105 | for q_idx in range(q.shape[0]): 106 | if with_heatmap: # compute heatmap use GenDexGrasp method to keep consistency 107 | if robot_name in hands: 108 | hand = hands[robot_name] 109 | else: 110 | hand = create_hand_model(robot_name) 111 | hands[robot_name] = hand 112 | robot_pc = hand.get_transformed_links_pc(q[q_idx])[:, :3] 113 | 114 | if object_name not in object_pcs: 115 | name = object_name.split('+') 116 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') 117 | mesh = trimesh.load_mesh(object_path) 118 | object_pc, face_indices = mesh.sample(2048, return_index=True) 119 | object_pc = torch.tensor(object_pc, dtype=torch.float32) 120 | object_normal = torch.tensor(mesh.face_normals[face_indices], dtype=torch.float32) 121 | object_pcs[object_name] = object_pc 122 | object_normals[object_name] = object_normal 123 | else: 124 | object_pc = object_pcs[object_name] 125 | object_normal = object_normals[object_name] 126 | 127 | n_robot = robot_pc.shape[0] 128 | n_object = object_pc.shape[0] 129 | 130 | robot_pc = robot_pc.unsqueeze(0).repeat(n_object, 1, 1) 131 | object_pc = object_pc.unsqueeze(0).repeat(n_robot, 1, 1).transpose(0, 1) 132 | object_normal = object_normal.unsqueeze(0).repeat(n_robot, 1, 1).transpose(0, 1) 133 | 134 | object_hand_dist = (robot_pc - object_pc).norm(dim=2) 135 | object_hand_align = ((robot_pc - object_pc) * object_normal).sum(dim=2) 136 | object_hand_align /= (object_hand_dist + 1e-5) 137 | 138 | object_hand_align_dist = object_hand_dist * torch.exp(1 - object_hand_align) 139 | contact_dist = torch.sqrt(object_hand_align_dist.min(dim=1)[0]) 140 | contact_value_current = 1 - 2 * (torch.sigmoid(10 * contact_dist) - 0.5) 141 | heapmap = contact_value_current.unsqueeze(-1) 142 | 143 | metadata.append((heapmap, q[q_idx], object_name, robot_name)) 144 | else: 145 | metadata.append((q[q_idx], object_name, robot_name)) 146 | 147 | info[robot_name] = { 148 | 'robot_name': robot_name, 149 | 'num_total': num_total, 150 | 'num_upper_object': num_upper_object, 151 | 'num_per_object': num_per_object 152 | } 153 | 154 | dataset = { 155 | 'info': info, 156 | 'metadata': metadata 157 | } 158 | if with_heatmap: 159 | torch.save(dataset, os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset_heatmap.pt')) 160 | torch.save(object_pcs, os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/object_point_clouds.pt')) 161 | else: 162 | torch.save(dataset, os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt')) 163 | 164 | print("Post process done!") 165 | 166 | 167 | if __name__ == '__main__': 168 | warnings.simplefilter(action='ignore', category=FutureWarning) 169 | 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--gpu_list', # input format like '--gpu_list 0,1,2,3,4,5,6,7' 172 | default=['0', '1', '2', '3', '4', '5', '6', '7'], 173 | type=lambda string: string.split(',')) 174 | parser.add_argument('--print_info', action='store_true') 175 | parser.add_argument('--post_process', action='store_true') 176 | parser.add_argument('--with_heatmap', action='store_true') 177 | args = parser.parse_args() 178 | 179 | assert not (args.print_info and args.post_process) 180 | if args.print_info: 181 | # dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset/cmap_dataset.pt') 182 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 183 | info = torch.load(dataset_path, map_location=torch.device('cpu'))['info'] 184 | if 'cmap_func' in info: 185 | print(f"cmap_func: {info['cmap_func']}") 186 | del info['cmap_func'] 187 | 188 | for robot_name in info.keys(): 189 | print(f"********************************") 190 | print(f"robot_name: {info[robot_name]['robot_name']}") 191 | print(f"num_total: {info[robot_name]['num_total']}") 192 | print(f"num_upper_object: {info[robot_name]['num_upper_object']}") 193 | print(f"num_per_object: {len(info[robot_name]['num_per_object'])}") 194 | for object_name in sorted(info[robot_name]['num_per_object'].keys()): 195 | print(f" {object_name}: {info[robot_name]['num_per_object'][object_name]}") 196 | print(f"********************************") 197 | elif args.post_process: 198 | post_process(args.with_heatmap) 199 | else: 200 | filter_dataset(args.gpu_list) 201 | -------------------------------------------------------------------------------- /data_utils/generate_pc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import time 5 | import viser 6 | import torch 7 | import trimesh 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.hand_model import create_hand_model 13 | 14 | 15 | def generate_object_pc(args): 16 | """ object/{contactdb, ycb}/.pt: (num_points, 6), point xyz + normal """ 17 | for dataset_type in ['contactdb', 'ycb']: 18 | input_dir = str(os.path.join(ROOT_DIR, args.object_source_path, dataset_type)) 19 | output_dir = str(os.path.join(ROOT_DIR, args.save_path, 'object', dataset_type)) 20 | os.makedirs(output_dir, exist_ok=True) 21 | 22 | for object_name in os.listdir(input_dir): 23 | if not os.path.isdir(os.path.join(input_dir, object_name)): # skip json file 24 | continue 25 | print(f'Processing {dataset_type}/{object_name}...') 26 | mesh_path = os.path.join(input_dir, object_name, f'{object_name}.stl') 27 | mesh = trimesh.load_mesh(mesh_path) 28 | object_pc, face_indices = mesh.sample(args.num_points, return_index=True) 29 | object_pc = torch.tensor(object_pc, dtype=torch.float32) 30 | normals = torch.tensor(mesh.face_normals[face_indices], dtype=torch.float32) 31 | object_pc_normals = torch.cat([object_pc, normals], dim=-1) 32 | torch.save(object_pc_normals, os.path.join(output_dir, f'{object_name}.pt')) 33 | 34 | print("\nGenerating object point cloud finished.") 35 | 36 | 37 | def generate_robot_pc(args): 38 | output_dir = str(os.path.join(ROOT_DIR, args.save_path, 'robot')) 39 | output_path = str(os.path.join(output_dir, f'{args.robot_name}.pt')) 40 | os.makedirs(output_dir, exist_ok=True) 41 | 42 | hand = create_hand_model(args.robot_name, torch.device('cpu'), args.num_points) 43 | links_pc = hand.vertices 44 | sampled_pc, sampled_pc_index = hand.get_sampled_pc(num_points=args.num_points) 45 | 46 | filtered_links_pc = {} 47 | for link_index, (link_name, points) in enumerate(links_pc.items()): 48 | mask = [i % args.num_points for i in sampled_pc_index 49 | if link_index * args.num_points <= i < (link_index + 1) * args.num_points] 50 | links_pc[link_name] = torch.tensor(points, dtype=torch.float32) 51 | filtered_links_pc[link_name] = torch.tensor(points[mask], dtype=torch.float32) 52 | print(f"[{link_name}] original shape: {links_pc[link_name].shape}, filtered shape: {filtered_links_pc[link_name].shape}") 53 | 54 | data = { 55 | 'original': links_pc, 56 | 'filtered': filtered_links_pc 57 | } 58 | torch.save(data, output_path) 59 | print("\nGenerating robot point cloud finished.") 60 | 61 | server = viser.ViserServer(host='127.0.0.1', port=8080) 62 | server.scene.add_point_cloud( 63 | 'point cloud', 64 | sampled_pc[:, :3].numpy(), 65 | point_size=0.001, 66 | point_shape="circle", 67 | colors=(0, 0, 200) 68 | ) 69 | while True: 70 | time.sleep(1) 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--type', default='robot', type=str) 76 | parser.add_argument('--save_path', default='data/PointCloud/', type=str) 77 | parser.add_argument('--num_points', default=512, type=int) 78 | # for object pc generation 79 | parser.add_argument('--object_source_path', default='data/data_urdf/object', type=str) 80 | # for robot pc generation 81 | parser.add_argument('--robot_name', default='shadowhand', type=str) 82 | args = parser.parse_args() 83 | 84 | if args.type == 'robot': 85 | generate_robot_pc(args) 86 | elif args.type == 'object': 87 | generate_object_pc(args) 88 | else: 89 | raise NotImplementedError 90 | -------------------------------------------------------------------------------- /data_utils/removed_links.json: -------------------------------------------------------------------------------- 1 | { 2 | "allegro": [], 3 | "barrett": [], 4 | "ezgripper": ["base_link"], 5 | "robotiq_3finger": [], 6 | "shadowhand": ["forearm", "wrist", "ffknuckle", "mfknuckle", "rfknuckle", "lfknuckle", "thbase"], 7 | "leaphand": [] 8 | } 9 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def knn(x, k): 7 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) 8 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 9 | pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() 10 | 11 | idx = pairwise_distance.topk(k=k, dim=-1)[1] 12 | return idx 13 | 14 | 15 | def get_graph_feature(x, k=20): 16 | idx = knn(x, k=k) 17 | batch_size, num_points, _ = idx.size() 18 | _, num_dims, _ = x.size() 19 | 20 | idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points 21 | idx = idx + idx_base 22 | idx = idx.view(-1) 23 | 24 | x = x.transpose(2, 1).contiguous() 25 | feature = x.view(batch_size * num_points, -1)[idx, :] 26 | feature = feature.view(batch_size, num_points, k, num_dims) 27 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 28 | 29 | feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2).contiguous() 30 | 31 | return feature 32 | 33 | 34 | class Encoder(nn.Module): 35 | """ 36 | The implementation is based on the DGCNN model 37 | (https://github.com/WangYueFt/dgcnn/blob/f765b469a67730658ba554e97dc11723a7bab628/pytorch/model.py#L88), 38 | and https://github.com/r-pad/taxpose/blob/0c4298fa0486fd09e63bf24d618a579b66ba0f18/third_party/dcp/model.py#L282. 39 | 40 | Further explanation can be found in Appendix F.1 of https://arxiv.org/pdf/2410.01702. 41 | """ 42 | 43 | def __init__(self, emb_dim=512): 44 | super(Encoder, self).__init__() 45 | 46 | self.conv1 = nn.Sequential( 47 | nn.Conv2d(6, 64, kernel_size=1, bias=False), 48 | nn.BatchNorm2d(64), 49 | nn.LeakyReLU(negative_slope=0.2) 50 | ) 51 | self.conv2 = nn.Sequential( 52 | nn.Conv2d(64, 64, kernel_size=1, bias=False), 53 | nn.BatchNorm2d(64), 54 | nn.LeakyReLU(negative_slope=0.2) 55 | ) 56 | self.conv3 = nn.Sequential( 57 | nn.Conv2d(64, 128, kernel_size=1, bias=False), 58 | nn.BatchNorm2d(128), 59 | nn.LeakyReLU(negative_slope=0.2) 60 | ) 61 | self.conv4 = nn.Sequential( 62 | nn.Conv2d(128, 256, kernel_size=1, bias=False), 63 | nn.BatchNorm2d(256), 64 | nn.LeakyReLU(negative_slope=0.2) 65 | ) 66 | self.conv5 = nn.Sequential( 67 | nn.Conv2d(256, 512, kernel_size=1, bias=False), 68 | nn.BatchNorm2d(512), 69 | nn.LeakyReLU(negative_slope=0.2) 70 | ) 71 | self.conv6 = nn.Sequential( 72 | nn.Conv1d(1536, emb_dim, kernel_size=1, bias=False), 73 | nn.BatchNorm1d(emb_dim), 74 | nn.LeakyReLU(negative_slope=0.2) 75 | ) 76 | 77 | def forward(self, x): 78 | x = x.permute(0, 2, 1) # (B, N, 3) -> (B, 3, N) 79 | B, _, N = x.size() 80 | 81 | x = get_graph_feature(x, k=32) # (B, 6, N, K) 82 | 83 | x = self.conv1(x) # (B, 64, N, K) 84 | x1 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N) 85 | 86 | x = self.conv2(x) # (B, 64, N, K) 87 | x2 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N) 88 | 89 | x = self.conv3(x) # (B, 128, N, K) 90 | x3 = x.max(dim=-1, keepdim=False)[0] # (B, 128, N) 91 | 92 | x = self.conv4(x) # (B, 256, N, K) 93 | x4 = x.max(dim=-1, keepdim=False)[0] # (B, 256, N) 94 | 95 | x = self.conv5(x) # (B, 512, N, K) 96 | x5 = x.max(dim=-1, keepdim=False)[0] # (B, 512, N) 97 | 98 | global_feat = x5.mean(dim=-1, keepdim=True).repeat(1, 1, N) # (B, 512, 1) -> (B, 512, N) 99 | 100 | x = torch.cat((x1, x2, x3, x4, x5, global_feat), dim=1) # (B, 1536, N) 101 | x = self.conv6(x).view(B, -1, N) # (B, 512, N) 102 | 103 | return x.permute(0, 2, 1) # (B, D, N) -> (B, N, D) 104 | 105 | 106 | class CvaeEncoder(nn.Module): 107 | """ 108 | The implementation is based on the DGCNN model 109 | (https://github.com/WangYueFt/dgcnn/blob/f765b469a67730658ba554e97dc11723a7bab628/pytorch/model.py#L88). 110 | 111 | The only modification made is to enable the input to include additional features. 112 | """ 113 | 114 | def __init__(self, emb_dims, output_channels, feat_dim=0): 115 | super(CvaeEncoder, self).__init__() 116 | self.feat_dim = feat_dim 117 | 118 | self.bn1 = nn.BatchNorm2d(64) 119 | self.bn2 = nn.BatchNorm2d(64) 120 | self.bn3 = nn.BatchNorm2d(128) 121 | self.bn4 = nn.BatchNorm2d(256) 122 | self.bn5 = nn.BatchNorm1d(emb_dims) 123 | 124 | self.conv1 = nn.Sequential( 125 | nn.Conv2d(6 + feat_dim, 64, kernel_size=1, bias=False), 126 | self.bn1, 127 | nn.LeakyReLU(negative_slope=0.2) 128 | ) 129 | self.conv2 = nn.Sequential( 130 | nn.Conv2d(64, 64, kernel_size=1, bias=False), 131 | self.bn2, 132 | nn.LeakyReLU(negative_slope=0.2) 133 | ) 134 | self.conv3 = nn.Sequential( 135 | nn.Conv2d(64,128, kernel_size=1, bias=False), 136 | self.bn3, 137 | nn.LeakyReLU(negative_slope=0.2) 138 | ) 139 | self.conv4 = nn.Sequential( 140 | nn.Conv2d(128, 256, kernel_size=1, bias=False), 141 | self.bn4, 142 | nn.LeakyReLU(negative_slope=0.2) 143 | ) 144 | self.conv5 = nn.Sequential( 145 | nn.Conv1d(512, emb_dims, kernel_size=1, bias=False), 146 | self.bn5, 147 | nn.LeakyReLU(negative_slope=0.2) 148 | ) 149 | self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False) 150 | self.bn6 = nn.BatchNorm1d(512) 151 | self.dp1 = nn.Dropout(p=0.5) 152 | self.linear2 = nn.Linear(512, 256) 153 | self.bn7 = nn.BatchNorm1d(256) 154 | self.dp2 = nn.Dropout(p=0.5) 155 | self.linear3 = nn.Linear(256, output_channels) 156 | 157 | def forward(self, x): 158 | x = x.permute(0, 2, 1) 159 | B, D, N = x.size() 160 | x_k = get_graph_feature(x[:, :3, :]) # B, 6, N, K 161 | x_feat = x[:, 3:, :].unsqueeze(-1).repeat(1, 1, 1, 20) if self.feat_dim != 0 else None # K = 20 162 | x = torch.cat([x_k, x_feat], dim=1) if self.feat_dim != 0 else x_k # (B, 6 + feat_dim, N, K) 163 | 164 | x = self.conv1(x) 165 | x1 = x.max(dim=-1, keepdim=True)[0] 166 | 167 | x = self.conv2(x) 168 | x2 = x.max(dim=-1, keepdim=True)[0] 169 | 170 | x = self.conv3(x) 171 | x3 = x.max(dim=-1, keepdim=True)[0] 172 | 173 | x = self.conv4(x) 174 | x4 = x.max(dim=-1, keepdim=True)[0] 175 | 176 | x = torch.cat((x1, x2, x3, x4), dim=1)[..., 0] 177 | 178 | x = self.conv5(x) 179 | x1 = F.adaptive_max_pool1d(x, 1).view(B, -1) 180 | x2 = F.adaptive_avg_pool1d(x, 1).view(B, -1) 181 | x = torch.cat((x1, x2), 1) 182 | 183 | x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) 184 | x = self.dp1(x) 185 | x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) 186 | x = self.dp2(x) 187 | x = self.linear3(x) 188 | 189 | return x # (B, output_channels) 190 | -------------------------------------------------------------------------------- /model/latent_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResnetBlockFC(nn.Module): 7 | """ 8 | Fully connected ResNet Block class. 9 | Args: 10 | size_in (int): input dimension 11 | size_out (int): output dimension 12 | size_h (int): hidden dimension 13 | """ 14 | def __init__(self, size_in, size_out=None, size_h=None): 15 | super().__init__() 16 | if size_out is None: 17 | size_out = size_in 18 | 19 | if size_h is None: 20 | size_h = min(size_in, size_out) 21 | 22 | self.size_in = size_in 23 | self.size_h = size_h 24 | self.size_out = size_out 25 | 26 | self.fc_0 = nn.Linear(size_in, size_h) 27 | self.fc_1 = nn.Linear(size_h, size_out) 28 | self.actvn = nn.ReLU() 29 | 30 | if size_in == size_out: 31 | self.shortcut = None 32 | else: 33 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 34 | nn.init.zeros_(self.fc_1.weight) 35 | 36 | def forward(self, x, final_nl=False): 37 | net = self.fc_0(self.actvn(x)) 38 | dx = self.fc_1(self.actvn(net)) 39 | if self.shortcut is not None: 40 | x_s = self.shortcut(x) 41 | else: 42 | x_s = x 43 | x_out = x_s + dx 44 | if final_nl: 45 | return F.leaky_relu(x_out, negative_slope=0.2) 46 | return x_out 47 | 48 | 49 | class LatentEncoder(nn.Module): 50 | def __init__(self, in_dim, dim, out_dim): 51 | super().__init__() 52 | self.block = ResnetBlockFC(size_in=in_dim, size_out=dim, size_h=dim) 53 | self.fc_mu = nn.Linear(dim, out_dim) 54 | self.fc_logvar = nn.Linear(dim, out_dim) 55 | 56 | def forward(self, x): 57 | x = self.block(x, final_nl=True) 58 | return self.fc_mu(x), self.fc_logvar(x) 59 | -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is sourced from https://github.com/r-pad/taxpose. 3 | 4 | The only modification made is adjusting the relative imports to enhance the clarity of the file structure. 5 | """ 6 | 7 | from typing import Callable, List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class MLPKernel(nn.Module): 15 | def __init__(self, feature_dim): 16 | super().__init__() 17 | self.feature_dim = feature_dim 18 | self.mlp = MLP(2 * feature_dim, [300, 100, 1]) 19 | 20 | def forward(self, x1, x2): 21 | v1 = self.mlp(torch.cat([x1, x2], dim=-1)) 22 | v2 = self.mlp(torch.cat([x2, x1], dim=-1)) 23 | return F.softplus((v1 + v2) / 2) 24 | 25 | 26 | class MLP(nn.Sequential): 27 | """ 28 | This block implements the multi-layer perceptron (MLP) module. 29 | 30 | Args: 31 | in_channels (int): Number of channels of the input 32 | hidden_channels (List[int]): List of the hidden channel dimensions 33 | norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` 34 | activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` 35 | inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. 36 | Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. 37 | bias (bool): Whether to use bias in the linear layer. Default ``True`` 38 | dropout (float): The probability for the dropout layer. Default: 0.0 39 | """ 40 | 41 | def __init__( 42 | self, 43 | in_channels: int, 44 | hidden_channels: List[int], 45 | norm_layer: Optional[Callable[..., torch.nn.Module]] = None, 46 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, 47 | inplace: Optional[bool] = None, 48 | bias: bool = True, 49 | dropout: float = 0.0, 50 | ): 51 | # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: 52 | # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py 53 | params = {} if inplace is None else {"inplace": inplace} 54 | 55 | layers: List[nn.Module] = [] 56 | in_dim = in_channels 57 | for hidden_dim in hidden_channels[:-1]: 58 | layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) 59 | if norm_layer is not None: 60 | layers.append(norm_layer(hidden_dim)) 61 | if activation_layer is not None: 62 | layers.append(activation_layer(**params)) 63 | layers.append(torch.nn.Dropout(dropout, **params)) 64 | in_dim = hidden_dim 65 | 66 | layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) 67 | layers.append(torch.nn.Dropout(dropout, **params)) 68 | 69 | super().__init__(*layers) 70 | -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | import pytorch_lightning as pl 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.se3_transform import compute_link_pose 13 | from utils.multilateration import multilateration 14 | from utils.func_utils import calculate_depth 15 | from utils.pretrain_utils import dist2weight, infonce_loss, mean_order 16 | 17 | 18 | class TrainingModule(pl.LightningModule): 19 | def __init__(self, cfg, network, epoch_idx): 20 | super().__init__() 21 | self.cfg = cfg 22 | self.network = network 23 | self.epoch_idx = epoch_idx 24 | 25 | self.lr = cfg.lr 26 | 27 | os.makedirs(self.cfg.save_dir, exist_ok=True) 28 | 29 | def ddp_print(self, *args, **kwargs): 30 | if self.global_rank == 0: 31 | print(*args, **kwargs) 32 | 33 | def training_step(self, batch, batch_idx): 34 | object_name = batch['object_name'] 35 | robot_links_pc = batch['robot_links_pc'] 36 | robot_pc_initial = batch['robot_pc_initial'] 37 | robot_pc_target = batch['robot_pc_target'] 38 | object_pc = batch['object_pc'] 39 | dro_gt = batch['dro_gt'] 40 | 41 | network_output = self.network( 42 | robot_pc_initial, 43 | object_pc, 44 | robot_pc_target 45 | ) 46 | 47 | dro = network_output['dro'] 48 | mu = network_output['mu'] 49 | logvar = network_output['logvar'] 50 | 51 | mlat_pc = multilateration(dro, object_pc) 52 | transforms, transformed_pc = compute_link_pose(robot_links_pc, mlat_pc) 53 | 54 | loss = 0. 55 | 56 | if self.cfg.loss_kl: 57 | loss_kl = - 0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp()) 58 | loss_kl = torch.sqrt(1 + loss_kl ** 2) - 1 59 | loss_kl = loss_kl * self.cfg.loss_kl_weight 60 | self.log('loss_kl', loss_kl, prog_bar=True) 61 | loss += loss_kl 62 | 63 | if self.cfg.loss_r: 64 | loss_r = nn.L1Loss()(dro, dro_gt) 65 | loss_r = loss_r * self.cfg.loss_r_weight 66 | self.log('loss_r', loss_r, prog_bar=True) 67 | loss += loss_r 68 | 69 | if self.cfg.loss_se3: 70 | transforms_gt, transformed_pc_gt = compute_link_pose(robot_links_pc, robot_pc_target) 71 | loss_se3 = 0. 72 | for idx in range(len(transforms)): # iteration over batch 73 | transform = transforms[idx] 74 | transform_gt = transforms_gt[idx] 75 | loss_se3_item = 0. 76 | for link_name in transform: 77 | rel_translation = transform[link_name][:3, 3] - transform_gt[link_name][:3, 3] 78 | rel_rotation = transform[link_name][:3, :3].mT @ transform_gt[link_name][:3, :3] 79 | rel_rotation_trace = torch.clamp(torch.trace(rel_rotation), -1, 3) 80 | rel_angle = torch.acos((rel_rotation_trace - 1) / 2) 81 | loss_se3_item += torch.mean(torch.norm(rel_translation, dim=-1) + rel_angle) 82 | loss_se3 += loss_se3_item / len(transform) 83 | loss_se3 = loss_se3 / len(transforms) * self.cfg.loss_se3_weight 84 | self.log('loss_se3', loss_se3, prog_bar=True) 85 | loss += loss_se3 86 | 87 | if self.cfg.loss_depth: 88 | loss_depth = calculate_depth(transformed_pc, object_name) 89 | loss_depth = loss_depth * self.cfg.loss_depth_weight 90 | self.log('loss_depth', loss_depth, prog_bar=True) 91 | loss += loss_depth 92 | 93 | self.log("loss", loss, prog_bar=True) 94 | return loss 95 | 96 | def on_after_backward(self): 97 | """ 98 | For unknown reasons, there is a small chance that the gradients in CVAE may become NaN during backpropagation. 99 | In such cases, skip the iteration. 100 | """ 101 | for param in self.network.parameters(): 102 | if param.grad is not None and torch.isnan(param.grad).any(): 103 | param.grad = None 104 | 105 | def on_train_epoch_end(self): 106 | self.epoch_idx += 1 107 | self.ddp_print(f"Training epoch: {self.epoch_idx}") 108 | if self.epoch_idx % self.cfg.save_every_n_epoch == 0: 109 | self.ddp_print(f"Saving state_dict at epoch: {self.epoch_idx}") 110 | torch.save(self.network.state_dict(), f'{self.cfg.save_dir}/epoch_{self.epoch_idx}.pth') 111 | 112 | def configure_optimizers(self): 113 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 114 | return optimizer 115 | 116 | 117 | class PretrainingModule(pl.LightningModule): 118 | def __init__(self, cfg, encoder): 119 | super().__init__() 120 | self.cfg = cfg 121 | self.encoder = encoder 122 | 123 | self.lr = cfg.lr 124 | self.temperature = cfg.temperature 125 | 126 | self.epoch_idx = 0 127 | os.makedirs(self.cfg.save_dir, exist_ok=True) 128 | 129 | def ddp_print(self, *args, **kwargs): 130 | if self.global_rank == 0: 131 | print(*args, **kwargs) 132 | 133 | def training_step(self, batch, batch_idx): 134 | robot_pc_1 = batch['robot_pc_1'] 135 | robot_pc_2 = batch['robot_pc_2'] 136 | 137 | robot_pc_1 = robot_pc_1 - robot_pc_1.mean(dim=1, keepdims=True) 138 | robot_pc_2 = robot_pc_2 - robot_pc_2.mean(dim=1, keepdims=True) 139 | 140 | phi_1 = self.encoder(robot_pc_1) # (B, N, 3) -> (B, N, D) 141 | phi_2 = self.encoder(robot_pc_2) # (B, N, 3) -> (B, N, D) 142 | 143 | weights = dist2weight(robot_pc_1, func=lambda x: torch.tanh(10 * x)) 144 | loss, similarity = infonce_loss( 145 | phi_1, phi_2, weights=weights, temperature=self.temperature 146 | ) 147 | mean_order_error = mean_order(similarity) 148 | 149 | self.log("mean_order", mean_order_error) 150 | self.log("loss", loss, prog_bar=True) 151 | 152 | return loss 153 | 154 | def on_train_epoch_end(self): 155 | self.epoch_idx += 1 156 | self.ddp_print(f"Training epoch: {self.epoch_idx}") 157 | if self.epoch_idx % self.cfg.save_every_n_epoch == 0: 158 | self.ddp_print(f"Saving state_dict at epoch: {self.epoch_idx}") 159 | torch.save(self.encoder.state_dict(), f'{self.cfg.save_dir}/epoch_{self.epoch_idx}.pth') 160 | 161 | 162 | def configure_optimizers(self): 163 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 164 | return optimizer 165 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | 6 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | sys.path.append(ROOT_DIR) 8 | 9 | from model.encoder import Encoder, CvaeEncoder 10 | from model.transformer import Transformer 11 | from model.latent_encoder import LatentEncoder 12 | from model.mlp import MLPKernel 13 | 14 | 15 | def create_encoder_network(emb_dim, pretrain=None, device=torch.device('cpu')) -> nn.Module: 16 | encoder = Encoder(emb_dim=emb_dim) 17 | if pretrain is not None: 18 | print(f"******** Load embedding network pretrain from <{pretrain}> ********") 19 | encoder.load_state_dict( 20 | torch.load( 21 | os.path.join(ROOT_DIR, f"ckpt/pretrain/{pretrain}"), 22 | map_location=device 23 | ) 24 | ) 25 | return encoder 26 | 27 | 28 | class Network(nn.Module): 29 | def __init__(self, cfg, mode): 30 | super(Network, self).__init__() 31 | self.cfg = cfg 32 | self.mode = mode 33 | 34 | self.encoder_robot = create_encoder_network(emb_dim=cfg.emb_dim, pretrain=cfg.pretrain) 35 | self.encoder_object = create_encoder_network(emb_dim=cfg.emb_dim) 36 | 37 | self.transformer_robot = Transformer(emb_dim=cfg.emb_dim) 38 | self.transformer_object = Transformer(emb_dim=cfg.emb_dim) 39 | 40 | # CVAE encoder 41 | self.point_encoder = CvaeEncoder(emb_dims=cfg.emb_dim, output_channels=2 * cfg.latent_dim, feat_dim=cfg.emb_dim) 42 | self.latent_encoder = LatentEncoder(in_dim=2*cfg.latent_dim, dim=4*cfg.latent_dim, out_dim=cfg.latent_dim) 43 | 44 | self.kernel = MLPKernel(cfg.emb_dim + cfg.latent_dim) 45 | 46 | def forward(self, robot_pc, object_pc, target_pc=None): 47 | if self.cfg.center_pc: # zero-mean the robot point cloud 48 | robot_pc = robot_pc - robot_pc.mean(dim=1, keepdim=True) 49 | 50 | # point cloud encoder 51 | robot_embedding = self.encoder_robot(robot_pc) 52 | object_embedding = self.encoder_object(object_pc) 53 | 54 | if self.cfg.pretrain is not None: 55 | robot_embedding = robot_embedding.detach() 56 | 57 | # point cloud transformer 58 | transformer_robot_outputs = self.transformer_robot(robot_embedding, object_embedding) 59 | transformer_object_outputs = self.transformer_object(object_embedding, robot_embedding) 60 | robot_embedding_tf = robot_embedding + transformer_robot_outputs["src_embedding"] 61 | object_embedding_tf = object_embedding + transformer_object_outputs["src_embedding"] 62 | 63 | # CVAE encoder 64 | if self.mode == 'train': 65 | grasp_pc = torch.cat([target_pc, object_pc], dim=1) 66 | grasp_emb = torch.cat([robot_embedding_tf, object_embedding_tf], dim=1) 67 | latent = self.point_encoder(torch.cat([grasp_pc, grasp_emb], -1)) 68 | mu, logvar = self.latent_encoder(latent) 69 | z_dist = torch.distributions.normal.Normal(mu, torch.exp(0.5 * logvar)) 70 | z = z_dist.rsample() # (B, latent_dim) 71 | else: 72 | mu, logvar = None, None 73 | z = torch.randn(robot_pc.shape[0], self.cfg.latent_dim).to(robot_pc.device) 74 | z = z.unsqueeze(dim=1).repeat(1, robot_embedding_tf.shape[1], 1) # (B, N, latent_dim) 75 | 76 | Phi_A = torch.cat([robot_embedding_tf, z], dim=-1) # (B, N, emb_dim + latent_dim) 77 | Phi_B = torch.cat([object_embedding_tf, z], dim=-1) # (B, N, emb_dim + latent_dim) 78 | 79 | # Compute D(R,O) matrix 80 | if self.cfg.block_computing: # use matrix block computation to save GPU memory 81 | B, N, D = Phi_A.shape 82 | block_num = 4 # experimental result, reaching a balance between speed and GPU memory 83 | N_block = N // block_num 84 | assert N % N_block == 0, 'Unable to perform block computation.' 85 | 86 | dro = torch.zeros([B, N, N], dtype=torch.float32, device=Phi_A.device) 87 | for A_i in range(block_num): 88 | Phi_A_block = Phi_A[:, A_i * N_block: (A_i + 1) * N_block, :] # (B, N_block, D) 89 | for B_i in range(block_num): 90 | Phi_B_block = Phi_B[:, B_i * N_block: (B_i + 1) * N_block, :] # (B, N_block, D) 91 | 92 | Phi_A_r = Phi_A_block.unsqueeze(2).repeat(1, 1, N_block, 1).reshape(B * N_block * N_block, D) 93 | Phi_B_r = Phi_B_block.unsqueeze(1).repeat(1, N_block, 1, 1).reshape(B * N_block * N_block, D) 94 | 95 | dro[:, A_i * N_block: (A_i + 1) * N_block, B_i * N_block: (B_i + 1) * N_block] \ 96 | = self.kernel(Phi_A_r, Phi_B_r).reshape(B, N_block, N_block) 97 | else: 98 | Phi_A_r = ( 99 | Phi_A.unsqueeze(2) 100 | .repeat(1, 1, Phi_A.shape[1], 1) 101 | .reshape(Phi_A.shape[0] * Phi_A.shape[1] * Phi_A.shape[1], Phi_A.shape[2]) 102 | ) 103 | Phi_B_r = ( 104 | Phi_B.unsqueeze(1) 105 | .repeat(1, Phi_B.shape[1], 1, 1) 106 | .reshape(Phi_B.shape[0] * Phi_B.shape[1] * Phi_B.shape[1], Phi_B.shape[2]) 107 | ) 108 | dro = self.kernel(Phi_A_r, Phi_B_r).reshape(Phi_A.shape[0], Phi_A.shape[1], Phi_B.shape[1]) 109 | 110 | outputs = { 111 | 'dro': dro, 112 | 'mu': mu, 113 | 'logvar': logvar, 114 | } 115 | return outputs 116 | 117 | 118 | def create_network(cfg, mode): 119 | network = Network( 120 | cfg=cfg, 121 | mode=mode 122 | ) 123 | return network 124 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is sourced from https://github.com/r-pad/taxpose, which builds upon 3 | the transformer model from https://github.com/WangYueFt/dcp/blob/master/model.py. 4 | 5 | The only modification made is adjusting the relative imports to enhance the clarity of the file structure. 6 | """ 7 | 8 | import math 9 | import copy 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class Transformer(nn.Module): 16 | def __init__( 17 | self, 18 | emb_dim=512, 19 | n_blocks=1, 20 | dropout=0.0, 21 | ff_dims=1024, 22 | n_heads=4, 23 | bidirectional=False, 24 | ): 25 | super(Transformer, self).__init__() 26 | self.emb_dim = emb_dim 27 | self.N = n_blocks 28 | self.dropout = dropout 29 | self.ff_dims = ff_dims 30 | self.n_heads = n_heads 31 | self.bidirectional = bidirectional 32 | c = copy.deepcopy 33 | attn = MultiHeadedAttention(self.n_heads, self.emb_dim) 34 | ff = PositionwiseFeedForward(self.emb_dim, self.ff_dims, self.dropout) 35 | self.model = EncoderDecoder( 36 | Encoder( 37 | EncoderLayer(self.emb_dim, c(attn), c(ff), self.dropout), 38 | self.N 39 | ), 40 | Decoder( 41 | DecoderLayer(self.emb_dim, c(attn), c(attn), c(ff), self.dropout), 42 | self.N, 43 | ), 44 | nn.Sequential(), 45 | nn.Sequential(), 46 | nn.Sequential(), 47 | ) 48 | 49 | def forward(self, *input): 50 | src = input[0] 51 | tgt = input[1] 52 | src_embedding = self.model(tgt, src, None, None) 53 | src_attn = self.model.decoder.layers[-1].src_attn.attn 54 | 55 | outputs = {"src_embedding": src_embedding, "src_attn": src_attn} 56 | 57 | if self.bidirectional: 58 | tgt_embedding = ( 59 | self.model(src, tgt, None, None) 60 | ) 61 | tgt_attn = self.model.decoder.layers[-1].src_attn.attn 62 | 63 | outputs = { 64 | **outputs, 65 | "tgt_embedding": tgt_embedding, 66 | "tgt_attn": tgt_attn, 67 | } 68 | 69 | return outputs 70 | 71 | def clones(module, N): 72 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 73 | 74 | def attention(query, key, value, mask=None, dropout=None): 75 | d_k = query.size(-1) 76 | scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / math.sqrt(d_k) 77 | if mask is not None: 78 | scores = scores.masked_fill(mask == 0, -1e9) 79 | p_attn = F.softmax(scores, dim=-1) 80 | return torch.matmul(p_attn, value), p_attn 81 | 82 | class LayerNorm(nn.Module): 83 | def __init__(self, features, eps=1e-6): 84 | super(LayerNorm, self).__init__() 85 | self.a_2 = nn.Parameter(torch.ones(features)) 86 | self.b_2 = nn.Parameter(torch.zeros(features)) 87 | self.eps = eps 88 | 89 | def forward(self, x): 90 | mean = x.mean(-1, keepdim=True) 91 | std = x.std(-1, keepdim=True) 92 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 93 | 94 | class EncoderDecoder(nn.Module): 95 | """ 96 | A standard Encoder-Decoder architecture. Base for this and many 97 | other models. 98 | """ 99 | 100 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 101 | super(EncoderDecoder, self).__init__() 102 | self.encoder = encoder 103 | self.decoder = decoder 104 | self.src_embed = src_embed 105 | self.tgt_embed = tgt_embed 106 | self.generator = generator 107 | 108 | def forward(self, src, tgt, src_mask, tgt_mask): 109 | """Take in and process masked src and target sequences.""" 110 | return self.decode(self.encode(src, src_mask), src_mask, 111 | tgt, tgt_mask) 112 | 113 | def encode(self, src, src_mask): 114 | return self.encoder(self.src_embed(src), src_mask) 115 | 116 | def decode(self, memory, src_mask, tgt, tgt_mask): 117 | return self.generator(self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)) 118 | 119 | class Decoder(nn.Module): 120 | """Generic N layer decoder with masking.""" 121 | 122 | def __init__(self, layer, N): 123 | super(Decoder, self).__init__() 124 | self.layers = clones(layer, N) 125 | self.norm = LayerNorm(layer.size) 126 | 127 | def forward(self, x, memory, src_mask, tgt_mask): 128 | for layer in self.layers: 129 | x = layer(x, memory, src_mask, tgt_mask) 130 | return self.norm(x) 131 | 132 | class DecoderLayer(nn.Module): 133 | """Decoder is made of self-attn, src-attn, and feed forward (defined below)""" 134 | 135 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 136 | super(DecoderLayer, self).__init__() 137 | self.size = size 138 | self.self_attn = self_attn 139 | self.src_attn = src_attn 140 | self.feed_forward = feed_forward 141 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 142 | 143 | def forward(self, x, memory, src_mask, tgt_mask): 144 | """Follow Figure 1 (right) for connections.""" 145 | m = memory 146 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 147 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 148 | return self.sublayer[2](x, self.feed_forward) 149 | 150 | class SublayerConnection(nn.Module): 151 | def __init__(self, size, dropout=None): 152 | super(SublayerConnection, self).__init__() 153 | self.norm = LayerNorm(size) 154 | 155 | def forward(self, x, sublayer): 156 | return x + sublayer(self.norm(x)) 157 | 158 | class Encoder(nn.Module): 159 | def __init__(self, layer, N): 160 | super(Encoder, self).__init__() 161 | self.layers = clones(layer, N) 162 | self.norm = LayerNorm(layer.size) 163 | 164 | def forward(self, x, mask): 165 | for layer in self.layers: 166 | x = layer(x, mask) 167 | return self.norm(x) 168 | 169 | class EncoderLayer(nn.Module): 170 | def __init__(self, size, self_attn, feed_forward, dropout): 171 | super(EncoderLayer, self).__init__() 172 | self.self_attn = self_attn 173 | self.feed_forward = feed_forward 174 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 175 | self.size = size 176 | 177 | def forward(self, x, mask): 178 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 179 | return self.sublayer[1](x, self.feed_forward) 180 | 181 | class MultiHeadedAttention(nn.Module): 182 | def __init__(self, h, d_model, dropout=0.1): 183 | """Take in model size and number of heads.""" 184 | super(MultiHeadedAttention, self).__init__() 185 | assert d_model % h == 0 186 | # We assume d_v always equals d_k 187 | self.d_k = d_model // h 188 | self.h = h 189 | self.linears = clones(nn.Linear(d_model, d_model), 4) 190 | self.attn = None 191 | self.dropout = None 192 | 193 | def forward(self, query, key, value, mask=None): 194 | """Implements Figure 2""" 195 | if mask is not None: 196 | # Same mask applied to all h heads. 197 | mask = mask.unsqueeze(1) 198 | nbatches = query.size(0) 199 | 200 | # 1) Do all the linear projections in batch from d_model => h x d_k 201 | query, key, value = \ 202 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() 203 | for l, x in zip(self.linears, (query, key, value))] 204 | 205 | # 2) Apply attention on all the projected vectors in batch. 206 | x, self.attn = attention(query, key, value, mask=mask, 207 | dropout=self.dropout) 208 | 209 | # 3) "Concat" using a view and apply a final linear. 210 | x = x.transpose(1, 2).contiguous() \ 211 | .view(nbatches, -1, self.h * self.d_k) 212 | return self.linears[-1](x) 213 | 214 | 215 | class PositionwiseFeedForward(nn.Module): 216 | """Implements FFN equation.""" 217 | 218 | def __init__(self, d_model, d_ff, dropout=0.1): 219 | super(PositionwiseFeedForward, self).__init__() 220 | self.w_1 = nn.Linear(d_model, d_ff) 221 | self.norm = nn.Sequential() # nn.BatchNorm1d(d_ff) 222 | self.w_2 = nn.Linear(d_ff, d_model) 223 | self.dropout = None 224 | 225 | def forward(self, x): 226 | return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous()) 227 | -------------------------------------------------------------------------------- /pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenyuwei2003/DRO-Grasp/b312055b4a20f73ddfeb3ffc8a1a6c80d48bbe31/pipeline.jpg -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | import hydra 5 | from omegaconf import OmegaConf 6 | import torch 7 | import pytorch_lightning as pl 8 | from pytorch_lightning.loggers import WandbLogger 9 | 10 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(ROOT_DIR) 12 | 13 | from model.module import PretrainingModule 14 | from model.network import create_encoder_network 15 | from data_utils.PretrainDataset import create_dataloader 16 | 17 | 18 | @hydra.main(version_base="1.2", config_path="configs", config_name="pretrain") 19 | def main(cfg): 20 | print("******************************** [Config] ********************************") 21 | print(OmegaConf.to_yaml(cfg)) 22 | print("******************************** [Config] ********************************") 23 | 24 | pl.seed_everything(cfg.seed) 25 | 26 | logger = WandbLogger( 27 | name=cfg.name, 28 | save_dir=cfg.wandb.save_dir, 29 | project=cfg.wandb.project 30 | ) 31 | trainer = pl.Trainer( 32 | logger=logger, 33 | accelerator='gpu', 34 | devices=cfg.gpu, 35 | log_every_n_steps=cfg.log_every_n_steps, 36 | max_epochs=cfg.training.max_epochs 37 | ) 38 | 39 | dataloader = create_dataloader(cfg.dataset) 40 | encoder = create_encoder_network(cfg.model.emb_dim) 41 | model = PretrainingModule( 42 | cfg=cfg.training, 43 | encoder=encoder 44 | ) 45 | model.train() 46 | 47 | trainer.fit(model, dataloader) 48 | 49 | 50 | if __name__ == "__main__": 51 | torch.set_float32_matmul_precision("high") 52 | torch.autograd.set_detect_anomaly(True) 53 | torch.cuda.empty_cache() 54 | torch.multiprocessing.set_sharing_strategy("file_system") 55 | warnings.simplefilter(action='ignore', category=FutureWarning) 56 | main() 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohappyeyeballs==2.4.0 3 | aiohttp==3.10.5 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | arm_pytorch_utilities==0.4.3 7 | async-timeout==4.0.3 8 | attrs==24.2.0 9 | beautifulsoup4==4.12.3 10 | cachetools==5.5.0 11 | certifi==2024.8.30 12 | charset-normalizer==3.3.2 13 | click==8.1.7 14 | contourpy==1.1.1 15 | cvxpy==1.5.2 16 | cvxpylayers==0.1.6 17 | cycler==0.12.1 18 | docker-pycreds==0.4.0 19 | filelock==3.15.4 20 | fonttools==4.53.1 21 | frozenlist==1.4.1 22 | fsspec==2024.6.1 23 | gdown==5.2.0 24 | gitdb==4.0.11 25 | GitPython==3.1.43 26 | h5py==3.11.0 27 | hydra-core==1.3.2 28 | idna==3.8 29 | importlib_resources==6.4.4 30 | Jinja2==3.1.4 31 | kiwisolver==1.4.5 32 | lightning-utilities==0.11.6 33 | lxml==5.3.0 34 | MarkupSafe==2.1.5 35 | matplotlib==3.7.5 36 | mpmath==1.3.0 37 | multidict==6.0.5 38 | networkx==3.1 39 | numpy==1.24.4 40 | nvidia-cublas-cu12==12.1.3.1 41 | nvidia-cuda-cupti-cu12==12.1.105 42 | nvidia-cuda-nvrtc-cu12==12.1.105 43 | nvidia-cuda-runtime-cu12==12.1.105 44 | nvidia-cudnn-cu12==9.1.0.70 45 | nvidia-cufft-cu12==11.0.2.54 46 | nvidia-curand-cu12==10.3.2.106 47 | nvidia-cusolver-cu12==11.4.5.107 48 | nvidia-cusparse-cu12==12.1.0.106 49 | nvidia-ml-py==12.535.161 50 | nvidia-nccl-cu12==2.20.5 51 | nvidia-nvjitlink-cu12==12.6.68 52 | nvidia-nvtx-cu12==12.1.105 53 | nvitop==1.3.2 54 | omegaconf==2.3.0 55 | packaging==24.1 56 | pillow==10.4.0 57 | platformdirs==4.2.2 58 | protobuf==5.28.0 59 | psutil==6.0.0 60 | pyparsing==3.1.4 61 | PySocks==1.7.1 62 | python-dateutil==2.9.0.post0 63 | pytorch-lightning==2.4.0 64 | pytorch-seed==0.2.0 65 | pytorch_kinematics==0.7.4 66 | PyYAML==6.0.2 67 | requests==2.32.3 68 | scipy==1.10.1 69 | sentry-sdk==2.13.0 70 | setproctitle==1.3.3 71 | six==1.16.0 72 | smmap==5.0.1 73 | soupsieve==2.6 74 | sympy==1.13.2 75 | termcolor==2.4.0 76 | torch==2.4.1 77 | torchmetrics==1.4.1 78 | torchvision==0.19.1 79 | tqdm==4.66.5 80 | trimesh==4.4.8 81 | triton==3.0.0 82 | typing_extensions==4.12.2 83 | urllib3==2.2.2 84 | viser==0.2.1 85 | wandb==0.17.8 86 | yarl==1.9.6 87 | zipp==3.20.1 88 | -------------------------------------------------------------------------------- /scripts/download_ckpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ckpt 4 | 5 | cd ckpt 6 | wget https://github.com/zhenyuwei2003/DRO-Grasp/releases/download/v1.0/ckpt.zip 7 | unzip ckpt.zip 8 | rm ckpt.zip 9 | cd .. 10 | 11 | echo "Download checkpoint models finished!" 12 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p data 4 | 5 | cd data 6 | wget https://github.com/zhenyuwei2003/DRO-Grasp/releases/download/v1.0/data.zip 7 | unzip data.zip 8 | rm data.zip 9 | cd .. 10 | 11 | echo "Download data finished!" 12 | -------------------------------------------------------------------------------- /scripts/example_isaac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import warnings 5 | import numpy as np 6 | from tqdm import tqdm 7 | from termcolor import cprint 8 | from types import SimpleNamespace 9 | 10 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | sys.path.append(ROOT_DIR) 12 | 13 | from model.network import create_network 14 | from data_utils.CMapDataset import create_dataloader 15 | from utils.multilateration import multilateration 16 | from utils.se3_transform import compute_link_pose 17 | from utils.optimization import * 18 | from utils.hand_model import create_hand_model 19 | from validation.validate_utils import validate_isaac 20 | 21 | 22 | gpu = 0 23 | device = torch.device(f'cuda:{gpu}') 24 | ckpt_name = 'model_3robots' # 'model_3robots_partial', 'model_allegro', 'model_barrett', 'model_shadowhand' 25 | batch_size = 10 26 | 27 | 28 | def main(): 29 | network = create_network( 30 | SimpleNamespace(**{ 31 | 'emb_dim': 512, 32 | 'latent_dim': 64, 33 | 'pretrain': None, 34 | 'center_pc': True, 35 | 'block_computing': True 36 | }), 37 | mode='validate' 38 | ).to(device) 39 | network.load_state_dict(torch.load(f"ckpt/model/{ckpt_name}.pth", map_location=device)) 40 | network.eval() 41 | dataloader = create_dataloader( 42 | SimpleNamespace(**{ 43 | 'batch_size': batch_size, 44 | 'robot_names': ['barrett', 'allegro', 'shadowhand'], 45 | 'debug_object_names': None, 46 | 'object_pc_type': 'random' if ckpt_name != 'model_3robots_partial' else 'partial', 47 | 'num_workers': 16 48 | }), 49 | is_train=False 50 | ) 51 | 52 | global_robot_name = None 53 | hand = None 54 | all_success_q = [] 55 | time_list = [] 56 | success_num = 0 57 | total_num = 0 58 | for i, data in enumerate(dataloader): 59 | robot_name = data['robot_name'] 60 | object_name = data['object_name'] 61 | 62 | if robot_name != global_robot_name: 63 | if global_robot_name is not None: 64 | all_success_q = torch.cat(all_success_q, dim=0) 65 | diversity_std = torch.std(all_success_q, dim=0).mean() 66 | times = np.array(time_list) 67 | time_mean = np.mean(times) 68 | time_std = np.std(times) 69 | 70 | success_rate = success_num / total_num * 100 71 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 72 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 73 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 74 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 75 | 76 | all_success_q = [] 77 | time_list = [] 78 | success_num = 0 79 | total_num = 0 80 | hand = create_hand_model(robot_name, device) 81 | global_robot_name = robot_name 82 | 83 | predict_q_list = [] 84 | for data_idx in tqdm(range(batch_size)): 85 | initial_q = data['initial_q'][data_idx: data_idx + 1].to(device) 86 | robot_pc = data['robot_pc'][data_idx: data_idx + 1].to(device) 87 | object_pc = data['object_pc'][data_idx: data_idx + 1].to(device) 88 | 89 | with torch.no_grad(): 90 | dro = network(robot_pc, object_pc)['dro'].detach() 91 | 92 | mlat_pc = multilateration(dro, object_pc) 93 | transform, _ = compute_link_pose(hand.links_pc, mlat_pc, is_train=False) 94 | optim_transform = process_transform(hand.pk_chain, transform) 95 | 96 | layer = create_problem(hand.pk_chain, optim_transform.keys()) 97 | start_time = time.time() 98 | predict_q = optimization(hand.pk_chain, layer, initial_q, optim_transform) 99 | end_time = time.time() 100 | # print(f"[{data_count}/{batch_size}] Optimization time: {end_time - start_time:.4f} s") 101 | time_list.append(end_time - start_time) 102 | 103 | predict_q_list.append(predict_q) 104 | 105 | predict_q_batch = torch.cat(predict_q_list, dim=0) 106 | 107 | success, isaac_q = validate_isaac(robot_name, object_name, predict_q_batch, gpu=gpu) 108 | succ_num = success.sum().item() if success is not None else -1 109 | success_q = predict_q_batch[success] 110 | all_success_q.append(success_q) 111 | 112 | cprint(f"[{robot_name}/{object_name}]", 'light_blue', end=' ') 113 | cprint(f"Result: {succ_num}/{batch_size}({succ_num / batch_size * 100:.2f}%)", 'green') 114 | success_num += succ_num 115 | total_num += batch_size 116 | 117 | all_success_q = torch.cat(all_success_q, dim=0) 118 | diversity_std = torch.std(all_success_q, dim=0).mean() 119 | 120 | times = np.array(time_list) 121 | time_mean = np.mean(times) 122 | time_std = np.std(times) 123 | 124 | success_rate = success_num / total_num * 100 125 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 126 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 127 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 128 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 129 | 130 | 131 | if __name__ == "__main__": 132 | warnings.simplefilter(action='ignore', category=FutureWarning) 133 | torch.set_num_threads(8) 134 | main() 135 | -------------------------------------------------------------------------------- /scripts/example_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from types import SimpleNamespace 4 | from tqdm import tqdm 5 | import torch 6 | 7 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(ROOT_DIR) 9 | 10 | from model.network import create_encoder_network 11 | from data_utils.CMapDataset import create_dataloader 12 | from utils.pretrain_utils import dist2weight, infonce_loss 13 | 14 | 15 | pretrain_ckpt = "pretrain_3robots" # name of pretrain model 16 | robot_names = ['barrett', 'allegro', 'shadowhand'] 17 | verbose = False 18 | data_num = 200 19 | 20 | 21 | def main(): 22 | encoder = create_encoder_network(emb_dim=512) 23 | 24 | encoder.load_state_dict( 25 | torch.load( 26 | os.path.join(ROOT_DIR, f'ckpt/pretrain/{pretrain_ckpt}.pth'), 27 | map_location=torch.device('cpu') 28 | ) 29 | ) 30 | 31 | for robot_name in robot_names: 32 | print(f"Robot: {robot_name}") 33 | dataloader = create_dataloader( 34 | SimpleNamespace(**{ 35 | 'batch_size': 1, 36 | 'robot_names': [robot_name], 37 | 'debug_object_names': None, 38 | 'object_pc_type': 'random', 39 | 'num_workers': 4 40 | }), 41 | is_train=True 42 | ) 43 | 44 | orders = [] 45 | for data_idx, data in enumerate(tqdm(dataloader, total=data_num)): 46 | if data_idx == data_num: 47 | break 48 | 49 | pc_1 = data['robot_pc_initial'] 50 | pc_2 = data['robot_pc_target'] 51 | 52 | pc_1 = pc_1 - pc_1.mean(dim=1, keepdims=True) 53 | pc_2 = pc_2 - pc_2.mean(dim=1, keepdims=True) 54 | 55 | emb_1 = encoder(pc_1).detach() 56 | emb_2 = encoder(pc_2).detach() 57 | 58 | weight = dist2weight(pc_1, func=lambda x: torch.tanh(10 * x)) 59 | loss, similarity = infonce_loss( 60 | emb_1, emb_2, weights=weight, temperature=0.1 61 | ) 62 | 63 | order = (similarity > similarity.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)).sum(-1).float().mean() 64 | orders.append(order) 65 | if verbose: 66 | print("\torder:", order) 67 | 68 | print(f"Robot: {robot_name}, Mean Order: {sum(orders) / len(orders)}\n") 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /scripts/pretrain_order.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import argparse 6 | import warnings 7 | from types import SimpleNamespace 8 | from tqdm import tqdm 9 | import torch 10 | 11 | from model.network import create_encoder_network 12 | from data_utils.CMapDataset import create_dataloader 13 | from utils.pretrain_utils import dist2weight, infonce_loss 14 | 15 | 16 | def main(args): 17 | encoder = create_encoder_network(emb_dim=512) 18 | 19 | for epoch in args.epoch_list: 20 | print("****************************************************************") 21 | print(f"[Epoch {epoch}]") 22 | encoder.load_state_dict( 23 | torch.load( 24 | os.path.join(ROOT_DIR, f'output/{args.pretrain_ckpt}/state_dict/epoch_{epoch}.pth'), 25 | map_location=torch.device('cpu') 26 | ) 27 | ) 28 | 29 | for robot_name in args.robot_names: 30 | print(f"Robot: {robot_name}") 31 | dataloader = create_dataloader( 32 | SimpleNamespace(**{ 33 | 'batch_size': 1, 34 | 'robot_names': [robot_name], 35 | 'debug_object_names': None, 36 | 'object_pc_type': 'random', 37 | 'num_workers': 4 38 | }), 39 | is_train=True 40 | ) 41 | # print(len(dataloader)) 42 | 43 | orders = [] 44 | for data_idx, data in enumerate(tqdm(dataloader, total=args.data_num)): 45 | if data_idx == args.data_num: 46 | break 47 | 48 | pc_1 = data['robot_pc_initial'] 49 | pc_2 = data['robot_pc_target'] 50 | 51 | pc_1 = pc_1 - pc_1.mean(dim=1, keepdims=True) 52 | pc_2 = pc_2 - pc_2.mean(dim=1, keepdims=True) 53 | 54 | emb_1 = encoder(pc_1).detach() 55 | emb_2 = encoder(pc_2).detach() 56 | 57 | weight = dist2weight(pc_1, func=lambda x: torch.tanh(10 * x)) 58 | loss, similarity = infonce_loss( 59 | emb_1, emb_2, weights=weight, temperature=0.1 60 | ) 61 | 62 | order = (similarity > similarity.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)).sum(-1).float().mean() 63 | orders.append(order) 64 | if args.verbose: 65 | print("\torder:", order) 66 | 67 | print(f"Epoch: {epoch}, Robot: {robot_name}, Mean Order: {sum(orders) / len(orders)}\n") 68 | 69 | 70 | if __name__ == '__main__': 71 | warnings.simplefilter(action='ignore', category=FutureWarning) 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--pretrain_ckpt', type=str, default='pretrain_3robots') 74 | parser.add_argument('--data_num', type=int, default=200) 75 | parser.add_argument('--epoch_list', type=lambda string: string.split(','), 76 | default=['10', '20', '30', '40', '50', '60', '70', '80', '90', '100']) 77 | parser.add_argument('--robot_names', type=lambda string: string.split(','), 78 | default=['barrett', 'allegro', 'shadowhand']) 79 | parser.add_argument('--verbose', action='store_true') 80 | args = parser.parse_args() 81 | main(args) 82 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import hydra 4 | import warnings 5 | import torch 6 | import pytorch_lightning as pl 7 | from omegaconf import OmegaConf 8 | from pytorch_lightning.loggers import WandbLogger 9 | 10 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(ROOT_DIR) 12 | 13 | from data_utils.CMapDataset import create_dataloader 14 | from model.network import create_network 15 | from model.module import TrainingModule 16 | 17 | 18 | @hydra.main(version_base="1.2", config_path="configs", config_name="train") 19 | def main(cfg): 20 | print("******************************** [Config] ********************************") 21 | print(OmegaConf.to_yaml(cfg)) 22 | print("******************************** [Config] ********************************") 23 | 24 | pl.seed_everything(cfg.seed) 25 | 26 | last_run_id = None 27 | last_epoch = 0 28 | last_ckpt_file = None 29 | if cfg.load_from_checkpoint: 30 | wandb_dir = f'output/{cfg.name}/log/{cfg.wandb.project}' 31 | last_run_id = os.listdir(wandb_dir)[0] 32 | ckpt_dir = f'{wandb_dir}/{last_run_id}/checkpoints' 33 | ckpt_files = os.listdir(ckpt_dir) 34 | for ckpt_file in ckpt_files: 35 | epoch = int(ckpt_file.split('-')[0].split('=')[1]) 36 | if epoch > last_epoch: 37 | last_epoch = epoch 38 | last_ckpt_file = os.path.join(ckpt_dir, ckpt_file) 39 | print("***************************************************") 40 | print(f"Loading checkpoint from run_id({last_run_id}): epoch {last_epoch}") 41 | print("***************************************************") 42 | 43 | logger = WandbLogger( 44 | name=cfg.name, 45 | save_dir=cfg.wandb.save_dir, 46 | id=last_run_id, 47 | project=cfg.wandb.project 48 | ) 49 | trainer = pl.Trainer( 50 | logger=logger, 51 | accelerator='gpu', 52 | strategy='ddp_find_unused_parameters_true' if (cfg.model.pretrain is not None) else 'auto', 53 | devices=cfg.gpu, 54 | log_every_n_steps=cfg.log_every_n_steps, 55 | max_epochs=cfg.training.max_epochs, 56 | gradient_clip_val=0.1 57 | ) 58 | 59 | dataloader = create_dataloader(cfg.dataset, is_train=True) 60 | 61 | network = create_network(cfg.model, mode='train') 62 | model = TrainingModule( 63 | cfg=cfg.training, 64 | network=network, 65 | epoch_idx=last_epoch 66 | ) 67 | model.train() 68 | 69 | trainer.fit(model, dataloader, ckpt_path=last_ckpt_file) 70 | torch.save(model.network.state_dict(), f'{cfg.training.save_dir}/epoch_{cfg.training.max_epochs}.pth') 71 | 72 | 73 | if __name__ == "__main__": 74 | torch.set_float32_matmul_precision("high") 75 | torch.autograd.set_detect_anomaly(True) 76 | torch.cuda.empty_cache() 77 | torch.multiprocessing.set_sharing_strategy("file_system") 78 | warnings.simplefilter(action='ignore', category=FutureWarning) 79 | main() 80 | -------------------------------------------------------------------------------- /utils/controller.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import json 5 | import trimesh 6 | import torch 7 | import viser 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.hand_model import create_hand_model 13 | from utils.rotation import q_rot6d_to_q_euler 14 | 15 | 16 | def get_link_dir(robot_name, joint_name): 17 | if joint_name.startswith('virtual'): 18 | return None 19 | 20 | if robot_name == 'allegro': 21 | if joint_name in ['joint_0.0', 'joint_4.0', 'joint_8.0', 'joint_13.0']: 22 | return None 23 | link_dir = torch.tensor([0, 0, 1], dtype=torch.float32) 24 | elif robot_name == 'barrett': 25 | if joint_name in ['bh_j11_joint', 'bh_j21_joint']: 26 | return None 27 | link_dir = torch.tensor([-1, 0, 0], dtype=torch.float32) 28 | elif robot_name == 'ezgripper': 29 | link_dir = torch.tensor([1, 0, 0], dtype=torch.float32) 30 | elif robot_name == 'robotiq_3finger': 31 | if joint_name in ['gripper_fingerB_knuckle', 'gripper_fingerC_knuckle']: 32 | return None 33 | link_dir = torch.tensor([0, 0, -1], dtype=torch.float32) 34 | elif robot_name == 'shadowhand': 35 | if joint_name in ['WRJ2', 'WRJ1']: 36 | return None 37 | if joint_name != 'THJ5': 38 | link_dir = torch.tensor([0, 0, 1], dtype=torch.float32) 39 | else: 40 | link_dir = torch.tensor([1, 0, 0], dtype=torch.float32) 41 | elif robot_name == 'leaphand': 42 | if joint_name in ['13']: 43 | return None 44 | if joint_name in ['0', '4', '8']: 45 | link_dir = torch.tensor([1, 0, 0], dtype=torch.float32) 46 | elif joint_name in ['1', '5', '9', '12', '14']: 47 | link_dir = torch.tensor([0, 1, 0], dtype=torch.float32) 48 | else: 49 | link_dir = torch.tensor([0, -1, 0], dtype=torch.float32) 50 | else: 51 | raise NotImplementedError(f"Unknown robot name: {robot_name}!") 52 | 53 | return link_dir 54 | 55 | 56 | def controller(robot_name, q_para): 57 | q_batch = torch.atleast_2d(q_para) 58 | 59 | hand = create_hand_model(robot_name, device=q_batch.device) 60 | joint_orders = hand.get_joint_orders() 61 | pk_chain = hand.pk_chain 62 | if q_batch.shape[-1] != len(pk_chain.get_joint_parameter_names()): 63 | q_batch = q_rot6d_to_q_euler(q_batch) 64 | status = pk_chain.forward_kinematics(q_batch) 65 | 66 | outer_q_batch = [] 67 | inner_q_batch = [] 68 | for batch_idx in range(q_batch.shape[0]): 69 | joint_dots = {} 70 | for frame_name in pk_chain.get_frame_names(): 71 | frame = pk_chain.find_frame(frame_name) 72 | joint = frame.joint 73 | link_dir = get_link_dir(robot_name, joint.name) 74 | if link_dir is None: 75 | continue 76 | 77 | frame_transform = status[frame_name].get_matrix()[batch_idx] 78 | axis_dir = frame_transform[:3, :3] @ joint.axis 79 | link_dir = frame_transform[:3, :3] @ link_dir 80 | normal_dir = torch.cross(axis_dir, link_dir, dim=0) 81 | axis_origin = frame_transform[:3, 3] 82 | origin_dir = -axis_origin / torch.norm(axis_origin) 83 | joint_dots[joint.name] = torch.dot(normal_dir, origin_dir) 84 | 85 | q = q_batch[batch_idx] 86 | lower_q, upper_q = hand.pk_chain.get_joint_limits() 87 | outer_q, inner_q = q.clone(), q.clone() 88 | for joint_name, dot in joint_dots.items(): 89 | idx = joint_orders.index(joint_name) 90 | if robot_name == 'robotiq_3finger': # open -> upper, close -> lower 91 | outer_q[idx] += 0.25 * ((outer_q[idx] - lower_q[idx]) if dot <= 0 else (outer_q[idx] - upper_q[idx])) 92 | inner_q[idx] += 0.15 * ((inner_q[idx] - upper_q[idx]) if dot <= 0 else (inner_q[idx] - lower_q[idx])) 93 | else: # open -> lower, close -> upper 94 | outer_q[idx] += 0.25 * ((lower_q[idx] - outer_q[idx]) if dot >= 0 else (upper_q[idx] - outer_q[idx])) 95 | inner_q[idx] += 0.15 * ((upper_q[idx] - inner_q[idx]) if dot >= 0 else (lower_q[idx] - inner_q[idx])) 96 | outer_q_batch.append(outer_q) 97 | inner_q_batch.append(inner_q) 98 | 99 | outer_q_batch = torch.stack(outer_q_batch, dim=0) 100 | inner_q_batch = torch.stack(inner_q_batch, dim=0) 101 | 102 | if q_para.ndim == 2: # batch 103 | return outer_q_batch.to(q_para.device), inner_q_batch.to(q_para.device) 104 | else: 105 | return outer_q_batch[0].to(q_para.device), inner_q_batch[0].to(q_para.device) 106 | -------------------------------------------------------------------------------- /utils/func_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(ROOT_DIR) 7 | 8 | 9 | def calculate_depth(robot_pc, object_names): 10 | """ 11 | Calculate the average penetration depth of predicted pc into the object. 12 | 13 | :param robot_pc: (B, N, 3) 14 | :param object_name: list, len = B 15 | :return: calculated depth, (B,) 16 | """ 17 | object_pc_list = [] 18 | normals_list = [] 19 | for object_name in object_names: 20 | name = object_name.split('+') 21 | object_path = os.path.join(ROOT_DIR, f'data/PointCloud/object/{name[0]}/{name[1]}.pt') 22 | object_pc_normals = torch.load(object_path).to(robot_pc.device) 23 | object_pc_list.append(object_pc_normals[:, :3]) 24 | normals_list.append(object_pc_normals[:, 3:]) 25 | object_pc = torch.stack(object_pc_list, dim=0) 26 | normals = torch.stack(normals_list, dim=0) 27 | 28 | distance = torch.cdist(robot_pc, object_pc) 29 | distance, index = torch.min(distance, dim=-1) 30 | index = index.unsqueeze(-1).repeat(1, 1, 3) 31 | object_pc_indexed = torch.gather(object_pc, dim=1, index=index) 32 | normals_indexed = torch.gather(normals, dim=1, index=index) 33 | get_sign = torch.vmap(torch.vmap(lambda x, y: torch.where(torch.dot(x, y) >= 0, 1, -1))) 34 | signed_distance = distance * get_sign(robot_pc - object_pc_indexed, normals_indexed) 35 | signed_distance[signed_distance > 0] = 0 36 | return -torch.mean(signed_distance) 37 | 38 | 39 | def farthest_point_sampling(point_cloud, num_points=1024): 40 | """ 41 | :param point_cloud: (N, 3) or (N, 4), point cloud (with link index) 42 | :param num_points: int, number of sampled points 43 | :return: ((N, 3) or (N, 4), list), sampled point cloud (numpy) & index 44 | """ 45 | point_cloud_origin = point_cloud 46 | if point_cloud.shape[1] == 4: 47 | point_cloud = point_cloud[:, :3] 48 | 49 | selected_indices = [0] 50 | distances = torch.norm(point_cloud - point_cloud[selected_indices[-1]], dim=1) 51 | for _ in range(num_points - 1): 52 | farthest_point_idx = torch.argmax(distances) 53 | selected_indices.append(farthest_point_idx) 54 | new_distances = torch.norm(point_cloud - point_cloud[farthest_point_idx], dim=1) 55 | distances = torch.min(distances, new_distances) 56 | sampled_point_cloud = point_cloud_origin[selected_indices] 57 | 58 | return sampled_point_cloud, selected_indices 59 | -------------------------------------------------------------------------------- /utils/hand_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import math 5 | import random 6 | import numpy as np 7 | import torch 8 | import trimesh 9 | import pytorch_kinematics as pk 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(ROOT_DIR) 13 | 14 | from utils.func_utils import farthest_point_sampling 15 | from utils.mesh_utils import load_link_geometries 16 | from utils.rotation import * 17 | 18 | 19 | class HandModel: 20 | def __init__( 21 | self, 22 | robot_name, 23 | urdf_path, 24 | meshes_path, 25 | links_pc_path, 26 | device, 27 | link_num_points=512 28 | ): 29 | self.robot_name = robot_name 30 | self.urdf_path = urdf_path 31 | self.meshes_path = meshes_path 32 | self.device = device 33 | 34 | self.pk_chain = pk.build_chain_from_urdf(open(urdf_path).read()).to(dtype=torch.float32, device=device) 35 | self.dof = len(self.pk_chain.get_joint_parameter_names()) 36 | if os.path.exists(links_pc_path): # In case of generating robot links pc, the file doesn't exist. 37 | links_pc_data = torch.load(links_pc_path, map_location=device) 38 | self.links_pc = links_pc_data['filtered'] 39 | self.links_pc_original = links_pc_data['original'] 40 | else: 41 | self.links_pc = None 42 | self.links_pc_original = None 43 | 44 | self.meshes = load_link_geometries(robot_name, self.urdf_path, self.pk_chain.get_link_names()) 45 | 46 | self.vertices = {} 47 | removed_links = json.load(open(os.path.join(ROOT_DIR, 'data_utils/removed_links.json')))[robot_name] 48 | for link_name, link_mesh in self.meshes.items(): 49 | if link_name in removed_links: # remove links unrelated to contact 50 | continue 51 | v = link_mesh.sample(link_num_points) 52 | self.vertices[link_name] = v 53 | 54 | self.frame_status = None 55 | 56 | def get_joint_orders(self): 57 | return [joint.name for joint in self.pk_chain.get_joints()] 58 | 59 | def update_status(self, q): 60 | if q.shape[-1] != self.dof: 61 | q = q_rot6d_to_q_euler(q) 62 | self.frame_status = self.pk_chain.forward_kinematics(q.to(self.device)) 63 | 64 | def get_transformed_links_pc(self, q=None, links_pc=None): 65 | """ 66 | Use robot link pc & q value to get point cloud. 67 | 68 | :param q: (6 + DOF,), joint values (euler representation) 69 | :param links_pc: {link_name: (N_link, 3)}, robot links pc dict, not None only for get_sampled_pc() 70 | :return: point cloud: (N, 4), with link index 71 | """ 72 | if q is None: 73 | q = torch.zeros(self.dof, dtype=torch.float32, device=self.device) 74 | self.update_status(q) 75 | if links_pc is None: 76 | links_pc = self.links_pc 77 | 78 | all_pc_se3 = [] 79 | for link_index, (link_name, link_pc) in enumerate(links_pc.items()): 80 | if not torch.is_tensor(link_pc): 81 | link_pc = torch.tensor(link_pc, dtype=torch.float32, device=q.device) 82 | n_link = link_pc.shape[0] 83 | se3 = self.frame_status[link_name].get_matrix()[0].to(q.device) 84 | homogeneous_tensor = torch.ones(n_link, 1, device=q.device) 85 | link_pc_homogeneous = torch.cat([link_pc.to(q.device), homogeneous_tensor], dim=1) 86 | link_pc_se3 = (link_pc_homogeneous @ se3.T)[:, :3] 87 | index_tensor = torch.full([n_link, 1], float(link_index), device=q.device) 88 | link_pc_se3_index = torch.cat([link_pc_se3, index_tensor], dim=1) 89 | all_pc_se3.append(link_pc_se3_index) 90 | all_pc_se3 = torch.cat(all_pc_se3, dim=0) 91 | 92 | return all_pc_se3 93 | 94 | def get_sampled_pc(self, q=None, num_points=512): 95 | """ 96 | :param q: (9 + DOF,), joint values (rot6d representation) 97 | :param num_points: int, number of sampled points 98 | :return: ((N, 3), list), sampled point cloud (numpy) & index 99 | """ 100 | if q is None: 101 | q = self.get_canonical_q() 102 | 103 | sampled_pc = self.get_transformed_links_pc(q, self.vertices) 104 | return farthest_point_sampling(sampled_pc, num_points) 105 | 106 | def get_canonical_q(self): 107 | """ For visualization purposes only. """ 108 | lower, upper = self.pk_chain.get_joint_limits() 109 | canonical_q = torch.tensor(lower) * 0.75 + torch.tensor(upper) * 0.25 110 | canonical_q[:6] = 0 111 | return canonical_q 112 | 113 | def get_initial_q(self, q=None, max_angle: float = math.pi / 6): 114 | """ 115 | Compute the robot initial joint value q based on the target grasp. 116 | Root translation is not considered since the point cloud will be normalized to zero-mean. 117 | 118 | :param q: (6 + DOF,) or (9 + DOF,), joint values (euler/rot6d representation) 119 | :param max_angle: float, maximum angle of the random rotation 120 | :return: initial q: (6 + DOF,), euler representation 121 | """ 122 | if q is None: # random sample root rotation and joint values 123 | q_initial = torch.zeros(self.dof, dtype=torch.float32, device=self.device) 124 | 125 | q_initial[3:6] = (torch.rand(3) * 2 - 1) * torch.pi 126 | q_initial[5] /= 2 127 | 128 | lower_joint_limits, upper_joint_limits = self.pk_chain.get_joint_limits() 129 | lower_joint_limits = torch.tensor(lower_joint_limits[6:], dtype=torch.float32) 130 | upper_joint_limits = torch.tensor(upper_joint_limits[6:], dtype=torch.float32) 131 | portion = random.uniform(0.65, 0.85) 132 | q_initial[6:] = lower_joint_limits * portion + upper_joint_limits * (1 - portion) 133 | else: 134 | if len(q) == self.dof: 135 | q = q_euler_to_q_rot6d(q) 136 | q_initial = q.clone() 137 | 138 | # compute random initial rotation 139 | direction = - q_initial[:3] / torch.norm(q_initial[:3]) 140 | angle = torch.tensor(random.uniform(0, max_angle), device=q.device) # sample rotation angle 141 | axis = torch.randn(3).to(q.device) # sample rotation axis 142 | axis -= torch.dot(axis, direction) * direction # ensure orthogonality 143 | axis = axis / torch.norm(axis) 144 | random_rotation = axisangle_to_matrix(axis, angle).to(q.device) 145 | rotation_matrix = random_rotation @ rot6d_to_matrix(q_initial[3:9]) 146 | q_initial[3:9] = matrix_to_rot6d(rotation_matrix) 147 | 148 | # compute random initial joint values 149 | lower_joint_limits, upper_joint_limits = self.pk_chain.get_joint_limits() 150 | lower_joint_limits = torch.tensor(lower_joint_limits[6:], dtype=torch.float32) 151 | upper_joint_limits = torch.tensor(upper_joint_limits[6:], dtype=torch.float32) 152 | portion = random.uniform(0.65, 0.85) 153 | q_initial[9:] = lower_joint_limits * portion + upper_joint_limits * (1 - portion) 154 | # q_initial[9:] = torch.zeros_like(q_initial[9:], dtype=q.dtype, device=q.device) 155 | 156 | q_initial = q_rot6d_to_q_euler(q_initial) 157 | 158 | return q_initial 159 | 160 | def get_trimesh_q(self, q): 161 | """ Return the hand trimesh object corresponding to the input joint value q. """ 162 | self.update_status(q) 163 | 164 | scene = trimesh.Scene() 165 | for link_name in self.vertices: 166 | mesh_transform_matrix = self.frame_status[link_name].get_matrix()[0].cpu().numpy() 167 | scene.add_geometry(self.meshes[link_name].copy().apply_transform(mesh_transform_matrix)) 168 | 169 | vertices = [] 170 | faces = [] 171 | vertex_offset = 0 172 | for geom in scene.geometry.values(): 173 | if isinstance(geom, trimesh.Trimesh): 174 | vertices.append(geom.vertices) 175 | faces.append(geom.faces + vertex_offset) 176 | vertex_offset += len(geom.vertices) 177 | all_vertices = np.vstack(vertices) 178 | all_faces = np.vstack(faces) 179 | 180 | parts = {} 181 | for link_name in self.meshes: 182 | mesh_transform_matrix = self.frame_status[link_name].get_matrix()[0].cpu().numpy() 183 | part_mesh = self.meshes[link_name].copy().apply_transform(mesh_transform_matrix) 184 | parts[link_name] = part_mesh 185 | 186 | return_dict = { 187 | 'visual': trimesh.Trimesh(vertices=all_vertices, faces=all_faces), 188 | 'parts': parts 189 | } 190 | return return_dict 191 | 192 | def get_trimesh_se3(self, transform, index): 193 | """ Return the hand trimesh object corresponding to the input transform. """ 194 | scene = trimesh.Scene() 195 | for link_name in transform: 196 | mesh_transform_matrix = transform[link_name][index].cpu().numpy() 197 | scene.add_geometry(self.meshes[link_name].copy().apply_transform(mesh_transform_matrix)) 198 | 199 | vertices = [] 200 | faces = [] 201 | vertex_offset = 0 202 | for geom in scene.geometry.values(): 203 | if isinstance(geom, trimesh.Trimesh): 204 | vertices.append(geom.vertices) 205 | faces.append(geom.faces + vertex_offset) 206 | vertex_offset += len(geom.vertices) 207 | all_vertices = np.vstack(vertices) 208 | all_faces = np.vstack(faces) 209 | 210 | return trimesh.Trimesh(vertices=all_vertices, faces=all_faces) 211 | 212 | 213 | def create_hand_model( 214 | robot_name, 215 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 216 | num_points=512 217 | ): 218 | json_path = os.path.join(ROOT_DIR, 'data/data_urdf/robot/urdf_assets_meta.json') 219 | urdf_assets_meta = json.load(open(json_path)) 220 | urdf_path = os.path.join(ROOT_DIR, urdf_assets_meta['urdf_path'][robot_name]) 221 | meshes_path = os.path.join(ROOT_DIR, urdf_assets_meta['meshes_path'][robot_name]) 222 | links_pc_path = os.path.join(ROOT_DIR, f'data/PointCloud/robot/{robot_name}.pt') 223 | hand_model = HandModel(robot_name, urdf_path, meshes_path, links_pc_path, device, num_points) 224 | return hand_model 225 | -------------------------------------------------------------------------------- /utils/mesh_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | !!! This code file is not organized, there may be relatively chaotic writing and inconsistent comment formats. !!! 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import trimesh 8 | import xml.etree.ElementTree as ET 9 | from scipy.spatial.transform import Rotation as R 10 | 11 | 12 | def as_mesh(scene_or_mesh): 13 | """ 14 | Convert a possible scene to a mesh. 15 | 16 | If conversion occurs, the returned mesh has only vertex and face data. 17 | """ 18 | if isinstance(scene_or_mesh, trimesh.Scene): 19 | if len(scene_or_mesh.geometry) == 0: 20 | mesh = None # empty scene 21 | else: 22 | # we lose texture information here 23 | mesh = trimesh.util.concatenate( 24 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)for g in scene_or_mesh.geometry.values()) 25 | ) 26 | else: 27 | assert isinstance(scene_or_mesh, trimesh.Trimesh) 28 | mesh = scene_or_mesh 29 | return mesh 30 | 31 | 32 | def extract_colors_from_urdf(urdf_path): 33 | tree = ET.parse(urdf_path) 34 | root = tree.getroot() 35 | 36 | global_materials = {} 37 | 38 | for material in root.findall("material"): 39 | name = material.attrib["name"] 40 | color_elem = material.find("color") 41 | if color_elem is not None and "rgba" in color_elem.attrib: 42 | rgba = [float(c) for c in color_elem.attrib["rgba"].split()] 43 | global_materials[name] = rgba 44 | 45 | link_colors = {} 46 | 47 | for link in root.iter("link"): 48 | link_name = link.attrib["name"] 49 | visual = link.find("./visual") 50 | if visual is not None: 51 | material = visual.find("./material") 52 | if material is not None: 53 | color = material.find("color") 54 | if color is not None and "rgba" in color.attrib: 55 | rgba = [float(c) for c in color.attrib["rgba"].split()] 56 | link_colors[link_name] = rgba 57 | elif "name" in material.attrib: 58 | material_name = material.attrib["name"] 59 | if material_name in global_materials: 60 | link_colors[link_name] = global_materials[material_name] 61 | 62 | return link_colors 63 | 64 | 65 | def parse_origin(element): 66 | """Parse the origin element for translation and rotation.""" 67 | origin = element.find("origin") 68 | xyz = np.zeros(3) 69 | rotation = np.eye(3) 70 | if origin is not None: 71 | xyz = np.fromstring(origin.attrib.get("xyz", "0 0 0"), sep=" ") 72 | rpy = np.fromstring(origin.attrib.get("rpy", "0 0 0"), sep=" ") 73 | rotation = R.from_euler("xyz", rpy).as_matrix() 74 | return xyz, rotation 75 | 76 | 77 | def apply_transform(mesh, translation, rotation): 78 | """Apply translation and rotation to a mesh.""" 79 | # mesh.apply_translation(-mesh.centroid) 80 | transform = np.eye(4) 81 | transform[:3, :3] = rotation 82 | transform[:3, 3] = translation 83 | mesh.apply_transform(transform) 84 | return mesh 85 | 86 | 87 | def create_primitive_mesh(geometry, translation, rotation): 88 | """Create a trimesh object from primitive geometry definitions with transformations.""" 89 | if geometry.tag.endswith("box"): 90 | size = np.fromstring(geometry.attrib["size"], sep=" ") 91 | mesh = trimesh.creation.box(extents=size) 92 | elif geometry.tag.endswith("sphere"): 93 | radius = float(geometry.attrib["radius"]) 94 | mesh = trimesh.creation.icosphere(radius=radius) 95 | elif geometry.tag.endswith("cylinder"): 96 | radius = float(geometry.attrib["radius"]) 97 | length = float(geometry.attrib["length"]) 98 | mesh = trimesh.creation.cylinder(radius=radius, height=length) 99 | else: 100 | raise ValueError(f"Unsupported geometry type: {geometry.tag}") 101 | return apply_transform(mesh, translation, rotation) 102 | 103 | 104 | def load_link_geometries(robot_name, urdf_path, link_names, collision=False): 105 | """Load geometries (trimesh objects) for specified links from a URDF file, considering origins.""" 106 | urdf_dir = os.path.dirname(urdf_path) 107 | tree = ET.parse(urdf_path) 108 | root = tree.getroot() 109 | 110 | link_geometries = {} 111 | link_colors_from_urdf = extract_colors_from_urdf(urdf_path) 112 | 113 | for link in root.findall("link"): 114 | link_name = link.attrib["name"] 115 | link_color = link_colors_from_urdf.get(link_name, None) 116 | if link_name in link_names: 117 | geom_index = "collision" if collision else "visual" 118 | link_mesh = [] 119 | for visual in link.findall(".//" + geom_index): 120 | geometry = visual.find("geometry") 121 | xyz, rotation = parse_origin(visual) 122 | try: 123 | if geometry[0].tag.endswith("mesh"): 124 | mesh_filename = geometry[0].attrib["filename"] 125 | full_mesh_path = os.path.join(urdf_dir, mesh_filename) 126 | mesh = as_mesh(trimesh.load(full_mesh_path)) 127 | scale = np.fromstring(geometry[0].attrib.get("scale", "1 1 1"), sep=" ") 128 | mesh.apply_scale(scale) 129 | mesh = apply_transform(mesh, xyz, rotation) 130 | link_mesh.append(mesh) 131 | else: # Handle primitive shapes 132 | mesh = create_primitive_mesh(geometry[0], xyz, rotation) 133 | scale = np.fromstring(geometry[0].attrib.get("scale", "1 1 1"), sep=" ") 134 | mesh.apply_scale(scale) 135 | link_mesh.append(mesh) 136 | except Exception as e: 137 | print(f"Failed to load geometry for {link_name}: {e}") 138 | if len(link_mesh) == 0: 139 | continue 140 | elif len(link_mesh) > 1: 141 | link_trimesh = as_mesh(trimesh.Scene(link_mesh)) 142 | elif len(link_mesh) == 1: 143 | link_trimesh = link_mesh[0] 144 | 145 | if link_color is not None: 146 | link_trimesh.visual.face_colors = np.array(link_color) 147 | link_geometries[link_name] = link_trimesh 148 | 149 | return link_geometries 150 | -------------------------------------------------------------------------------- /utils/multilateration.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | 7 | @typing.no_type_check 8 | def estimate_p( 9 | P: torch.FloatTensor, R: torch.FloatTensor, W: Optional[torch.FloatTensor] = None 10 | ) -> torch.FloatStorage: 11 | assert P.ndim == 3 # N x D x 1 12 | assert R.ndim == 1 # N 13 | assert P.shape[0] == R.shape[0] 14 | assert P.shape[1] in {2, 3} 15 | 16 | N, D, _ = P.shape 17 | 18 | if W is None: 19 | W = torch.ones(N, device=P.device) 20 | assert W.ndim == 1 # N 21 | W = W[:, None, None] 22 | 23 | # Shared stuff. 24 | Pt = P.permute(0, 2, 1) 25 | PPt = P @ Pt 26 | PtP = (Pt @ P).squeeze() 27 | I = torch.eye(D, device=P.device) 28 | NI = I[None].repeat(N, 1, 1) 29 | PtP_minus_r2 = (PtP - R**2)[:, None, None] 30 | 31 | # These are ripped straight from the paper, with weighting passed through. 32 | a = (W * (PtP_minus_r2 * P)).mean(dim=0) 33 | B = (W * (-2 * PPt - PtP_minus_r2 * NI)).mean(dim=0) 34 | c = (W * P).mean(dim=0) 35 | f = a + B @ c + 2 * c @ c.T @ c 36 | H = -2 * PPt.mean(dim=0) + 2 * c @ c.T 37 | q = -torch.linalg.inv(H) @ f 38 | p = q + c 39 | 40 | return p 41 | 42 | 43 | def multilateration(dro, fixed_pc): 44 | """ 45 | Compute the target point cloud described by D(R,O) matrix & fixed_pc 46 | 47 | :param dro: (B, N, N), point-wise relative distance matrix between target point cloud & fixed point cloud 48 | :param fixed_pc: (B, N, 3), point cloud as a reference for relative distance 49 | :return: (B, N, 3), the target point cloud 50 | """ 51 | assert dro.ndim == 3 and fixed_pc.ndim == 3, "multilateration() requires batch data." 52 | v_est_p = torch.vmap(torch.vmap(estimate_p, in_dims=(None, 0))) 53 | target_pc = v_est_p(fixed_pc.unsqueeze(-1), dro)[..., 0] 54 | return target_pc -------------------------------------------------------------------------------- /utils/optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cvxpy as cp 4 | from cvxpylayers.torch import CvxpyLayer 5 | 6 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | sys.path.append(ROOT_DIR) 8 | 9 | from utils.rotation import * 10 | 11 | 12 | def process_transform(pk_chain, transform, device=None): 13 | """ Compute extra link transform, and convert SE3 transform to only translation. """ 14 | new_transform = transform.copy() 15 | for name in pk_chain.get_frame_names(exclude_fixed=False): 16 | if name.startswith('extra'): 17 | frame = pk_chain.find_frame(name) 18 | parent_name = pk_chain.idx_to_frame[pk_chain.parents_indices[pk_chain.frame_to_idx[name]][-2].item()] 19 | new_transform[name] = new_transform[parent_name] @ frame.joint.offset.get_matrix()[0] 20 | for name, se3 in new_transform.items(): 21 | new_transform[name] = se3[:, :3, 3] 22 | if device is not None: 23 | new_transform[name] = new_transform[name].to(device) 24 | 25 | return new_transform 26 | 27 | 28 | def jacobian(pk_chain, q, frame_X_dict, frame_names): 29 | """ 30 | Calculate Jacobian (dX/dq) of all frames 31 | 32 | Notation: (similar as https://manipulation.csail.mit.edu/pick.html#monogram) 33 | J: jacobian, X: transform, R: rotation, p: position, v: velocity, w: angular velocity 34 | <>_BA_C: of frame A measured from frame B expressed in frame C 35 | W: world frame, J: joint frame, F: link frame 36 | 37 | :param pk_chain: get from pk.build_chain_from_urdf() 38 | :param q: (6 + DOF,) or (B, 6 + DOF), joint values (euler representation) 39 | :return: Jacobian: {frame_name: (B, 6, num_joints)} 40 | """ 41 | jacobian_dict = {} 42 | 43 | q = torch.atleast_2d(q) 44 | batch_size = q.shape[0] 45 | joint_names = pk_chain.get_joint_parameter_names() 46 | num_joints = len(joint_names) 47 | joint_name2idx = {name: idx for idx, name in enumerate(joint_names)} 48 | 49 | frames = [pk_chain.find_frame(name) for name in pk_chain.get_joint_parent_frame_names()] 50 | idx = lambda frame: joint_name2idx[frame.joint.name] 51 | 52 | transfer_X = {} 53 | for frame in frames: 54 | q_frame = q[:, idx(frame)] 55 | if frame.joint.joint_type == 'prismatic': 56 | q_frame = q_frame.unsqueeze(-1) 57 | transfer_X[idx(frame)] = frame.get_transform(q_frame).get_matrix() 58 | 59 | frame_X_dict = {f: frame_X_dict[f] for f in frame_X_dict if f in frame_names} 60 | 61 | for frame_name, frame_X in frame_X_dict.items(): 62 | jacobian = torch.zeros((batch_size, 6, num_joints), dtype=pk_chain.dtype, device=pk_chain.device) 63 | 64 | R_WF = frame_X.get_matrix()[:, :3, :3] 65 | X_JF = torch.eye(4, dtype=pk_chain.dtype, device=pk_chain.device).repeat(batch_size, 1, 1) 66 | for frame_idx in reversed(pk_chain.parents_indices[pk_chain.frame_to_idx[frame_name]].tolist()): 67 | frame = pk_chain.find_frame(pk_chain.idx_to_frame[frame_idx]) 68 | joint = frame.joint 69 | if joint.joint_type == 'fixed': 70 | if joint.offset is not None: 71 | X_JF = joint.offset.get_matrix() @ X_JF 72 | continue 73 | 74 | R_FJ = X_JF[:, :3, :3].mT 75 | R_WJ = R_WF @ R_FJ 76 | p_JF_J = X_JF[:, :3, 3][:, :, None] 77 | w_WJ_J = joint.axis[None, :, None].repeat(batch_size, 1, 1) 78 | if joint.joint_type == 'revolute': 79 | jacobian_v = R_WJ @ torch.cross(w_WJ_J, p_JF_J, dim=1) 80 | jacobian_w = R_WJ @ w_WJ_J 81 | elif joint.joint_type == 'prismatic': 82 | jacobian_v = R_WJ @ w_WJ_J 83 | jacobian_w = torch.zeros([batch_size, 3, 1], dtype=jacobian_v.dtype, device=jacobian_v.device) 84 | else: 85 | raise NotImplementedError(f"Unknown joint_type: {joint.joint_type}") 86 | 87 | joint_idx = joint_name2idx[joint.name] 88 | X_JF = transfer_X[joint_idx] @ X_JF 89 | jacobian[:, :, joint_idx] = torch.cat([jacobian_v[..., 0], jacobian_w[..., 0]], dim=1) 90 | 91 | jacobian_dict[frame_name] = jacobian 92 | return jacobian_dict 93 | 94 | 95 | def create_problem(pk_chain, frame_names): 96 | """ 97 | Only use all frame positions (ignore rotation) to optimize joint values. 98 | 99 | :param pk_chain: get from pk.build_chain_from_urdf() 100 | :param frame_names: list of frame names to optimize 101 | :return: CvxpyLayer() 102 | """ 103 | n_joint = len(pk_chain.get_joint_parameter_names()) 104 | 105 | delta_q = cp.Variable(n_joint) 106 | 107 | q = cp.Parameter(n_joint) 108 | jacobian = {} 109 | frame_xyz = {} 110 | target_frame_xyz = {} 111 | 112 | objective_expr = 0 113 | for link_name in frame_names: 114 | frame_xyz[link_name] = cp.Parameter(3) 115 | target_frame_xyz[link_name] = cp.Parameter(3) 116 | 117 | jacobian[link_name] = cp.Parameter((3, n_joint)) 118 | delta_frame_xyz = jacobian[link_name] @ delta_q 119 | 120 | predict_frame_xyz = frame_xyz[link_name] + delta_frame_xyz 121 | objective_expr += cp.norm2(predict_frame_xyz - target_frame_xyz[link_name]) 122 | objective = cp.Minimize(objective_expr) 123 | 124 | lower_joint_limits, upper_joint_limits = pk_chain.get_joint_limits() 125 | upper_limit = cp.minimum(0.5, upper_joint_limits - q) 126 | lower_limit = cp.maximum(-0.5, lower_joint_limits - q) 127 | constraints = [delta_q <= upper_limit, delta_q >= lower_limit] 128 | problem = cp.Problem(objective, constraints) 129 | 130 | layer = CvxpyLayer( 131 | problem, 132 | parameters=[ 133 | q, 134 | *frame_xyz.values(), 135 | *target_frame_xyz.values(), 136 | *jacobian.values() 137 | ], 138 | variables=[delta_q] 139 | ) 140 | return layer 141 | 142 | 143 | def optimization(pk_chain, layer, initial_q, transform, n_iter=64, verbose=False): 144 | if initial_q.shape[-1] != len(pk_chain.get_joint_parameter_names()): 145 | initial_q = q_rot6d_to_q_euler(initial_q) 146 | q = initial_q.clone() 147 | 148 | for i in range(n_iter): 149 | status = pk_chain.forward_kinematics(q) 150 | jacobians = jacobian(pk_chain, q, status, transform.keys()) 151 | 152 | frame_xyz = {} 153 | target_frame_xyz = {} 154 | jacobians_xyz = {} 155 | for link_name, link_jacobian in jacobians.items(): 156 | frame_xyz[link_name] = status[link_name].get_matrix()[:, :3, 3] 157 | target_frame_xyz[link_name] = transform[link_name] 158 | jacobians_xyz[link_name] = link_jacobian[:, :3, :] 159 | 160 | delta_q = layer( 161 | q, 162 | *list(frame_xyz.values()), 163 | *list(target_frame_xyz.values()), 164 | *list(jacobians_xyz.values()), 165 | ) 166 | q += delta_q[0] 167 | if verbose: 168 | print(f'[Step {i}], delta_q norm: {delta_q[0].norm()}') 169 | if delta_q[0].norm() < 0.3: 170 | if verbose: 171 | print("Converged at iteration:", i) 172 | break 173 | return q 174 | -------------------------------------------------------------------------------- /utils/pretrain_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is sourced with some modifications made, from 3 | https://github.com/r-pad/taxpose/blob/0c4298fa0486fd09e63bf24d618a579b66ba0f18/taxpose/utils/emb_losses.py. 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def dist2weight(xyz, func=None): 11 | d = (xyz.unsqueeze(1) - xyz.unsqueeze(2)).norm(dim=-1) 12 | if func is not None: 13 | d = func(d) 14 | w = d / d.max(dim=-1, keepdims=True)[0] 15 | w = w + torch.eye(d.shape[-1], device=d.device).unsqueeze(0).tile( 16 | [d.shape[0], 1, 1] 17 | ) 18 | return w 19 | 20 | 21 | def infonce_loss(phi_1, phi_2, weights=None, temperature=0.1): 22 | B, N, D = phi_1.shape 23 | 24 | # cosine similarity 25 | phi_1 = F.normalize(phi_1, dim=2) 26 | phi_2 = F.normalize(phi_2, dim=2) 27 | similarity = phi_1 @ phi_2.mT 28 | 29 | target = torch.arange(N, device=similarity.device).tile([B, 1]) 30 | if weights is None: 31 | weights = 1.0 32 | loss = F.cross_entropy(torch.log(weights) + (similarity / temperature), target) 33 | 34 | return loss, similarity 35 | 36 | 37 | def mean_order(similarity): 38 | order = (similarity > similarity.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)).sum(-1) 39 | return order.float().mean() / similarity.shape[-1] 40 | -------------------------------------------------------------------------------- /utils/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.spatial.transform import Rotation 3 | 4 | def matrix_to_euler(matrix): 5 | device = matrix.device 6 | # forward_kinematics() requires intrinsic euler ('XYZ') 7 | euler = Rotation.from_matrix(matrix.cpu().numpy()).as_euler('XYZ') 8 | return torch.tensor(euler, dtype=torch.float32, device=device) 9 | 10 | def euler_to_matrix(euler): 11 | device = euler.device 12 | matrix = Rotation.from_euler('XYZ', euler.cpu().numpy()).as_matrix() 13 | return torch.tensor(matrix, dtype=torch.float32, device=device) 14 | 15 | def matrix_to_rot6d(matrix): 16 | return matrix.T.reshape(9)[:6] 17 | 18 | def rot6d_to_matrix(rot6d): 19 | x = normalize(rot6d[..., 0:3]) 20 | y = normalize(rot6d[..., 3:6]) 21 | a = normalize(x + y) 22 | b = normalize(x - y) 23 | x = normalize(a + b) 24 | y = normalize(a - b) 25 | z = normalize(torch.cross(x, y, dim=-1)) 26 | matrix = torch.stack([x, y, z], dim=-2).mT 27 | return matrix 28 | 29 | def euler_to_rot6d(euler): 30 | matrix = euler_to_matrix(euler) 31 | return matrix_to_rot6d(matrix) 32 | 33 | def rot6d_to_euler(rot6d): 34 | matrix = rot6d_to_matrix(rot6d) 35 | return matrix_to_euler(matrix) 36 | 37 | def axisangle_to_matrix(axis, angle): 38 | (x, y, z), c, s = axis, torch.cos(angle), torch.sin(angle) 39 | return torch.tensor([ 40 | [(1 - c) * x * x + c, (1 - c) * x * y - s * z, (1 - c) * x * z + s * y], 41 | [(1 - c) * x * y + s * z, (1 - c) * y * y + c, (1 - c) * y * z - s * x], 42 | [(1 - c) * x * z - s * y, (1 - c) * y * z + s * x, (1 - c) * z * z + c] 43 | ]) 44 | 45 | def euler_to_quaternion(euler): 46 | device = euler.device 47 | quaternion = Rotation.from_euler('XYZ', euler.cpu().numpy()).as_quat() 48 | return torch.tensor(quaternion, dtype=torch.float32, device=device) 49 | 50 | def normalize(v): 51 | return v / torch.norm(v, dim=-1, keepdim=True) 52 | 53 | def q_euler_to_q_rot6d(q_euler): 54 | return torch.cat([q_euler[..., :3], euler_to_rot6d(q_euler[..., 3:6]), q_euler[..., 6:]], dim=-1) 55 | 56 | def q_rot6d_to_q_euler(q_rot6d): 57 | return torch.cat([q_rot6d[..., :3], rot6d_to_euler(q_rot6d[..., 3:9]), q_rot6d[..., 9:]], dim=-1) 58 | 59 | 60 | if __name__ == '__main__': 61 | """ Test correctness of above functions, no need to compare euler angle due to singularity. """ 62 | test_euler = torch.rand(3) * 2 * torch.pi 63 | 64 | test_matrix = euler_to_matrix(test_euler) 65 | test_euler_prime = matrix_to_euler(test_matrix) 66 | test_matrix_prime = euler_to_matrix(test_euler_prime) 67 | assert torch.allclose(test_matrix, test_matrix_prime), \ 68 | f"Original Matrix: {test_matrix}, Converted Matrix: {test_matrix_prime}" 69 | 70 | test_rot6d = matrix_to_rot6d(test_matrix) 71 | test_matrix_prime = rot6d_to_matrix(test_rot6d) 72 | assert torch.allclose(test_matrix, test_matrix_prime),\ 73 | f"Original Matrix: {test_matrix}, Converted Matrix: {test_matrix_prime}" 74 | 75 | test_rot6d_prime = matrix_to_rot6d(test_matrix_prime) 76 | assert torch.allclose(test_rot6d, test_rot6d_prime), \ 77 | f"Original Rot6D: {test_rot6d}, Converted Rot6D: {test_rot6d_prime}" 78 | 79 | test_euler_prime = rot6d_to_euler(test_rot6d) 80 | test_rot6d_prime = euler_to_rot6d(test_euler_prime) 81 | assert torch.allclose(test_rot6d, test_rot6d_prime), \ 82 | f"Original Rot6D: {test_rot6d}, Converted Rot6D: {test_rot6d_prime}" 83 | 84 | print("All Tests Passed!") 85 | -------------------------------------------------------------------------------- /utils/se3_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_se3_transform(P, Q): 5 | """ 6 | Compute SE3 transform between two point clouds. 7 | 8 | :param P: (N, 3) or (B, N, 3), point cloud (w/ or w/o batch) 9 | :param Q: same as P 10 | :return: SE3 transform between P and Q, (4, 4) or (B, 4, 4) 11 | """ 12 | assert P.shape == Q.shape 13 | 14 | if P.ndim == 2: # (N, 3) 15 | P_mean = torch.mean(P, dim=0) 16 | Q_mean = torch.mean(Q, dim=0) 17 | P_prime = P - P_mean 18 | Q_prime = Q - Q_mean 19 | H = P_prime.T @ Q_prime 20 | U, _, Vt = torch.linalg.svd(H) 21 | V = Vt.T 22 | R = V @ U.T 23 | if torch.linalg.det(R) < 0: 24 | V[:, -1] *= -1 25 | R = V @ U.T 26 | t = Q_mean - R @ P_mean 27 | 28 | T = torch.eye(4).to(P.device) 29 | T[:3, :3] = R 30 | T[:3, 3] = t 31 | elif P.ndim == 3: # (B, N, 3) 32 | P_mean = torch.mean(P, dim=1, keepdim=True) 33 | Q_mean = torch.mean(Q, dim=1, keepdim=True) 34 | P_prime = P - P_mean 35 | Q_prime = Q - Q_mean 36 | H = P_prime.transpose(1, 2) @ Q_prime 37 | U, _, Vt = torch.linalg.svd(H) 38 | V = Vt.transpose(1, 2) 39 | R = V @ U.transpose(1, 2) 40 | det_R = torch.linalg.det(R) 41 | VV = V.clone() 42 | VV[:, :, -1] *= torch.where(det_R < 0, -1.0, 1.0).unsqueeze(-1) 43 | RR = VV @ U.transpose(1, 2) 44 | t = Q_mean.squeeze(1) - (RR @ P_mean.transpose(1, 2)).squeeze(-1) 45 | 46 | T = torch.eye(4).repeat(P.shape[0], 1, 1).to(P.device) 47 | T[:, :3, :3] = RR 48 | T[:, :3, 3] = t 49 | else: 50 | raise RuntimeError('Unexpected point cloud shape!') 51 | 52 | return T 53 | 54 | 55 | def se3_transform_point_cloud(P, Transform): 56 | """ 57 | Apply SE3 transform on point cloud. 58 | 59 | :param P: (N, 3) or (B, N, 3), point cloud (w/ or w/o batch) 60 | :param Transform: SE3 transform (w/ or w/o batch) 61 | :return: Point Cloud after SE3 transform, (N, 3) or (B, N, 3) 62 | """ 63 | P_prime = torch.cat((P, torch.ones([*P.shape[:-1], 1], dtype=torch.float32, device=P.device)), dim=-1) 64 | P_transformed = P_prime @ Transform.mT 65 | return P_transformed[..., :3] 66 | 67 | 68 | def compute_link_pose(robot_links_pc, predict_pcs, is_train=True): 69 | """ 70 | Calculate link poses of the predicted pc. 71 | 72 | :param robot_links_pc: (train) [{link_name: (N_i, 3), ...}, ...], per link sampled points of batch robots 73 | (validate) {link_name: (N_i, 3), ...}, per link sampled points of the same robot 74 | :param predict_pcs: (B, N, 3), point cloud to calculate SE3 75 | :return: link transforms, [{link_name: (4, 4)}, ...]; 76 | transformed_pc, (B, N, 3) 77 | """ 78 | if is_train: 79 | assert predict_pcs.ndim == 3, "compute_link_pose() requires batch data during training." 80 | batch_transform = [] 81 | batch_transformed_pc = [] 82 | for idx in range(len(robot_links_pc)): 83 | links_pc = robot_links_pc[idx] 84 | predict_pc = predict_pcs[idx] 85 | 86 | global_index = 0 87 | transform = {} 88 | transformed_pc = [] 89 | for link_index, (link_name, link_pc) in enumerate(links_pc.items()): 90 | predict_pc_link = predict_pc[global_index: global_index + link_pc.shape[-2], :3] 91 | global_index += link_pc.shape[0] 92 | 93 | link_se3 = compute_se3_transform(link_pc.unsqueeze(0), predict_pc_link.unsqueeze(0))[0] # (4, 4) 94 | link_transformed_pc = se3_transform_point_cloud(link_pc, link_se3) # (N_link, 3) 95 | transform[link_name] = link_se3 96 | transformed_pc.append(link_transformed_pc) 97 | 98 | batch_transform.append(transform) 99 | batch_transformed_pc.append(torch.cat(transformed_pc, dim=0)) 100 | 101 | return batch_transform, torch.stack(batch_transformed_pc, dim=0) 102 | else: 103 | batch_transform = {} 104 | batch_transformed_pc = [] 105 | global_index = 0 106 | for link_index, (link_name, link_pc) in enumerate(robot_links_pc.items()): 107 | if predict_pcs.ndim == 3 and link_pc.ndim != 3: 108 | link_pc = link_pc.unsqueeze(0).repeat(predict_pcs.shape[0], 1, 1) 109 | predict_pc_link = predict_pcs[..., global_index: global_index + link_pc.shape[-2], :3] 110 | global_index += link_pc.shape[-2] 111 | batch_transform[link_name] = compute_se3_transform(link_pc, predict_pc_link) 112 | batch_transformed_pc.append(se3_transform_point_cloud(link_pc, batch_transform[link_name])) 113 | batch_transformed_pc = torch.cat(batch_transformed_pc, dim=-2) 114 | 115 | return batch_transform, batch_transformed_pc 116 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most of the visualization code has not been encapsulated into functions; 3 | only the part for visualizing vectors is kept in this file, and the comment format is not consistent. 4 | """ 5 | 6 | import trimesh 7 | import numpy as np 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | 11 | def normalize(x): 12 | """ 13 | Normalize the input vector. If the magnitude of the vector is zero, a small value is added to prevent division by zero. 14 | 15 | Parameters: 16 | - x (np.ndarray): Input vector to be normalized. 17 | 18 | Returns: 19 | - np.ndarray: Normalized vector. 20 | """ 21 | if len(x.shape) == 1: 22 | mag = np.linalg.norm(x) 23 | if mag == 0: 24 | mag = mag + 1e-10 25 | return x / mag 26 | else: 27 | norms = np.linalg.norm(x, axis=1, keepdims=True) 28 | norms = np.where(norms == 0, 1e-10, norms) 29 | return x / norms 30 | 31 | 32 | def sample_transform_w_normals( 33 | new_palm_center, 34 | new_face_vector, 35 | sample_roll, 36 | ori_face_vector=np.array([1.0, 0.0, 0.0]), 37 | ): 38 | """ 39 | Compute the transformation matrix from the original palm pose to a new palm pose. 40 | 41 | Parameters: 42 | - new_palm_center (np.ndarray): The point of the palm center [x, y, z]. 43 | - new_face_vector (np.ndarray): The direction vector representing the new palm facing direction. 44 | - sample_roll (float): The roll angle in range [0, 2*pi). 45 | - ori_face_vector (np.ndarray): The original direction vector representing the palm facing direction. Default is [1.0, 0.0, 0.0]. 46 | 47 | Returns: 48 | - rst_transform (np.ndarray): A 4x4 transformation matrix. 49 | """ 50 | 51 | rot_axis = np.cross(ori_face_vector, normalize(new_face_vector)) 52 | rot_axis = rot_axis / (np.linalg.norm(rot_axis) + 1e-16) 53 | rot_ang = np.arccos(np.clip(np.dot(ori_face_vector, new_face_vector), -1.0, 1.0)) 54 | 55 | if rot_ang > 3.1415 or rot_ang < -3.1415: 56 | rot_axis = ( 57 | np.array([1.0, 0.0, 0.0]) 58 | if not np.isclose(ori_face_vector, np.array([1.0, 0.0, 0.0])).all() 59 | else np.array([0.0, 1.0, 0.0]) 60 | ) 61 | 62 | rot = R.from_rotvec(rot_ang * rot_axis).as_matrix() 63 | roll_rot = R.from_rotvec(sample_roll * new_face_vector).as_matrix() 64 | 65 | final_rot = roll_rot @ rot 66 | rst_transform = np.eye(4) 67 | rst_transform[:3, :3] = final_rot 68 | rst_transform[:3, 3] = new_palm_center 69 | return rst_transform 70 | 71 | 72 | def vis_vector( 73 | start_point, 74 | vector, 75 | length=0.1, 76 | cyliner_r=0.003, 77 | color=[255, 255, 100, 245], 78 | no_arrow=False, 79 | ): 80 | """ 81 | start_points: np.ndarray, shape=(3,) 82 | vectors: np.ndarray, shape=(3,) 83 | length: cylinder length 84 | """ 85 | normalized_vector = normalize(vector) 86 | end_point = start_point + length * normalized_vector 87 | 88 | # create a mesh for the force 89 | force_cylinder = trimesh.creation.cylinder( 90 | radius=cyliner_r, segment=np.array([start_point, end_point]) 91 | ) 92 | 93 | # create a mesh for the arrowhead 94 | cone_transform = sample_transform_w_normals( 95 | end_point, normalized_vector, 0, ori_face_vector=np.array([0.0, 0.0, 1.0]) 96 | ) 97 | arrowhead_cone = trimesh.creation.cone( 98 | radius=2 * cyliner_r, height=4 * cyliner_r, transform=cone_transform 99 | ) 100 | # combine the two meshes into one 101 | if not no_arrow: 102 | force_mesh = force_cylinder + arrowhead_cone 103 | else: 104 | force_mesh = force_cylinder 105 | force_mesh.visual.face_colors = color 106 | 107 | return force_mesh 108 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import warnings 5 | from termcolor import cprint 6 | import hydra 7 | import numpy as np 8 | 9 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from model.network import create_network 13 | from data_utils.CMapDataset import create_dataloader 14 | from utils.multilateration import multilateration 15 | from utils.se3_transform import compute_link_pose 16 | from utils.optimization import * 17 | from utils.hand_model import create_hand_model 18 | from validation.validate_utils import validate_isaac 19 | 20 | 21 | @hydra.main(version_base="1.2", config_path="configs", config_name="validate") 22 | def main(cfg): 23 | device = torch.device(f'cuda:{cfg.gpu}') 24 | batch_size = cfg.dataset.batch_size 25 | print(f"Device: {device}") 26 | print('Name:', cfg.name) 27 | 28 | os.makedirs(os.path.join(ROOT_DIR, 'validate_output'), exist_ok=True) 29 | log_file_name = os.path.join(ROOT_DIR, f'validate_output/{cfg.name}.log') 30 | print('Log file:', log_file_name) 31 | for validate_epoch in cfg.validate_epochs: 32 | print(f"************************ Validating epoch {validate_epoch} ************************") 33 | with open(log_file_name, 'a') as f: 34 | print(f"************************ Validating epoch {validate_epoch} ************************", file=f) 35 | 36 | network = create_network(cfg.model, mode='validate').to(device) 37 | network.load_state_dict(torch.load(f"output/{cfg.name}/state_dict/epoch_{validate_epoch}.pth", map_location=device)) 38 | network.eval() 39 | 40 | dataloader = create_dataloader(cfg.dataset, is_train=False) 41 | 42 | global_robot_name = None 43 | hand = None 44 | all_success_q = [] 45 | time_list = [] 46 | success_num = 0 47 | total_num = 0 48 | vis_info = [] 49 | for i, data in enumerate(dataloader): 50 | robot_name = data['robot_name'] 51 | object_name = data['object_name'] 52 | 53 | if robot_name != global_robot_name: 54 | if global_robot_name is not None: 55 | all_success_q = torch.cat(all_success_q, dim=0) 56 | diversity_std = torch.std(all_success_q, dim=0).mean() 57 | times = np.array(time_list) 58 | time_mean = np.mean(times) 59 | time_std = np.std(times) 60 | 61 | success_rate = success_num / total_num * 100 62 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 63 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 64 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 65 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 66 | with open(log_file_name, 'a') as f: 67 | cprint(f"[{global_robot_name}]", 'magenta', end=' ', file=f) 68 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ', file=f) 69 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ', file=f) 70 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue', file=f) 71 | 72 | all_success_q = [] 73 | time_list = [] 74 | success_num = 0 75 | total_num = 0 76 | hand = create_hand_model(robot_name, device) 77 | global_robot_name = robot_name 78 | 79 | initial_q_list = [] 80 | predict_q_list = [] 81 | object_pc_list = [] 82 | mlat_pc_list = [] 83 | transform_list = [] 84 | data_count = 0 85 | while data_count != batch_size: 86 | split_num = min(batch_size - data_count, cfg.split_batch_size) 87 | 88 | initial_q = data['initial_q'][data_count : data_count + split_num].to(device) 89 | robot_pc = data['robot_pc'][data_count : data_count + split_num].to(device) 90 | object_pc = data['object_pc'][data_count : data_count + split_num].to(device) 91 | 92 | data_count += split_num 93 | 94 | with torch.no_grad(): 95 | dro = network( 96 | robot_pc, 97 | object_pc 98 | )['dro'].detach() 99 | 100 | mlat_pc = multilateration(dro, object_pc) 101 | transform, _ = compute_link_pose(hand.links_pc, mlat_pc, is_train=False) 102 | optim_transform = process_transform(hand.pk_chain, transform) 103 | 104 | layer = create_problem(hand.pk_chain, optim_transform.keys()) 105 | start_time = time.time() 106 | predict_q = optimization(hand.pk_chain, layer, initial_q, optim_transform) 107 | end_time = time.time() 108 | print(f"[{data_count}/{batch_size}] Optimization time: {end_time - start_time:.4f} s") 109 | time_list.append(end_time - start_time) 110 | 111 | initial_q_list.append(initial_q) 112 | predict_q_list.append(predict_q) 113 | object_pc_list.append(object_pc) 114 | mlat_pc_list.append(mlat_pc) 115 | transform_list.append(transform) 116 | 117 | initial_q_batch = torch.cat(initial_q_list, dim=0) 118 | predict_q_batch = torch.cat(predict_q_list, dim=0) 119 | object_pc_batch = torch.cat(object_pc_list, dim=0) 120 | mlat_pc_batch = torch.cat(mlat_pc_list, dim=0) 121 | transform_batch = {} 122 | for transform in transform_list: 123 | for k, v in transform.items(): 124 | transform_batch[k] = v if k not in transform_batch else torch.cat((transform_batch[k], v), dim=0) 125 | 126 | success, isaac_q = validate_isaac(robot_name, object_name, predict_q_batch, gpu=cfg.gpu) 127 | succ_num = success.sum().item() if success is not None else -1 128 | success_q = predict_q_batch[success] 129 | all_success_q.append(success_q) 130 | 131 | vis_info.append({ 132 | 'robot_name': robot_name, 133 | 'object_name': object_name, 134 | 'initial_q': initial_q_batch, 135 | 'predict_q': predict_q_batch, 136 | 'object_pc': object_pc_batch, 137 | 'mlat_pc': mlat_pc_batch, 138 | 'predict_transform': transform_batch, 139 | 'success': success, 140 | 'isaac_q': isaac_q 141 | }) 142 | 143 | cprint(f"[{robot_name}/{object_name}]", 'light_blue', end=' ') 144 | cprint(f"Result: {succ_num}/{batch_size}({succ_num / batch_size * 100:.2f}%)", 'green') 145 | with open(log_file_name, 'a') as f: 146 | cprint(f"[{robot_name}/{object_name}]", 'light_blue', end=' ', file=f) 147 | cprint(f"Result: {succ_num}/{batch_size}({succ_num / batch_size * 100:.2f}%)", 'green', file=f) 148 | success_num += succ_num 149 | total_num += batch_size 150 | 151 | all_success_q = torch.cat(all_success_q, dim=0) 152 | diversity_std = torch.std(all_success_q, dim=0).mean() 153 | 154 | times = np.array(time_list) 155 | time_mean = np.mean(times) 156 | time_std = np.std(times) 157 | 158 | success_rate = success_num / total_num * 100 159 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 160 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 161 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 162 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 163 | with open(log_file_name, 'a') as f: 164 | cprint(f"[{global_robot_name}]", 'magenta', end=' ', file=f) 165 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ', file=f) 166 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ', file=f) 167 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue', file=f) 168 | 169 | vis_info_file = f'{cfg.name}_epoch{validate_epoch}' 170 | os.makedirs(os.path.join(ROOT_DIR, 'vis_info'), exist_ok=True) 171 | torch.save(vis_info, os.path.join(ROOT_DIR, f'vis_info/{vis_info_file}.pt')) 172 | 173 | 174 | if __name__ == "__main__": 175 | warnings.simplefilter(action='ignore', category=FutureWarning) 176 | torch.set_num_threads(8) 177 | main() 178 | -------------------------------------------------------------------------------- /validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenyuwei2003/DRO-Grasp/b312055b4a20f73ddfeb3ffc8a1a6c80d48bbe31/validation/__init__.py -------------------------------------------------------------------------------- /validation/asset_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | 4 | NVIDIA CORPORATION and its licensors retain all intellectual property 5 | and proprietary rights in and to this software, related documentation 6 | and any modifications thereto. Any use, reproduction, disclosure or 7 | distribution of this software and related documentation without an express 8 | license agreement from NVIDIA CORPORATION is strictly prohibited. 9 | 10 | 11 | Asset and Environment Information 12 | --------------------------------- 13 | Demonstrates introspection capabilities of the gym api at the asset and environment levels 14 | - Once an asset is loaded its properties can be queried 15 | - Assets in environments can be queried and their current states be retrieved 16 | """ 17 | 18 | import os 19 | from isaacgym import gymapi 20 | from isaacgym import gymutil 21 | 22 | 23 | def print_asset_info(gym, asset, name): 24 | print("======== Asset info %s: ========" % (name)) 25 | num_bodies = gym.get_asset_rigid_body_count(asset) 26 | num_joints = gym.get_asset_joint_count(asset) 27 | num_dofs = gym.get_asset_dof_count(asset) 28 | print("Got %d bodies, %d joints, and %d DOFs" % 29 | (num_bodies, num_joints, num_dofs)) 30 | 31 | # Iterate through bodies 32 | print("Bodies:") 33 | for i in range(num_bodies): 34 | name = gym.get_asset_rigid_body_name(asset, i) 35 | print(" %2d: '%s'" % (i, name)) 36 | 37 | # Iterate through joints 38 | print("Joints:") 39 | for i in range(num_joints): 40 | name = gym.get_asset_joint_name(asset, i) 41 | type = gym.get_asset_joint_type(asset, i) 42 | type_name = gym.get_joint_type_string(type) 43 | print(" %2d: '%s' (%s)" % (i, name, type_name)) 44 | 45 | # iterate through degrees of freedom (DOFs) 46 | print("DOFs:") 47 | for i in range(num_dofs): 48 | name = gym.get_asset_dof_name(asset, i) 49 | type = gym.get_asset_dof_type(asset, i) 50 | type_name = gym.get_dof_type_string(type) 51 | print(" %2d: '%s' (%s)" % (i, name, type_name)) 52 | 53 | 54 | def print_actor_info(gym, env, actor_handle): 55 | 56 | name = gym.get_actor_name(env, actor_handle) 57 | 58 | body_names = gym.get_actor_rigid_body_names(env, actor_handle) 59 | body_dict = gym.get_actor_rigid_body_dict(env, actor_handle) 60 | 61 | joint_names = gym.get_actor_joint_names(env, actor_handle) 62 | joint_dict = gym.get_actor_joint_dict(env, actor_handle) 63 | 64 | dof_names = gym.get_actor_dof_names(env, actor_handle) 65 | dof_dict = gym.get_actor_dof_dict(env, actor_handle) 66 | 67 | print() 68 | print("===== Actor: %s =======================================" % name) 69 | 70 | print("\nBodies") 71 | print(body_names) 72 | print(body_dict) 73 | 74 | print("\nJoints") 75 | print(joint_names) 76 | print(joint_dict) 77 | 78 | print("\n Degrees Of Freedom (DOFs)") 79 | print(dof_names) 80 | print(dof_dict) 81 | print() 82 | 83 | # Get body state information 84 | body_states = gym.get_actor_rigid_body_states( 85 | env, actor_handle, gymapi.STATE_ALL) 86 | 87 | # Print some state slices 88 | print("Poses from Body State:") 89 | print(body_states['pose']) # print just the poses 90 | 91 | print("\nVelocities from Body State:") 92 | print(body_states['vel']) # print just the velocities 93 | print() 94 | 95 | # iterate through bodies and print name and position 96 | body_positions = body_states['pose']['p'] 97 | for i in range(len(body_names)): 98 | print("Body '%s' has position" % body_names[i], body_positions[i]) 99 | 100 | print("\nDOF states:") 101 | 102 | # get DOF states 103 | dof_states = gym.get_actor_dof_states(env, actor_handle, gymapi.STATE_ALL) 104 | 105 | # print some state slices 106 | # Print all states for each degree of freedom 107 | print(dof_states) 108 | print() 109 | 110 | # iterate through DOFs and print name and position 111 | dof_positions = dof_states['pos'] 112 | for i in range(len(dof_names)): 113 | print("DOF '%s' has position" % dof_names[i], dof_positions[i]) 114 | -------------------------------------------------------------------------------- /validation/isaac_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import warnings 6 | from termcolor import cprint 7 | 8 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(ROOT_DIR) 10 | 11 | from validation.isaac_validator import IsaacValidator # IsaacGym must be imported before PyTorch 12 | from utils.hand_model import create_hand_model 13 | from utils.rotation import q_rot6d_to_q_euler 14 | 15 | import torch 16 | 17 | 18 | def isaac_main( 19 | mode: str, 20 | robot_name: str, 21 | object_name: str, 22 | batch_size: int, 23 | q_batch: torch.Tensor = None, 24 | gpu: int = 0, 25 | use_gui: bool = False 26 | ): 27 | """ 28 | For filtering dataset and validating grasps. 29 | 30 | :param mode: str, 'filter' or 'validation' 31 | :param robot_name: str 32 | :param object_name: str 33 | :param batch_size: int, number of grasps in Isaac Gym simultaneously 34 | :param q_batch: torch.Tensor (validation only) 35 | :param gpu: int, specify the GPU device used by Isaac Gym 36 | :param use_gui: bool, whether to visualize Isaac Gym simulation process 37 | :return: success: (batch_size,), bool, whether each grasp is successful in Isaac Gym; 38 | q_isaac: (success_num, DOF), torch.float32, successful joint values after the grasp phase 39 | """ 40 | if mode == 'filter' and batch_size == 0: # special judge for num_per_object = 0 in dataset 41 | return 0, None 42 | if use_gui: # for unknown reason otherwise will segmentation fault :( 43 | gpu = 0 44 | 45 | data_urdf_path = os.path.join(ROOT_DIR, 'data/data_urdf') 46 | urdf_assets_meta = json.load(open(os.path.join(data_urdf_path, 'robot/urdf_assets_meta.json'))) 47 | robot_urdf_path = urdf_assets_meta['urdf_path'][robot_name] 48 | object_name_split = object_name.split('+') if object_name is not None else None 49 | # object_urdf_path = f'{object_name_split[0]}/{object_name_split[1]}/{object_name_split[1]}.urdf' 50 | object_urdf_path = f'{object_name_split[0]}/{object_name_split[1]}/coacd_decomposed_object_one_link.urdf' 51 | 52 | hand = create_hand_model(robot_name) 53 | joint_orders = hand.get_joint_orders() 54 | 55 | simulator = IsaacValidator( 56 | robot_name=robot_name, 57 | joint_orders=joint_orders, 58 | batch_size=batch_size, 59 | gpu=gpu, 60 | is_filter=(mode == 'filter'), 61 | use_gui=use_gui 62 | ) 63 | print("[Isaac] IsaacValidator is created.") 64 | 65 | simulator.set_asset( 66 | robot_path=os.path.join(data_urdf_path, 'robot'), 67 | robot_file=robot_urdf_path[21:], # ignore 'data/data_urdf/robot/' 68 | object_path=os.path.join(data_urdf_path, 'object'), 69 | object_file=object_urdf_path 70 | ) 71 | simulator.create_envs() 72 | print("[Isaac] IsaacValidator preparation is done.") 73 | 74 | if mode == 'filter': 75 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset/cmap_dataset.pt') 76 | metadata = torch.load(dataset_path)['metadata'] 77 | q_batch = [m[1] for m in metadata if m[2] == object_name and m[3] == robot_name] 78 | q_batch = torch.stack(q_batch, dim=0).to(torch.device('cpu')) 79 | if q_batch.shape[-1] != len(joint_orders): 80 | q_batch = q_rot6d_to_q_euler(q_batch) 81 | 82 | simulator.set_actor_pose_dof(q_batch.to(torch.device('cpu'))) 83 | success, q_isaac = simulator.run_sim() 84 | simulator.destroy() 85 | 86 | return success, q_isaac 87 | 88 | 89 | # for Python scripts subprocess call to avoid Isaac Gym GPU memory leak problem 90 | if __name__ == '__main__': 91 | warnings.simplefilter(action='ignore', category=FutureWarning) 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--mode', type=str, required=True) 94 | parser.add_argument('--robot_name', type=str, required=True) 95 | parser.add_argument('--object_name', type=str, required=True) 96 | parser.add_argument('--batch_size', type=int, required=True) 97 | parser.add_argument('--q_file', type=str) 98 | parser.add_argument('--gpu', default=0, type=int) 99 | parser.add_argument('--use_gui', action='store_true') 100 | args = parser.parse_args() 101 | 102 | print(f'GPU: {args.gpu}') 103 | assert args.mode in ['filter', 'validation'], f"Unknown mode: {args.mode}!" 104 | q_batch = torch.load(args.q_file, map_location=f'cpu') if args.q_file is not None else None 105 | success, q_isaac = isaac_main( 106 | mode=args.mode, 107 | robot_name=args.robot_name, 108 | object_name=args.object_name, 109 | batch_size=args.batch_size, 110 | q_batch=q_batch, 111 | gpu=args.gpu, 112 | use_gui=args.use_gui 113 | ) 114 | 115 | success_num = success.sum().item() 116 | if args.mode == 'filter': 117 | print(f"<{args.robot_name}/{args.object_name}> before: {args.batch_size}, after: {success_num}") 118 | if success_num > 0: 119 | q_filtered = q_isaac[success] 120 | save_dir = str(os.path.join(ROOT_DIR, 'data/CMapDataset_filtered', args.robot_name)) 121 | os.makedirs(save_dir, exist_ok=True) 122 | torch.save(q_filtered, os.path.join(save_dir, f'{args.object_name}_{success_num}.pt')) 123 | elif args.mode == 'validation': 124 | cprint(f"[{args.robot_name}/{args.object_name}] Result: {success_num}/{args.batch_size}", 'green') 125 | save_data = { 126 | 'success': success, 127 | 'q_isaac': q_isaac 128 | } 129 | os.makedirs(os.path.join(ROOT_DIR, 'tmp'), exist_ok=True) 130 | torch.save(save_data, os.path.join(ROOT_DIR, f'tmp/isaac_main_ret_{args.gpu}.pt')) 131 | -------------------------------------------------------------------------------- /validation/isaac_validator.py: -------------------------------------------------------------------------------- 1 | from isaacgym import gymapi 2 | from isaacgym import gymtorch 3 | 4 | import os 5 | import sys 6 | import time 7 | import numpy as np 8 | import torch 9 | from scipy.spatial.transform import Rotation 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(ROOT_DIR) 13 | 14 | from utils.controller import controller 15 | 16 | 17 | class IsaacValidator: 18 | def __init__( 19 | self, 20 | robot_name, 21 | joint_orders, 22 | batch_size, 23 | gpu=0, 24 | is_filter=False, 25 | use_gui=False, 26 | robot_friction=3., 27 | object_friction=3., 28 | steps_per_sec=100, 29 | grasp_step=100, 30 | debug_interval=0.01 31 | ): 32 | self.gym = gymapi.acquire_gym() 33 | 34 | self.robot_name = robot_name 35 | self.joint_orders = joint_orders 36 | self.batch_size = batch_size 37 | self.gpu = gpu 38 | self.is_filter = is_filter 39 | self.robot_friction = robot_friction 40 | self.object_friction = object_friction 41 | self.steps_per_sec = steps_per_sec 42 | self.grasp_step = grasp_step 43 | self.debug_interval = debug_interval 44 | 45 | self.envs = [] 46 | self.robot_handles = [] 47 | self.object_handles = [] 48 | self.robot_asset = None 49 | self.object_asset = None 50 | self.rigid_body_num = None 51 | self.object_force = None 52 | self.urdf2isaac_order = None 53 | self.isaac2urdf_order = None 54 | 55 | self.sim_params = gymapi.SimParams() 56 | # set common parameters 57 | self.sim_params.dt = 1 / steps_per_sec 58 | self.sim_params.substeps = 2 59 | self.sim_params.gravity = gymapi.Vec3(0.0, 0.0, 0.0) 60 | #self.sim_params.use_gpu_pipeline = True 61 | # set PhysX-specific parameters 62 | self.sim_params.physx.use_gpu = True 63 | self.sim_params.physx.solver_type = 1 64 | self.sim_params.physx.num_position_iterations = 8 65 | self.sim_params.physx.num_velocity_iterations = 0 66 | self.sim_params.physx.contact_offset = 0.01 67 | self.sim_params.physx.rest_offset = 0.0 68 | 69 | self.sim = self.gym.create_sim(self.gpu, self.gpu, gymapi.SIM_PHYSX, self.sim_params) 70 | self._rigid_body_states = self.gym.acquire_rigid_body_state_tensor(self.sim) 71 | self._dof_states = self.gym.acquire_dof_state_tensor(self.sim) 72 | 73 | self.viewer = None 74 | if use_gui: 75 | self.has_viewer = True 76 | self.camera_props = gymapi.CameraProperties() 77 | self.camera_props.width = 1920 78 | self.camera_props.height = 1080 79 | self.camera_props.use_collision_geometry = True 80 | self.viewer = self.gym.create_viewer(self.sim, self.camera_props) 81 | self.gym.viewer_camera_look_at(self.viewer, None, gymapi.Vec3(1, 0, 0), gymapi.Vec3(0, 0, 0)) 82 | else: 83 | self.has_viewer = False 84 | 85 | self.robot_asset_options = gymapi.AssetOptions() 86 | self.robot_asset_options.disable_gravity = True 87 | self.robot_asset_options.fix_base_link = True 88 | self.robot_asset_options.collapse_fixed_joints = True 89 | 90 | self.object_asset_options = gymapi.AssetOptions() 91 | self.object_asset_options.override_com = True 92 | self.object_asset_options.override_inertia = True 93 | self.object_asset_options.density = 500 94 | 95 | def set_asset(self, robot_path, robot_file, object_path, object_file): 96 | self.robot_asset = self.gym.load_asset(self.sim, robot_path, robot_file, self.robot_asset_options) 97 | self.object_asset = self.gym.load_asset(self.sim, object_path, object_file, self.object_asset_options) 98 | self.rigid_body_num = (self.gym.get_asset_rigid_body_count(self.robot_asset) 99 | + self.gym.get_asset_rigid_body_count(self.object_asset)) 100 | # print_asset_info(gym, self.robot_asset, 'robot') 101 | # print_asset_info(gym, self.object_asset, 'object') 102 | 103 | def create_envs(self): 104 | for env_idx in range(self.batch_size): 105 | env = self.gym.create_env( 106 | self.sim, 107 | gymapi.Vec3(-1, -1, -1), 108 | gymapi.Vec3(1, 1, 1), 109 | int(self.batch_size ** 0.5) 110 | ) 111 | self.envs.append(env) 112 | 113 | # draw world frame 114 | if self.has_viewer: 115 | x_axis_dir = np.array([0, 0, 0, 1, 0, 0], dtype=np.float32) 116 | x_axis_color = np.array([1, 0, 0], dtype=np.float32) 117 | self.gym.add_lines(self.viewer, env, 1, x_axis_dir, x_axis_color) 118 | y_axis_dir = np.array([0, 0, 0, 0, 1, 0], dtype=np.float32) 119 | y_axis_color = np.array([0, 1, 0], dtype=np.float32) 120 | self.gym.add_lines(self.viewer, env, 1, y_axis_dir, y_axis_color) 121 | z_axis_dir = np.array([0, 0, 0, 0, 0, 1], dtype=np.float32) 122 | z_axis_color = np.array([0, 0, 1], dtype=np.float32) 123 | self.gym.add_lines(self.viewer, env, 1, z_axis_dir, z_axis_color) 124 | 125 | # object actor setting 126 | object_handle = self.gym.create_actor( 127 | env, 128 | self.object_asset, 129 | gymapi.Transform(), 130 | f'object_{env_idx}', 131 | env_idx 132 | ) 133 | self.object_handles.append(object_handle) 134 | 135 | object_shape_properties = self.gym.get_actor_rigid_shape_properties(env, object_handle) 136 | for i in range(len(object_shape_properties)): 137 | object_shape_properties[i].friction = self.object_friction 138 | self.gym.set_actor_rigid_shape_properties(env, object_handle, object_shape_properties) 139 | 140 | # robot actor setting 141 | robot_handle = self.gym.create_actor( 142 | env, 143 | self.robot_asset, 144 | gymapi.Transform(), 145 | f'robot_{env_idx}', 146 | env_idx 147 | ) 148 | self.robot_handles.append(robot_handle) 149 | 150 | robot_properties = self.gym.get_actor_dof_properties(env, robot_handle) 151 | robot_properties["driveMode"].fill(gymapi.DOF_MODE_POS) 152 | robot_properties["stiffness"].fill(1000) 153 | robot_properties["damping"].fill(200) 154 | self.gym.set_actor_dof_properties(env, robot_handle, robot_properties) 155 | 156 | object_shape_properties = self.gym.get_actor_rigid_shape_properties(env, robot_handle) 157 | for i in range(len(object_shape_properties)): 158 | object_shape_properties[i].friction = self.robot_friction 159 | self.gym.set_actor_rigid_shape_properties(env, robot_handle, object_shape_properties) 160 | 161 | # print_actor_info(self.gym, env, robot_handle) 162 | # print_actor_info(self.gym, env, object_handle) 163 | 164 | # assume robots & objects in the same batch are the same 165 | obj_property = self.gym.get_actor_rigid_body_properties(self.envs[0], self.object_handles[0]) 166 | object_mass = [obj_property[i].mass for i in range(len(obj_property))] 167 | object_mass = torch.tensor(object_mass) 168 | self.object_force = 0.5 * object_mass 169 | 170 | self.urdf2isaac_order = np.zeros(len(self.joint_orders), dtype=np.int32) 171 | self.isaac2urdf_order = np.zeros(len(self.joint_orders), dtype=np.int32) 172 | for urdf_idx, joint_name in enumerate(self.joint_orders): 173 | isaac_idx = self.gym.find_actor_dof_index(self.envs[0], self.robot_handles[0], joint_name, gymapi.DOMAIN_ACTOR) 174 | self.urdf2isaac_order[isaac_idx] = urdf_idx 175 | self.isaac2urdf_order[urdf_idx] = isaac_idx 176 | 177 | def set_actor_pose_dof(self, q): 178 | self.gym.prepare_sim(self.sim) 179 | 180 | # set all actors to origin 181 | _root_state = self.gym.acquire_actor_root_state_tensor(self.sim) 182 | root_state = gymtorch.wrap_tensor(_root_state) 183 | root_state[:] = torch.tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=torch.float32) 184 | self.gym.set_actor_root_state_tensor(self.sim, _root_state) 185 | 186 | outer_q, inner_q = controller(self.robot_name, q) 187 | 188 | for env_idx in range(len(self.envs)): 189 | env = self.envs[env_idx] 190 | robot_handle = self.robot_handles[env_idx] 191 | 192 | dof_states_initial = self.gym.get_actor_dof_states(env, robot_handle, gymapi.STATE_ALL).copy() 193 | dof_states_initial['pos'] = outer_q[env_idx, self.urdf2isaac_order] 194 | self.gym.set_actor_dof_states(env, robot_handle, dof_states_initial, gymapi.STATE_ALL) 195 | 196 | dof_states_target = self.gym.get_actor_dof_states(env, robot_handle, gymapi.STATE_ALL).copy() 197 | dof_states_target['pos'] = inner_q[env_idx, self.urdf2isaac_order] 198 | self.gym.set_actor_dof_position_targets(env, robot_handle, dof_states_target["pos"]) 199 | 200 | def run_sim(self): 201 | # controller phase 202 | for step in range(self.grasp_step): 203 | self.gym.simulate(self.sim) 204 | 205 | if self.has_viewer: 206 | if self.gym.query_viewer_has_closed(self.viewer): 207 | break 208 | t = time.time() 209 | while time.time() - t < self.debug_interval: 210 | self.gym.step_graphics(self.sim) 211 | self.gym.draw_viewer(self.viewer, self.sim, render_collision=True) 212 | 213 | self.gym.refresh_rigid_body_state_tensor(self.sim) 214 | start_pos = gymtorch.wrap_tensor(self._rigid_body_states)[::self.rigid_body_num, :3].clone() 215 | 216 | force_tensor = torch.zeros([len(self.envs), self.rigid_body_num, 3]) # env, rigid_body, xyz 217 | x_pos_force = force_tensor.clone() 218 | x_pos_force[:, 0, 0] = self.object_force 219 | x_neg_force = force_tensor.clone() 220 | x_neg_force[:, 0, 0] = -self.object_force 221 | y_pos_force = force_tensor.clone() 222 | y_pos_force[:, 0, 1] = self.object_force 223 | y_neg_force = force_tensor.clone() 224 | y_neg_force[:, 0, 1] = -self.object_force 225 | z_pos_force = force_tensor.clone() 226 | z_pos_force[:, 0, 2] = self.object_force 227 | z_neg_force = force_tensor.clone() 228 | z_neg_force[:, 0, 2] = -self.object_force 229 | force_list = [x_pos_force, y_pos_force, z_pos_force, x_neg_force, y_neg_force, z_neg_force] 230 | 231 | # force phase 232 | for step in range(self.steps_per_sec * 6): 233 | self.gym.apply_rigid_body_force_tensors(self.sim, 234 | gymtorch.unwrap_tensor(force_list[step // self.steps_per_sec]), 235 | None, 236 | gymapi.ENV_SPACE) 237 | self.gym.simulate(self.sim) 238 | self.gym.fetch_results(self.sim, True) 239 | 240 | if self.has_viewer: 241 | if self.gym.query_viewer_has_closed(self.viewer): 242 | break 243 | t = time.time() 244 | while time.time() - t < self.debug_interval: 245 | self.gym.step_graphics(self.sim) 246 | self.gym.draw_viewer(self.viewer, self.sim, render_collision=True) 247 | 248 | self.gym.refresh_rigid_body_state_tensor(self.sim) 249 | end_pos = gymtorch.wrap_tensor(self._rigid_body_states)[::self.rigid_body_num, :3].clone() 250 | 251 | distance = (end_pos - start_pos).norm(dim=-1) 252 | 253 | if self.is_filter: 254 | success = (distance <= 0.02) & (end_pos.norm(dim=-1) <= 0.05) 255 | else: 256 | success = (distance <= 0.02) 257 | 258 | # apply inverse object transform to robot to get new joint value 259 | self.gym.refresh_rigid_body_state_tensor(self.sim) 260 | object_pose = gymtorch.wrap_tensor(self._rigid_body_states).clone()[::self.rigid_body_num, :7] # batch_size, 7 (xyz + quat) 261 | object_transform = np.eye(4)[np.newaxis].repeat(self.batch_size, axis=0) 262 | object_transform[:, :3, 3] = object_pose[:, :3] 263 | object_transform[:, :3, :3] = Rotation.from_quat(object_pose[:, 3:7]).as_matrix() 264 | 265 | self.gym.refresh_dof_state_tensor(self.sim) 266 | dof_states = gymtorch.wrap_tensor(self._dof_states).clone().reshape(len(self.envs), -1, 2)[:, :, 0] # batch_size, DOF (xyz + euler + joint) 267 | robot_transform = np.eye(4)[np.newaxis].repeat(self.batch_size, axis=0) 268 | robot_transform[:, :3, 3] = dof_states[:, :3] 269 | robot_transform[:, :3, :3] = Rotation.from_euler('XYZ', dof_states[:, 3:6]).as_matrix() 270 | 271 | robot_transform = np.linalg.inv(object_transform) @ robot_transform 272 | dof_states[:, :3] = torch.tensor(robot_transform[:, :3, 3]) 273 | dof_states[:, 3:6] = torch.tensor(Rotation.from_matrix(robot_transform[:, :3, :3]).as_euler('XYZ')) 274 | q_isaac = dof_states[:, self.isaac2urdf_order].to(torch.device('cpu')) 275 | 276 | return success, q_isaac 277 | 278 | def reset_simulator(self): 279 | self.gym.destroy_sim(self.sim) 280 | if self.has_viewer: 281 | self.gym.destroy_viewer(self.viewer) 282 | self.viewer = self.gym.create_viewer(self.sim, self.camera_props) 283 | self.sim = self.gym.create_sim(self.gpu, self.gpu, gymapi.SIM_PHYSX, self.sim_params) 284 | for env in self.envs: 285 | self.gym.destroy_env(env) 286 | self.envs = [] 287 | self.robot_handles = [] 288 | self.object_handles = [] 289 | self.robot_asset = None 290 | self.object_asset = None 291 | 292 | def destroy(self): 293 | for env in self.envs: 294 | self.gym.destroy_env(env) 295 | self.gym.destroy_sim(self.sim) 296 | if self.has_viewer: 297 | self.gym.destroy_viewer(self.viewer) 298 | del self.gym 299 | -------------------------------------------------------------------------------- /validation/validate_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | from tqdm import tqdm 5 | from termcolor import cprint 6 | import torch 7 | import trimesh 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.controller import controller 13 | 14 | 15 | def validate_depth(hand, object_name, q_list_validate, threshold=0.005, exact=True): 16 | """ 17 | Calculate the penetration depth of predicted grasps into the object. 18 | 19 | :param hand: HandModel() 20 | :param object_name: str 21 | :param q_list_validate: list, joint values to validate 22 | :param threshold: float, criteria for determining success in depth 23 | :param exact: bool, if false, use point cloud instead of mesh to compute (much faster with minor error) 24 | :param print_info: bool, whether to print progress information 25 | :return: (list, list), success list & depth list 26 | """ 27 | name = object_name.split('+') 28 | if exact: 29 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') 30 | object_mesh = trimesh.load_mesh(object_path) 31 | else: 32 | object_path = os.path.join(ROOT_DIR, f'data/PointCloud/object/{name[0]}/{name[1]}.pt') 33 | object_pc_normals = torch.load(object_path).to(hand.device) 34 | object_pc = object_pc_normals[:, :3] 35 | normals = object_pc_normals[:, 3:] 36 | 37 | result_list = [] 38 | depth_list = [] 39 | q_list_initial = [] 40 | for q in q_list_validate: 41 | initial_q, _ = controller(hand.robot_name, q) 42 | q_list_initial.append(initial_q) 43 | for q in tqdm(q_list_initial): 44 | robot_pc = hand.get_transformed_links_pc(q)[:, :3] 45 | if exact: 46 | robot_pc = robot_pc.cpu() 47 | _, distance, _ = trimesh.proximity.ProximityQuery(object_mesh).on_surface(robot_pc) 48 | distance = distance[object_mesh.contains(robot_pc)] 49 | depth = distance.max() if distance.size else 0. 50 | else: 51 | distance = torch.cdist(robot_pc, object_pc) 52 | distance, index = torch.min(distance, dim=-1) 53 | object_pc_indexed, normals_indexed = object_pc[index], normals[index] 54 | get_sign = torch.vmap(lambda x, y: torch.where(torch.dot(x, y) >= 0, 1, -1)) 55 | signed_distance = distance * get_sign(robot_pc - object_pc_indexed, normals_indexed) 56 | depth, _ = torch.min(signed_distance, dim=-1) 57 | depth = -depth.item() if depth.item() < 0 else 0. 58 | 59 | result_list.append(depth <= threshold) 60 | depth_list.append(round(depth * 1000, 2)) 61 | 62 | return result_list, depth_list 63 | 64 | 65 | def validate_isaac(robot_name, object_name, q_batch, gpu: int = 0): 66 | """ 67 | Wrap function for subprocess call (isaac_main.py) to avoid Isaac Gym GPU memory leak problem. 68 | 69 | :param robot_name: str 70 | :param object_name: str 71 | :param q_batch: torch.Tensor, joint values to validate 72 | :param gpu: int 73 | :return: (list, list), success list & info list 74 | """ 75 | os.makedirs(os.path.join(ROOT_DIR, 'tmp'), exist_ok=True) 76 | q_file_path = str(os.path.join(ROOT_DIR, f'tmp/q_list_validate_{gpu}.pt')) 77 | torch.save(q_batch, q_file_path) 78 | batch_size = q_batch.shape[0] 79 | args = [ 80 | 'python', 81 | os.path.join(ROOT_DIR, 'validation/isaac_main.py'), 82 | '--mode', 'validation', 83 | '--robot_name', robot_name, 84 | '--object_name', object_name, 85 | '--batch_size', str(batch_size), 86 | '--q_file', q_file_path, 87 | '--gpu', str(gpu), 88 | # '--use_gui' 89 | ] 90 | ret = subprocess.run(args, capture_output=True, text=True) 91 | try: 92 | ret_file_path = os.path.join(ROOT_DIR, f'tmp/isaac_main_ret_{gpu}.pt') 93 | save_data = torch.load(ret_file_path) 94 | success = save_data['success'] 95 | q_isaac = save_data['q_isaac'] 96 | os.remove(q_file_path) 97 | os.remove(ret_file_path) 98 | except FileNotFoundError as e: 99 | cprint(f"Caught a ValueError: {e}", 'yellow') 100 | cprint(ret.stdout.strip(), 'blue') 101 | cprint(ret.stderr.strip(), 'red') 102 | exit() 103 | return success, q_isaac 104 | -------------------------------------------------------------------------------- /visualization/vis_controller.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import time 6 | import argparse 7 | import trimesh 8 | import torch 9 | import viser 10 | 11 | from utils.hand_model import create_hand_model 12 | from utils.controller import controller, get_link_dir 13 | from utils.vis_utils import vis_vector 14 | 15 | 16 | def vis_controller_result(robot_name='shadowhand', object_name=None): 17 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 18 | metadata = torch.load(dataset_path)['metadata'] 19 | metadata = [m for m in metadata if (object_name is None or m[1] == object_name) and m[2] == robot_name] 20 | 21 | server = viser.ViserServer(host='127.0.0.1', port=8080) 22 | 23 | slider = server.gui.add_slider( 24 | label='robot', 25 | min=0, 26 | max=len(metadata) - 1, 27 | step=1, 28 | initial_value=0 29 | ) 30 | slider.on_update(lambda gui: on_update(gui.target.value)) 31 | 32 | hand = create_hand_model(robot_name) 33 | 34 | def on_update(idx): 35 | q, object_name, _ = metadata[idx] 36 | outer_q, inner_q = controller(robot_name, q) 37 | 38 | name = object_name.split('+') 39 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') 40 | object_trimesh = trimesh.load_mesh(object_path) 41 | server.scene.add_mesh_simple( 42 | 'object', 43 | object_trimesh.vertices, 44 | object_trimesh.faces, 45 | color=(239, 132, 167), 46 | opacity=0.75 47 | ) 48 | 49 | robot_trimesh = hand.get_trimesh_q(q)["visual"] 50 | server.scene.add_mesh_simple( 51 | 'origin', 52 | robot_trimesh.vertices, 53 | robot_trimesh.faces, 54 | color=(102, 192, 255), 55 | opacity=0.75 56 | ) 57 | robot_trimesh = hand.get_trimesh_q(outer_q)["visual"] 58 | server.scene.add_mesh_simple( 59 | 'outer', 60 | robot_trimesh.vertices, 61 | robot_trimesh.faces, 62 | color=(255, 149, 71), 63 | opacity=0.75 64 | ) 65 | robot_trimesh = hand.get_trimesh_q(inner_q)["visual"] 66 | server.scene.add_mesh_simple( 67 | 'inner', 68 | robot_trimesh.vertices, 69 | robot_trimesh.faces, 70 | color=(255, 111, 190), 71 | opacity=0.75 72 | ) 73 | 74 | while True: 75 | time.sleep(1) 76 | 77 | 78 | def vis_hand_direction(robot_name='shadowhand'): 79 | server = viser.ViserServer(host='127.0.0.1', port=8080) 80 | 81 | hand = create_hand_model(robot_name, device='cpu') 82 | q = hand.get_canonical_q() 83 | joint_orders = hand.get_joint_orders() 84 | lower, upper = hand.pk_chain.get_joint_limits() 85 | 86 | canonical_trimesh = hand.get_trimesh_q(q)["visual"] 87 | server.scene.add_mesh_simple( 88 | robot_name, 89 | canonical_trimesh.vertices, 90 | canonical_trimesh.faces, 91 | color=(102, 192, 255), 92 | opacity=0.8 93 | ) 94 | 95 | pk_chain = hand.pk_chain 96 | status = pk_chain.forward_kinematics(q) 97 | joint_dots = {} 98 | for frame_name in pk_chain.get_frame_names(): 99 | frame = pk_chain.find_frame(frame_name) 100 | joint = frame.joint 101 | link_dir = get_link_dir(robot_name, joint.name) 102 | if link_dir is None: 103 | continue 104 | 105 | frame_transform = status[frame_name].get_matrix()[0] 106 | axis_dir = frame_transform[:3, :3] @ joint.axis 107 | link_dir = frame_transform[:3, :3] @ link_dir 108 | normal_dir = torch.cross(axis_dir, link_dir, dim=0) 109 | axis_origin = frame_transform[:3, 3] 110 | origin_dir = -axis_origin / torch.norm(axis_origin) 111 | joint_dots[joint.name] = float(torch.dot(normal_dir, origin_dir)) 112 | 113 | print(joint.name, joint_orders.index(joint.name), joint_dots[joint.name]) 114 | vec_mesh = vis_vector( 115 | axis_origin.numpy(), 116 | vector=normal_dir.numpy(), 117 | length=0.03, 118 | cyliner_r=0.001, 119 | color=(0, 255, 0) 120 | ) 121 | server.scene.add_mesh_trimesh(joint.name, vec_mesh, visible=True) 122 | 123 | current_q = [0 if i < 6 else lower[i] * 0.75 + upper[i] * 0.25 for i in range(hand.dof)] 124 | 125 | def update(joint_idx, joint_q): 126 | current_q[joint_idx] = joint_q 127 | trimesh = hand.get_trimesh_q(torch.tensor(current_q))["visual"] 128 | server.scene.add_mesh_simple( 129 | robot_name, 130 | trimesh.vertices, 131 | trimesh.faces, 132 | color=(102, 192, 255), 133 | opacity=0.8 134 | ) 135 | 136 | for i, joint_name in enumerate(joint_orders): 137 | if joint_name in joint_dots.keys(): 138 | slider = server.gui.add_slider( 139 | label=joint_name, 140 | min=round(lower[i], 2), 141 | max=round(upper[i], 2), 142 | step=(upper[i] - lower[i]) / 100, 143 | initial_value=current_q[i], 144 | ) 145 | slider.on_update(lambda gui: update(gui.target.order - 1, gui.target.value)) 146 | else: 147 | slider = server.gui.add_slider(label='', min=0, max=1, step=1, initial_value=0) 148 | 149 | while True: 150 | time.sleep(1) 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--robot_name', default='shadowhand', type=str) 156 | parser.add_argument('--controller', action='store_true') 157 | args = parser.parse_args() 158 | 159 | if args.controller: 160 | vis_controller_result(args.robot_name) 161 | else: 162 | vis_hand_direction(args.robot_name) 163 | -------------------------------------------------------------------------------- /visualization/vis_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import time 6 | import trimesh 7 | import torch 8 | import viser 9 | from utils.hand_model import create_hand_model 10 | 11 | filtered = True 12 | 13 | robot_names = ['allegro', 'barrett', 'ezgripper', 'robotiq_3finger', 'shadowhand'] 14 | object_names = [ 15 | 'contactdb+alarm_clock', 'contactdb+apple', 'contactdb+banana', 'contactdb+binoculars', 'contactdb+camera', 16 | 'contactdb+cell_phone', 'contactdb+cube_large', 'contactdb+cube_medium', 'contactdb+cube_small', 17 | 'contactdb+cylinder_large', 'contactdb+cylinder_medium', 'contactdb+cylinder_small', 'contactdb+door_knob', 18 | 'contactdb+elephant', 'contactdb+flashlight', 'contactdb+hammer', 'contactdb+light_bulb', 'contactdb+mouse', 19 | 'contactdb+piggy_bank', 'contactdb+ps_controller', 'contactdb+pyramid_large', 'contactdb+pyramid_medium', 20 | 'contactdb+pyramid_small', 'contactdb+rubber_duck', 'contactdb+stanford_bunny', 'contactdb+stapler', 21 | 'contactdb+toothpaste', 'contactdb+torus_large', 'contactdb+torus_medium', 'contactdb+torus_small', 22 | 'contactdb+train', 'contactdb+water_bottle', 'ycb+baseball', 'ycb+bleach_cleanser', 'ycb+cracker_box', 23 | 'ycb+foam_brick', 'ycb+gelatin_box', 'ycb+hammer', 'ycb+lemon', 'ycb+master_chef_can', 'ycb+mini_soccer_ball', 24 | 'ycb+mustard_bottle', 'ycb+orange', 'ycb+peach', 'ycb+pear', 'ycb+pitcher_base', 'ycb+plum', 'ycb+potted_meat_can', 25 | 'ycb+power_drill', 'ycb+pudding_box', 'ycb+rubiks_cube', 'ycb+sponge', 'ycb+strawberry', 'ycb+sugar_box', 26 | 'ycb+tomato_soup_can', 'ycb+toy_airplane', 'ycb+tuna_fish_can', 'ycb+wood_block' 27 | ] 28 | 29 | if filtered: 30 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 31 | else: 32 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset/cmap_dataset.pt') 33 | metadata = torch.load(dataset_path, map_location=torch.device('cpu'))['metadata'] 34 | 35 | def on_update(robot_idx, object_idx, grasp_idx): 36 | robot_name = robot_names[robot_idx] 37 | object_name = object_names[object_idx] 38 | if filtered: 39 | metadata_curr = [m[0] for m in metadata if m[1] == object_name and m[2] == robot_name] 40 | else: 41 | metadata_curr = [m[1] for m in metadata if m[2] == object_name and m[3] == robot_name] 42 | if len(metadata_curr) == 0: 43 | print('No metadata found!') 44 | return 45 | q = metadata_curr[grasp_idx % len(metadata_curr)] 46 | print(f"joint values: {q}") 47 | 48 | name = object_name.split('+') 49 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') # visual mesh 50 | # object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/coacd_allinone.obj') # collision mesh 51 | object_trimesh = trimesh.load_mesh(object_path) 52 | server.scene.add_mesh_simple( 53 | 'object', 54 | object_trimesh.vertices, 55 | object_trimesh.faces, 56 | color=(239, 132, 167), 57 | opacity=1 58 | ) 59 | 60 | hand = create_hand_model(robot_name) 61 | robot_trimesh = hand.get_trimesh_q(q)["visual"] 62 | server.scene.add_mesh_simple( 63 | 'robot', 64 | robot_trimesh.vertices, 65 | robot_trimesh.faces, 66 | color=(102, 192, 255), 67 | opacity=0.8 68 | ) 69 | 70 | server = viser.ViserServer(host='127.0.0.1', port=8080) 71 | 72 | robot_slider = server.gui.add_slider( 73 | label='robot', 74 | min=0, 75 | max=len(robot_names) - 1, 76 | step=1, 77 | initial_value=0 78 | ) 79 | object_slider = server.gui.add_slider( 80 | label='object', 81 | min=0, 82 | max=len(object_names) - 1, 83 | step=1, 84 | initial_value=0 85 | ) 86 | grasp_slider = server.gui.add_slider( 87 | label='grasp', 88 | min=0, 89 | max=199, 90 | step=1, 91 | initial_value=0 92 | ) 93 | robot_slider.on_update(lambda _: on_update(robot_slider.value, object_slider.value, grasp_slider.value)) 94 | object_slider.on_update(lambda _: on_update(robot_slider.value, object_slider.value, grasp_slider.value)) 95 | grasp_slider.on_update(lambda _: on_update(robot_slider.value, object_slider.value, grasp_slider.value)) 96 | 97 | while True: 98 | time.sleep(1) 99 | -------------------------------------------------------------------------------- /visualization/vis_hand_joint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualizes hand joint motion within joint range (upper & lower limits). 3 | """ 4 | 5 | import os 6 | import sys 7 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(ROOT_DIR) 9 | import time 10 | import argparse 11 | import torch 12 | import viser 13 | from utils.hand_model import create_hand_model 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--robot_name', type=str, default='shadowhand') 17 | args = parser.parse_args() 18 | robot_name = args.robot_name 19 | 20 | hand = create_hand_model(robot_name) 21 | pk_chain = hand.pk_chain 22 | lower, upper = pk_chain.get_joint_limits() 23 | 24 | server = viser.ViserServer(host='127.0.0.1', port=8080) 25 | 26 | canonical_trimesh = hand.get_trimesh_q(hand.get_canonical_q())["visual"] 27 | server.scene.add_mesh_simple( 28 | robot_name, 29 | canonical_trimesh.vertices, 30 | canonical_trimesh.faces, 31 | color=(102, 192, 255), 32 | opacity=0.8 33 | ) 34 | 35 | def update(q): 36 | trimesh = hand.get_trimesh_q(q)["visual"] 37 | server.scene.add_mesh_simple( 38 | robot_name, 39 | trimesh.vertices, 40 | trimesh.faces, 41 | color=(102, 192, 255), 42 | opacity=0.8 43 | ) 44 | 45 | gui_joints = [] 46 | for i, joint_name in enumerate(hand.get_joint_orders()): 47 | slider = server.gui.add_slider( 48 | label=joint_name, 49 | min=round(lower[i], 2), 50 | max=round(upper[i], 2), 51 | step=(upper[i] - lower[i]) / 100, 52 | initial_value=0 if i < 6 else lower[i] * 0.75 + upper[i] * 0.25, 53 | ) 54 | slider.on_update(lambda _: update(torch.tensor([gui.value for gui in gui_joints]))) 55 | gui_joints.append(slider) 56 | 57 | while True: 58 | time.sleep(1) 59 | -------------------------------------------------------------------------------- /visualization/vis_hand_link.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize hand links to remove abundant links in removed_links.json. 3 | """ 4 | 5 | import os 6 | import sys 7 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(ROOT_DIR) 9 | import time 10 | import argparse 11 | import viser 12 | from utils.hand_model import create_hand_model 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--robot_name', type=str, default='shadowhand') 16 | args = parser.parse_args() 17 | robot_name = args.robot_name 18 | 19 | hand = create_hand_model(robot_name) 20 | meshes = hand.get_trimesh_q(hand.get_canonical_q())['parts'] 21 | 22 | server = viser.ViserServer(host='127.0.0.1', port=8080) 23 | 24 | for name, mesh in meshes.items(): 25 | server.scene.add_mesh_simple( 26 | name, 27 | mesh.vertices, 28 | mesh.faces, 29 | color=(102, 192, 255), 30 | opacity=0.8 31 | ) 32 | 33 | while True: 34 | time.sleep(1) 35 | -------------------------------------------------------------------------------- /visualization/vis_optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import warnings 6 | import time 7 | import random 8 | import argparse 9 | import viser 10 | import torch 11 | 12 | from utils.hand_model import create_hand_model 13 | from utils.optimization import * 14 | from utils.se3_transform import compute_link_pose 15 | 16 | 17 | def main(robot_name): 18 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 19 | metadata = torch.load(dataset_path, map_location=torch.device('cpu'))['metadata'] 20 | metadata = [m for m in metadata if m[2] == robot_name] 21 | q = random.choice(metadata)[0] 22 | 23 | hand = create_hand_model(robot_name, device='cpu') 24 | initial_q = hand.get_initial_q(q) 25 | pc_initial = hand.get_transformed_links_pc(initial_q)[:, :3] 26 | pc_target = hand.get_transformed_links_pc(q)[:, :3] 27 | 28 | transform, _ = compute_link_pose(hand.links_pc, pc_target.unsqueeze(0), is_train=False) 29 | optim_transform = process_transform(hand.pk_chain, transform) 30 | layer = create_problem(hand.pk_chain, optim_transform.keys()) 31 | predict_q = optimization(hand.pk_chain, layer, initial_q.unsqueeze(0), optim_transform)[0] 32 | pc_optimize = hand.get_transformed_links_pc(predict_q)[:, :3] 33 | 34 | server = viser.ViserServer(host='127.0.0.1', port=8080) 35 | 36 | server.scene.add_point_cloud( 37 | 'pc_initial', 38 | pc_initial.numpy(), 39 | point_size=0.001, 40 | point_shape="circle", 41 | colors=(102, 192, 255), 42 | visible=False 43 | ) 44 | 45 | server.scene.add_point_cloud( 46 | 'pc_optimize', 47 | pc_optimize.numpy(), 48 | point_size=0.001, 49 | point_shape="circle", 50 | colors=(0, 0, 200) 51 | ) 52 | 53 | server.scene.add_point_cloud( 54 | 'pc_target', 55 | pc_target.numpy(), 56 | point_size=0.001, 57 | point_shape="circle", 58 | colors=(200, 0, 0) 59 | ) 60 | 61 | while True: 62 | time.sleep(1) 63 | 64 | 65 | if __name__ == '__main__': 66 | warnings.simplefilter(action='ignore', category=FutureWarning) 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--robot_name', default='shadowhand', type=str) 69 | args = parser.parse_args() 70 | 71 | main(args.robot_name) 72 | -------------------------------------------------------------------------------- /visualization/vis_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import time 5 | import viser 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from model.network import create_encoder_network 13 | from data_utils.CMapDataset import CMapDataset 14 | from utils.pretrain_utils import dist2weight, infonce_loss 15 | from utils.hand_model import create_hand_model 16 | 17 | 18 | def main(robot_name): 19 | encoder = create_encoder_network(emb_dim=512, pretrain='pretrain_3robots.pth') 20 | 21 | dataset = CMapDataset( 22 | batch_size=1, 23 | robot_names=[robot_name], 24 | is_train=True, 25 | debug_object_names=None 26 | ) 27 | data = dataset[0] 28 | q_1 = data['initial_q'][0] 29 | q_2 = data['target_q'][0] 30 | pc_1 = data['robot_pc_initial'] 31 | pc_2 = data['robot_pc_target'] 32 | 33 | pc_1 = pc_1 - pc_1.mean(dim=1, keepdims=True) 34 | pc_2 = pc_2 - pc_2.mean(dim=1, keepdims=True) 35 | 36 | emb_1 = encoder(pc_1).detach() 37 | emb_2 = encoder(pc_2).detach() 38 | 39 | weight = dist2weight(pc_1, func=lambda x: torch.tanh(10 * x)) 40 | loss, similarity = infonce_loss( 41 | emb_1, emb_2, weights=weight, temperature=0.1 42 | ) 43 | 44 | match_idx = torch.argmax(similarity[0], dim=0) 45 | 46 | # offset for clearer visualization result 47 | offset = torch.tensor([0, 0.3, 0]) 48 | vis_pc_1 = data['robot_pc_initial'][0] 49 | vis_pc_2 = data['robot_pc_target'][0] + offset 50 | q_2[:3] += offset 51 | 52 | # match_tgt = vis_pc_2[match_idx] 53 | # match_vec = match_tgt - vis_pc_1 54 | 55 | server = viser.ViserServer(host='127.0.0.1', port=8080) 56 | 57 | z_values = vis_pc_1[:, 1] 58 | z_normalized = (z_values - z_values.min()) / (z_values.max() - z_values.min()) 59 | cmap = plt.get_cmap('rainbow') 60 | initial_colors = cmap(z_normalized)[:, :3] 61 | target_colors = initial_colors[match_idx] 62 | 63 | server.scene.add_point_cloud( 64 | 'initial pc', 65 | vis_pc_1[:, :3].numpy(), 66 | point_size=0.002, 67 | point_shape="circle", 68 | colors=initial_colors 69 | ) 70 | 71 | server.scene.add_point_cloud( 72 | 'target pc', 73 | vis_pc_2[:, :3].numpy(), 74 | point_size=0.002, 75 | point_shape="circle", 76 | colors=target_colors 77 | ) 78 | 79 | hand = create_hand_model(robot_name) 80 | 81 | robot_trimesh = hand.get_trimesh_q(q_1)["visual"] 82 | server.scene.add_mesh_simple( 83 | 'robot_initial', 84 | robot_trimesh.vertices, 85 | robot_trimesh.faces, 86 | color=(102, 192, 255), 87 | opacity=0.2 88 | ) 89 | 90 | robot_trimesh = hand.get_trimesh_q(q_2)["visual"] 91 | server.scene.add_mesh_simple( 92 | 'robot_target', 93 | robot_trimesh.vertices, 94 | robot_trimesh.faces, 95 | color=(102, 192, 255), 96 | opacity=0.2 97 | ) 98 | 99 | while True: 100 | time.sleep(1) 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--robot_name', type=str, default='shadowhand') 106 | args = parser.parse_args() 107 | 108 | main(args.robot_name) 109 | -------------------------------------------------------------------------------- /visualization/vis_validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation visualization results will be saved in the 'vis_info/' folder. 3 | This code is used to visualize the saved information. 4 | """ 5 | 6 | import os 7 | import sys 8 | import time 9 | import viser 10 | import trimesh 11 | import torch 12 | 13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | sys.path.append(ROOT_DIR) 15 | 16 | from utils.hand_model import create_hand_model 17 | 18 | 19 | def main(): 20 | # substitute your filename here, which should be automatically saved in vis_info/ by validation.py 21 | file_name = 'vis_info/3robots_epoch10.pt' 22 | vis_info = torch.load(os.path.join(ROOT_DIR, file_name), map_location='cpu') 23 | 24 | def on_update(idx): 25 | invalid = True 26 | for info in vis_info: 27 | if idx >= info['predict_q'].shape[0]: 28 | idx -= info['predict_q'].shape[0] 29 | else: 30 | invalid = False 31 | break 32 | if invalid: 33 | print('Invalid index!') 34 | return 35 | 36 | print(info['robot_name'], info['object_name'], idx) 37 | print('result:', info['success'][idx]) 38 | 39 | object_name = info['object_name'].split('+') 40 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{object_name[0]}/{object_name[1]}/{object_name[1]}.stl') 41 | object_trimesh = trimesh.load_mesh(object_path) 42 | server.scene.add_mesh_simple( 43 | 'object', 44 | object_trimesh.vertices, 45 | object_trimesh.faces, 46 | color=(239, 132, 167), 47 | opacity=0.8 48 | ) 49 | 50 | server.scene.add_point_cloud( 51 | 'object_pc', 52 | info['object_pc'][idx].numpy(), 53 | point_size=0.0008, 54 | point_shape="circle", 55 | colors=(255, 0, 0), 56 | visible=False 57 | ) 58 | 59 | server.scene.add_point_cloud( 60 | 'mlat_pc', 61 | info['mlat_pc'][idx].numpy(), 62 | point_size=0.001, 63 | point_shape="circle", 64 | colors=(0, 0, 200), 65 | visible=False 66 | ) 67 | 68 | hand = create_hand_model(info['robot_name']) 69 | 70 | robot_transform_trimesh = hand.get_trimesh_se3(info['predict_transform'], idx) 71 | server.scene.add_mesh_trimesh('transform', robot_transform_trimesh, visible=False) 72 | 73 | robot_trimesh = hand.get_trimesh_q(info['predict_q'][idx])['visual'] 74 | server.scene.add_mesh_simple( 75 | 'robot_predict', 76 | robot_trimesh.vertices, 77 | robot_trimesh.faces, 78 | color=(102, 192, 255), 79 | opacity=0.8, 80 | visible=False 81 | ) 82 | 83 | robot_trimesh = hand.get_trimesh_q(info['isaac_q'][idx])['visual'] 84 | server.scene.add_mesh_simple( 85 | 'robot_isaac', 86 | robot_trimesh.vertices, 87 | robot_trimesh.faces, 88 | color=(102, 192, 255), 89 | opacity=0.8 90 | ) 91 | 92 | server = viser.ViserServer(host='127.0.0.1', port=8080) 93 | 94 | grasp_num = 0 95 | for info in vis_info: 96 | grasp_num += info['predict_q'].shape[0] 97 | 98 | slider = server.gui.add_slider( 99 | label='grasp_idx', 100 | min=0, 101 | max=grasp_num, 102 | step=1, 103 | initial_value=0 104 | ) 105 | slider.on_update(lambda _: on_update(slider.value)) 106 | 107 | while True: 108 | time.sleep(1) 109 | 110 | if __name__ == '__main__': 111 | main() 112 | --------------------------------------------------------------------------------