├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config.py ├── create_gnn_dataset.py ├── generate_grasps_for_all_objs.py ├── generate_grasps_for_obj.py ├── models ├── geomatch.py ├── gnn.py └── mlp.py ├── train.py ├── utils ├── general_utils.py ├── gnn_utils.py ├── gripper_utils.py └── math_utils.py ├── utils_data ├── augmentors.py └── gnn_dataset.py └── visualize_geomatch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeoMatch: Geometry Matching for Multi-Embodiment Grasping 2 | 3 | While significant progress has been made on the problem of generating grasps, many existing learning-based approaches still concentrate on a single embodiment, provide limited generalization to higher DoF end-effectors and cannot capture a diverse set of grasp modes. In this paper, we tackle the problem of grasping multi-embodiments through the viewpoint of learning rich geometric representations for both objects and end-effectors using Graph Neural Networks (GNN). Our novel method - GeoMatch - applies supervised learning on grasping data from multiple embodiments, learning end-to-end contact point likelihood maps as well as conditional autoregressive prediction of grasps keypoint-by-keypoint. We compare our method against 3 baselines that provide multi-embodiment support. Our approach performs better across 3 end-effectors, while also providing competitive diversity of grasps. Examples can be found at geomatch.github.io. 4 | 5 | This is source code for the paper: [Geometry Matching for Multi-Embodiment Grasping](https://arxiv.org/abs/2312.03864). 6 | 7 | ## Installation 8 | 9 | To get started, creating an Anaconda or virtual environment is recommended. 10 | 11 | This repository was developed on pytorch 1.13 among other dependencies: 12 | 13 | ```pip install torch==1.13.1 pytorch-kinematics matplotlib transforms3d numpy scipy plotly trimesh urdf_parser_py tqdm argparse``` 14 | 15 | For this work, we used the data from [GenDexGrasp: Generalizable Dexterous Grasping](https://github.com/tengyu-liu/GenDexGrasp/tree/main). Please follow instructions on this link to download. 16 | 17 | ## Usage 18 | 19 | To train our model, run: 20 | 21 | ```python3 train.py --epochs=XXX --batch_size=YYY``` 22 | 23 | To generate grasps for a given object and all end-effectors, run: 24 | 25 | ```python3 generate_grasps_for_obj.py --saved_model_dir= --object_name=``` 26 | 27 | Optionally, you can plot grasps as they're generated by passing in `--plot_grasps` to the command above. 28 | 29 | To generate grasps for all objects of the eval set and all end-effectors, run: 30 | 31 | ```python3 generate_grasps_for_all_objs.py --saved_model_dir=``` 32 | 33 | 34 | ## Citing this work 35 | 36 | If you liked and used our repository, please cite us: 37 | 38 | ``` 39 | @inproceedings{attarian2023geometry, 40 | title={Geometry Matching for Multi-Embodiment Grasping}, 41 | author={Attarian, Maria and Asif, Muhammad Adil and Liu, Jingzhou and Hari, Ruthrash and Garg, Animesh and Gilitschenski, Igor and Tompson, Jonathan}, 42 | booktitle={Proceedings of the 7th Conference on Robot Learning (CoRL)}, 43 | year={2023} 44 | } 45 | ``` 46 | 47 | ## License and disclaimer 48 | 49 | Copyright 2023 DeepMind Technologies Limited 50 | 51 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 52 | you may not use this file except in compliance with the Apache 2.0 license. 53 | You may obtain a copy of the Apache 2.0 license at: 54 | https://www.apache.org/licenses/LICENSE-2.0 55 | 56 | All other materials are licensed under the Creative Commons Attribution 4.0 57 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 58 | https://creativecommons.org/licenses/by/4.0/legalcode 59 | 60 | Unless required by applicable law or agreed to in writing, all software and 61 | materials distributed here under the Apache 2.0 or CC-BY licenses are 62 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 63 | either express or implied. See the licenses for the specific language governing 64 | permissions and limitations under those licenses. 65 | 66 | This is not an official Google product. 67 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Config file with parameters.""" 17 | 18 | obj_pc_n = 2048 19 | robot_pc_n = 6 20 | keypoint_n = 6 21 | 22 | hidden_n = 256 23 | obj_in_feats = 3 24 | robot_in_feats = obj_in_feats 25 | obj_out_feats = 512 26 | robot_out_feats = obj_out_feats 27 | robot_weighting = 500.0 28 | matchnet_weighting = 200.0 29 | num_hidden = 3 30 | 31 | dist_loss_weight = 0.5 32 | match_loss_weight = 0.5 33 | -------------------------------------------------------------------------------- /create_gnn_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Script that preprocesses the data into the format expected by GeoMatch.""" 17 | 18 | import argparse 19 | import os 20 | import numpy as np 21 | import torch 22 | import torch.utils.data 23 | from tqdm import tqdm 24 | import trimesh as tm 25 | from utils.general_utils import get_handmodel 26 | from utils.gnn_utils import euclidean_min_dist 27 | from utils.gnn_utils import generate_adj_mat_feats 28 | from utils.gnn_utils import generate_contact_maps 29 | 30 | device = 'cpu' 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--dataset_basedir', type=str, default='/data/grasp_gnn') 35 | parser.add_argument( 36 | '--object_mesh_basedir', type=str, default='/data/grasp_gnn/object' 37 | ) 38 | parser.add_argument('--gnn_dataset_basedir', type=str, default='data') 39 | 40 | args = parser.parse_args() 41 | 42 | cmap_data = torch.load( 43 | os.path.join( 44 | args.dataset_basedir, 'CMapDataset-sqrt_align/cmap_dataset.pt' 45 | ), 46 | map_location=torch.device('cpu'), 47 | ) 48 | object_data = torch.load( 49 | os.path.join( 50 | args.dataset_basedir, 'CMapDataset-sqrt_align/object_point_clouds.pt' 51 | ) 52 | ) 53 | 54 | robot_name_list = [ 55 | 'ezgripper', 56 | 'barrett', 57 | 'robotiq_3finger', 58 | 'allegro', 59 | 'shadowhand', 60 | ] 61 | hand_model = {} 62 | robot_key_point_idx = {} 63 | surface_points_per_robot = {} 64 | threshold = 0.04 65 | 66 | data_dict = {} 67 | for robot_name in tqdm(robot_name_list): 68 | hand_model[robot_name] = get_handmodel( 69 | robot_name, 1, 'cpu', 1.0, data_dir=args.dataset_basedir 70 | ) 71 | joint_lower = np.array( 72 | hand_model[robot_name].revolute_joints_q_lower.cpu().reshape(-1) 73 | ) 74 | joint_upper = np.array( 75 | hand_model[robot_name].revolute_joints_q_upper.cpu().reshape(-1) 76 | ) 77 | joint_mid = (joint_lower + joint_upper) / 2 78 | joints_q = (joint_mid + joint_lower) / 2 79 | rest_pose = ( 80 | torch.from_numpy( 81 | np.concatenate([np.array([0, 0, 0, 1, 0, 0, 0, 1, 0]), joints_q]) 82 | ) 83 | .unsqueeze(0) 84 | .to(device) 85 | .float() 86 | ) 87 | surface_points = ( 88 | hand_model[robot_name] 89 | .get_surface_points(rest_pose, downsample=True) 90 | .cpu() 91 | .squeeze(0) 92 | ) 93 | key_points, key_point_idx_dict, surface_sample_kp_idx = hand_model[ 94 | robot_name 95 | ].get_static_key_points(rest_pose, surface_points) 96 | robot_key_point_idx[robot_name] = key_point_idx_dict 97 | surface_points_per_robot[robot_name] = surface_points 98 | robot_adj, robot_features = generate_adj_mat_feats(surface_points, knn=8) 99 | data_dict[robot_name] = ( 100 | robot_adj, 101 | robot_features, 102 | rest_pose, 103 | surface_sample_kp_idx, 104 | key_point_idx_dict, 105 | robot_name, 106 | ) 107 | 108 | torch.save( 109 | data_dict, 110 | os.path.join( 111 | args.gnn_dataset_basedir, 'gnn_robot_adj_point_clouds_new.pt' 112 | ), 113 | ) 114 | 115 | data_dict = {} 116 | for obj_name in tqdm(object_data): 117 | object_mesh_path = os.path.join( 118 | args.object_mesh_basedir, 119 | f'{obj_name.split("+")[0]}', 120 | f'{obj_name.split("+")[1]}', 121 | f'{obj_name.split("+")[1]}.stl', 122 | ) 123 | obj_point_cloud = object_data[obj_name] 124 | obj_mesh = tm.load(object_mesh_path) 125 | normals = [] 126 | 127 | for p in obj_point_cloud: 128 | dist, indices = euclidean_min_dist(p, obj_mesh.vertices) 129 | normals.append(obj_mesh.vertex_normals[indices[0]]) 130 | 131 | normals = np.stack(normals, axis=0) 132 | 133 | obj_adj, obj_features = generate_adj_mat_feats(obj_point_cloud, knn=8) 134 | data_dict[obj_name] = (obj_adj, obj_features, torch.tensor(normals)) 135 | 136 | torch.save( 137 | data_dict, 138 | os.path.join(args.gnn_dataset_basedir, 'gnn_obj_adj_point_clouds_new.pt'), 139 | ) 140 | 141 | data_list = [] 142 | for metadata in tqdm(cmap_data['metadata']): 143 | _, q, object_name, robot_name = metadata 144 | q = q.unsqueeze(0) 145 | obj_point_cloud = object_data[object_name] 146 | 147 | robot_grasp_kps, _, _ = hand_model[robot_name].get_static_key_points( 148 | q, surface_points_per_robot[robot_name] 149 | ) 150 | obj_contact_map = np.zeros((6, obj_point_cloud.shape[0])) 151 | full_obj_contact_map = np.zeros((obj_point_cloud.shape[0], 1)) 152 | 153 | point_dists_idxs = [ 154 | euclidean_min_dist(x, obj_point_cloud) for x in robot_grasp_kps 155 | ] 156 | robot_contact_map = np.array( 157 | [int(x[0] < threshold) for x in point_dists_idxs] 158 | ).reshape(-1, 1) 159 | top_obj_contact_kps = torch.stack( 160 | [obj_point_cloud[x[1][0]] for x in point_dists_idxs], dim=0 161 | ) 162 | top_obj_contact_verts = torch.tensor( 163 | [x[1][0] for x in point_dists_idxs] 164 | ).long() 165 | 166 | for i in range(top_obj_contact_verts.shape[0]): 167 | obj_contact_map[i, point_dists_idxs[i][1][:20]] = 1 168 | full_obj_contact_map[point_dists_idxs[i][1][:20]] = 1 169 | 170 | obj_contacts = generate_contact_maps(obj_contact_map) 171 | robot_contacts = generate_contact_maps(robot_contact_map) 172 | 173 | data_point = ( 174 | obj_contacts, 175 | robot_contacts, 176 | top_obj_contact_kps, 177 | top_obj_contact_verts, 178 | full_obj_contact_map, 179 | q.squeeze(0), 180 | object_name, 181 | robot_name, 182 | ) 183 | data_list.append(data_point) 184 | 185 | data_dict = { 186 | 'info': [ 187 | 'obj_contacts', 188 | 'robot_contacts', 189 | 'top_obj_contact_kps', 190 | 'top_obj_contact_verts', 191 | 'full_obj_contact_map', 192 | 'q', 193 | 'object_name', 194 | 'robot_name', 195 | ], 196 | 'metadata': data_list, 197 | } 198 | torch.save( 199 | data_dict, 200 | os.path.join( 201 | args.gnn_dataset_basedir, 'gnn_obj_cmap_robot_cmap_adj_new.pt' 202 | ), 203 | ) 204 | -------------------------------------------------------------------------------- /generate_grasps_for_all_objs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Batch script that generates grasps for all eval objects and all end-effectors.""" 17 | 18 | import argparse 19 | import json 20 | import os 21 | import subprocess 22 | import numpy as np 23 | import torch 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--seed', type=int, default=0, help='Random seed.') 29 | parser.add_argument( 30 | '--data_dir', 31 | type=str, 32 | default='/data/grasp_gnn', 33 | help='Base data directory.', 34 | ) 35 | parser.add_argument( 36 | '--saved_model_dir', 37 | type=str, 38 | default=( 39 | 'logs_train/exp-pos_weight_500_200_6_kps_final-1683209255.5473607/' 40 | ), 41 | ) 42 | parser.add_argument('--output_dir', type=str, default='logs_out_grasps/') 43 | 44 | args = parser.parse_args() 45 | 46 | np.random.seed(args.seed) 47 | torch.manual_seed(args.seed) 48 | 49 | data_dir = args.data_dir 50 | print(f'Saved model dir: {args.saved_model_dir}') 51 | 52 | object_list = json.load( 53 | open( 54 | os.path.join( 55 | data_dir, 56 | 'CMapDataset-sqrt_align/split_train_validate_objects.json', 57 | ), 58 | 'rb', 59 | ) 60 | )['validate'] 61 | 62 | ps = [ 63 | subprocess.Popen( 64 | [ 65 | 'python', 66 | 'generate_grasps_for_obj.py', 67 | '--object_name', 68 | object_name, 69 | '--data_dir', 70 | data_dir, 71 | '--saved_model_dir', 72 | args.saved_model_dir, 73 | '--output_dir', 74 | args.output_dir, 75 | ], 76 | stdout=subprocess.PIPE, 77 | ) 78 | for object_name in object_list 79 | ] 80 | 81 | exit_codes = [p.wait() for p in ps] 82 | 83 | finished_proc_inds = [i for i, p in enumerate(ps) if p.poll() is not None] 84 | 85 | print(f'Exit codes of all processes: {exit_codes}') 86 | print(f'All processes finished?: {finished_proc_inds}') 87 | -------------------------------------------------------------------------------- /generate_grasps_for_obj.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Script that performs inference on GeoMatch to produce keypoints for a given object and all end-effectors, followed by the IK used for the paper. 17 | 18 | Note: Any other IK could be used. 19 | """ 20 | 21 | import argparse 22 | import json 23 | import os 24 | import config 25 | from models.geomatch import GeoMatch 26 | import numpy as np 27 | import plotly.graph_objects as go 28 | import scipy 29 | import torch 30 | from torch import nn 31 | import trimesh as tm 32 | from utils.general_utils import get_handmodel 33 | from utils.gnn_utils import euclidean_min_dist 34 | from utils.gnn_utils import plot_mesh 35 | from utils.gnn_utils import plot_point_cloud 36 | 37 | 38 | def compute_pose_from_rotation_matrix(t_pose, r_matrix): 39 | """Computes a 6D pose from a rotation matrix.""" 40 | 41 | batch = r_matrix.shape[0] 42 | joint_num = 2 43 | r_matrices = ( 44 | r_matrix.view(batch, 1, 3, 3) 45 | .expand(batch, joint_num, 3, 3) 46 | .contiguous() 47 | .view(batch * joint_num, 3, 3) 48 | ) 49 | src_poses = ( 50 | t_pose.view(1, joint_num, 3, 1) 51 | .expand(batch, joint_num, 3, 1) 52 | .contiguous() 53 | .view(batch * joint_num, 3, 1) 54 | ) 55 | 56 | out_poses = torch.matmul(r_matrices, src_poses.double()) 57 | 58 | return out_poses.view(batch, joint_num * 3) 59 | 60 | 61 | def rotation_matrix_from_vectors(vec1, vec2): 62 | """Returns the rotation matrix that aligns two vectors. 63 | 64 | Source: 65 | https://stackoverflow.com/questions/45142959/calculate-rotation-matrix-to-align-two-vectors-in-3d-space 66 | 67 | Args: 68 | vec1: first vector. 69 | vec2: second vector. 70 | 71 | Returns: 72 | A 3x3 rotation_matrix/ 73 | """ 74 | a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), ( 75 | vec2 / np.linalg.norm(vec2) 76 | ).reshape(3) 77 | v = np.cross(a, b) 78 | c = np.dot(a, b) 79 | s = np.linalg.norm(v) 80 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 81 | rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2)) 82 | return rotation_matrix 83 | 84 | 85 | def get_heuristic_init_pose( 86 | gripper_model, centroids, object_point_cloud, robot_point_cloud, q_rest 87 | ): 88 | """Gets an initial pose based on heuristics, to start IK.""" 89 | 90 | _, sorted_vert = euclidean_min_dist(centroids, robot_point_cloud) 91 | hand_closest_point = robot_point_cloud[sorted_vert[0]] 92 | center_mass = object_point_cloud.mean(dim=0) 93 | 94 | _, sorted_vert_obj = euclidean_min_dist(centroids, object_point_cloud) 95 | obj_closest_point = object_point_cloud[sorted_vert_obj[0]] 96 | 97 | hand_vec_normal = hand_closest_point - torch.tensor(centroids) 98 | obj_vec_normal = obj_closest_point - center_mass 99 | 100 | rot_mat = rotation_matrix_from_vectors(hand_vec_normal, obj_vec_normal) 101 | 102 | new_q_rot = compute_pose_from_rotation_matrix( 103 | q_rest.squeeze()[3:9], torch.tensor(rot_mat[None]) 104 | ) 105 | 106 | new_hand_closest = np.matmul(rot_mat, hand_closest_point) 107 | trans = obj_closest_point - new_hand_closest 108 | 109 | gripper_model.global_rotation = torch.tensor(rot_mat[None]) 110 | gripper_model.global_translation = trans[None] 111 | 112 | heuristic_q = np.concatenate( 113 | (trans, new_q_rot.squeeze(), q_rest.squeeze()[9:]) 114 | ) 115 | heuristic_q = torch.tensor(heuristic_q).float() 116 | 117 | return heuristic_q 118 | 119 | 120 | def autoregressive_inference( 121 | contact_map_pred, 122 | match_model, 123 | top_k, 124 | obj_pc, 125 | robot_embed, 126 | obj_embed, 127 | ): 128 | """Performs the autoregressive inference part of GeoMatch.""" 129 | 130 | max_topk = max(top_k) 131 | 132 | with torch.no_grad(): 133 | max_per_kp = torch.topk(contact_map_pred, k=(max_topk + 1), dim=1) 134 | all_grasps = [] 135 | 136 | for k in top_k: 137 | pred_curr = None 138 | grasp_points = [] 139 | contact_or_not = [] 140 | 141 | obj_proj_embed = match_model.obj_proj(obj_embed) 142 | robot_proj_embed = match_model.robot_proj(robot_embed) 143 | 144 | for i_prev in range(config.keypoint_n - 1): 145 | model_kp = match_model.kp_ar_model_1 146 | 147 | if i_prev == 1: 148 | model_kp = match_model.kp_ar_model_2 149 | elif i_prev == 2: 150 | model_kp = match_model.kp_ar_model_3 151 | elif i_prev == 3: 152 | model_kp = match_model.kp_ar_model_4 153 | elif i_prev == 4: 154 | model_kp = match_model.kp_ar_model_5 155 | 156 | xyz_prev = torch.gather( 157 | obj_pc[None], 158 | 1, 159 | max_per_kp.indices[:, k, i_prev, :].repeat(1, 1, 3), 160 | ) 161 | 162 | if i_prev == 0: 163 | grasp_points.append(xyz_prev.squeeze()) 164 | contact_or_not.append(torch.tensor(1)) 165 | else: 166 | xyz_prev = torch.stack(grasp_points, dim=0)[None] 167 | 168 | pred_curr = model_kp( 169 | obj_proj_embed, obj_pc[None], robot_proj_embed, xyz_prev 170 | ) 171 | pred_prob = nn.Sigmoid()(pred_curr) 172 | vert_pred = torch.max(pred_prob[..., 0], dim=-1) 173 | min_idx = vert_pred.indices[0] 174 | contact_or_not.append(torch.tensor(int(vert_pred.values[0] >= 0.5))) 175 | 176 | pred_curr = obj_pc[min_idx] 177 | grasp_points.append(pred_curr) 178 | 179 | grasp_points = torch.stack(grasp_points, dim=0) 180 | contact_or_not = torch.stack(contact_or_not, dim=0) 181 | final_grasp_points = torch.cat( 182 | (grasp_points, contact_or_not[..., None]), dim=-1 183 | ) 184 | all_grasps.append(final_grasp_points) 185 | 186 | return torch.stack(all_grasps, dim=0) 187 | 188 | 189 | def inference( 190 | geomatch_model, 191 | top_k, 192 | obj_pc, 193 | obj_adjacency, 194 | robot_point_cloud, 195 | robot_adjacency, 196 | keypoints_idx, 197 | ): 198 | """Performs full inference for GeoMatch.""" 199 | 200 | with torch.no_grad(): 201 | obj_embed = geomatch_model.encode_embed( 202 | geomatch_model.obj_encoder, obj_pc[None], obj_adjacency[None] 203 | ) 204 | robot_embed = geomatch_model.encode_embed( 205 | geomatch_model.robot_encoder, 206 | robot_point_cloud[None], 207 | robot_adjacency[None], 208 | ) 209 | 210 | robot_feat_size = robot_embed.shape[2] 211 | keypoint_feat = torch.gather( 212 | robot_embed, 213 | 1, 214 | keypoints_idx[..., None].long().repeat(1, 1, robot_feat_size), 215 | ) 216 | contact_map_pred = torch.matmul(obj_embed, keypoint_feat.transpose(2, 1))[ 217 | ..., None 218 | ] 219 | 220 | top_obj_contact_kps_pred = autoregressive_inference( 221 | contact_map_pred, 222 | geomatch_model, 223 | top_k, 224 | obj_pc, 225 | robot_embed, 226 | obj_embed, 227 | ) 228 | pred_points = top_obj_contact_kps_pred 229 | 230 | return pred_points 231 | 232 | 233 | def inverse_kinematics_optimization(gripper_model, q_init, target_points): 234 | """Function performing IK optimization and returning a pose.""" 235 | 236 | def optimize_target(q): 237 | q = torch.tensor(q).float() 238 | source_points, _, _ = gripper_model.get_static_key_points(q.unsqueeze(0)) 239 | e = [ 240 | np.linalg.norm(source_points[i] - target_points[i]) 241 | for i in range(len(source_points)) 242 | ] 243 | return e 244 | 245 | real_bounds = [] 246 | 247 | for _ in range(3): 248 | real_bounds.append((-0.5, 0.5)) 249 | 250 | for _ in range(6): 251 | real_bounds.append((-np.pi, np.pi)) 252 | 253 | for idx, _ in enumerate(gripper_model.revolute_joints_q_lower.squeeze()): 254 | real_bounds.append(( 255 | gripper_model.revolute_joints_q_lower[:, idx].squeeze().item(), 256 | gripper_model.revolute_joints_q_upper[:, idx].squeeze().item(), 257 | )) 258 | 259 | result = scipy.optimize.least_squares( 260 | optimize_target, q_init, method='trf', bounds=tuple(zip(*real_bounds)) 261 | ) 262 | return result 263 | 264 | 265 | if __name__ == '__main__': 266 | parser = argparse.ArgumentParser() 267 | parser.add_argument('--seed', type=int, default=0, help='Random seed.') 268 | parser.add_argument( 269 | '--device', type=str, default='cpu', help='Use cuda if available' 270 | ) 271 | parser.add_argument( 272 | '--object_name', 273 | type=str, 274 | default='contactdb+rubber_duck', 275 | help='Which object to calculate grasp for.', 276 | ) 277 | parser.add_argument( 278 | '--data_dir', 279 | type=str, 280 | default='/data/grasp_gnn', 281 | help='Base data directory.', 282 | ) 283 | parser.add_argument( 284 | '--saved_model_dir', 285 | type=str, 286 | default=( 287 | 'logs_train/exp-pos_weight_500_200_6_kps_final-1683209255.5473607/' 288 | ), 289 | ) 290 | parser.add_argument('--output_dir', type=str, default='logs_out_grasps/') 291 | parser.add_argument('--plot_grasps', default=False, action='store_true') 292 | args = parser.parse_args() 293 | 294 | np.random.seed(args.seed) 295 | torch.manual_seed(args.seed) 296 | 297 | data_dir = args.data_dir 298 | output_dir = os.path.join(args.output_dir, args.object_name) 299 | os.makedirs(output_dir, exist_ok=True) 300 | 301 | object_name = args.object_name 302 | object_mesh_basedir = os.path.join(data_dir, 'object') 303 | object_mesh_path = os.path.join( 304 | object_mesh_basedir, 305 | f'{args.object_name.split("+")[0]}', 306 | f'{args.object_name.split("+")[1]}', 307 | f'{args.object_name.split("+")[1]}.stl', 308 | ) 309 | obj_mesh = tm.load(object_mesh_path) 310 | obj_normals = obj_mesh.vertex_normals 311 | plot_grasps = args.plot_grasps 312 | 313 | robot_centroids = json.load( 314 | open(os.path.join(data_dir, 'robot_centroids.json')) 315 | ) 316 | robot_list = [ 317 | 'ezgripper', 318 | 'barrett', 319 | 'shadowhand', 320 | ] 321 | top_ks = [0, 20, 50, 100] 322 | 323 | model = GeoMatch(config) 324 | model.load_state_dict( 325 | torch.load( 326 | os.path.join(args.saved_model_dir, 'weights/grasp_gnn.pth'), 327 | map_location=torch.device('cpu'), 328 | ) 329 | ) 330 | 331 | model.eval() 332 | 333 | robot_pc_adj = torch.load( 334 | os.path.join(data_dir, 'gnn_robot_adj_point_clouds_new.pt') 335 | ) 336 | 337 | object_pc_adj = torch.load( 338 | os.path.join(data_dir, 'gnn_obj_adj_point_clouds_new.pt') 339 | ) 340 | new_object_pc_adj = torch.load( 341 | os.path.join(data_dir, 'gnn_obj_adj_point_clouds_new_unseen.pt') 342 | ) 343 | object_pc_adj.update(new_object_pc_adj) 344 | 345 | object_pc = object_pc_adj[object_name][1] 346 | obj_adj = object_pc_adj[object_name][0] 347 | generated_grasps = [] 348 | 349 | for robot_name in robot_list: 350 | hand_model = get_handmodel(robot_name, 1, 'cpu', 1.0) 351 | robot_pc = robot_pc_adj[robot_name][1] 352 | robot_adj = robot_pc_adj[robot_name][0] 353 | 354 | robot_keypoints_idx = robot_pc_adj[robot_name][3] 355 | rest_pose = hand_model.rest_pose 356 | hand_centroids = robot_centroids[robot_name] 357 | 358 | q_heuristic = get_heuristic_init_pose( 359 | hand_model, hand_centroids, object_pc, robot_pc, rest_pose 360 | ) 361 | 362 | all_grasps_predicted_keypoints = inference( 363 | model, 364 | top_ks, 365 | object_pc, 366 | obj_adj, 367 | robot_pc, 368 | robot_adj, 369 | robot_keypoints_idx, 370 | ) 371 | 372 | for i in range(all_grasps_predicted_keypoints.shape[0]): 373 | predicted_keypoints = all_grasps_predicted_keypoints[i] 374 | closest_mesh_idxs = [ 375 | euclidean_min_dist(x, obj_mesh.vertices)[1][0] 376 | for x in predicted_keypoints[:, :3] 377 | ] 378 | closest_normals = obj_normals[closest_mesh_idxs] 379 | 380 | pregrasp_pred_keypoints = predicted_keypoints[ 381 | :, :3 382 | ] + 0.005 * closest_normals.astype('float32') 383 | 384 | res = inverse_kinematics_optimization( 385 | hand_model, q_heuristic, pregrasp_pred_keypoints 386 | ) 387 | q_calc = res.x 388 | q_calc = torch.tensor(q_calc).float() 389 | 390 | calc_key_points, _, _ = hand_model.get_static_key_points( 391 | q_calc.unsqueeze(0) 392 | ) 393 | 394 | sample = { 395 | 'object_name': object_name, 396 | 'robot_name': robot_name, 397 | 'pred_keypoints': predicted_keypoints, 398 | 'robot_final_keypoints': calc_key_points, 399 | 'init_pose': q_heuristic, 400 | 'pred_grasp_pose': q_calc, 401 | 'sample_idx': i, 402 | 'scale': 1.0, 403 | } 404 | 405 | if plot_grasps: 406 | data = [ 407 | plot_mesh(mesh=tm.load(object_mesh_path), opacity=1.0, color='blue') 408 | ] 409 | data += hand_model.get_plotly_data( 410 | q=q_calc.unsqueeze(0), opacity=1.0, color='pink' 411 | ) 412 | data += [plot_point_cloud(predicted_keypoints, color='green')] 413 | data += [plot_point_cloud(calc_key_points, color='purple')] 414 | 415 | fig = go.Figure(data=data) 416 | fig.show() 417 | 418 | generated_grasps.append(sample) 419 | 420 | data_dict = { 421 | 'info': [ 422 | 'object_name', 423 | 'robot_name', 424 | 'pred_keypoints', 425 | 'robot_final_keypoints', 426 | 'init_pose', 427 | 'pred_grasp_pose', 428 | 'sample_idx', 429 | ], 430 | 'metadata': generated_grasps, 431 | } 432 | torch.save(data_dict, os.path.join(output_dir, 'gen_grasps.pt')) 433 | -------------------------------------------------------------------------------- /models/geomatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GeoMatch model definition.""" 17 | 18 | from models.gnn import GCN 19 | from models.mlp import MLP 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class GeoMatchARModule(nn.Module): 25 | """Autoregressive module class for GeoMatch.""" 26 | 27 | def __init__(self, config, n_kp) -> None: 28 | super().__init__() 29 | 30 | self.config = config 31 | self.n_kp = n_kp 32 | self.final_fc = MLP(128 + 3 * self.n_kp, 1, 3, 256) 33 | 34 | def forward(self, obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev): 35 | robot_i_embed = ( 36 | robot_proj_embed[:, self.n_kp][..., None] 37 | .transpose(2, 1) 38 | .repeat(1, self.config.obj_pc_n, 1) 39 | ) 40 | obj_robot_embed = torch.cat((obj_proj_embed, robot_i_embed), dim=-1) 41 | 42 | diff_xyz_tensor = [] 43 | for i in range(self.n_kp): 44 | diff_xyz = obj_pc - xyz_prev[:, i, :][..., None].transpose(2, 1) 45 | diff_xyz_tensor.append(diff_xyz) 46 | 47 | diff_xyz_tensor = torch.stack(diff_xyz_tensor, dim=-1) 48 | diff_xyz_tensor = diff_xyz_tensor.view( 49 | diff_xyz_tensor.shape[0], diff_xyz_tensor.shape[1], -1 50 | ) 51 | inp = torch.cat((obj_robot_embed, diff_xyz_tensor), dim=-1) 52 | pred_curr = self.final_fc(inp) 53 | 54 | return pred_curr 55 | 56 | def calc_loss(self, pred, label): 57 | pred = pred.view(pred.shape[0] * pred.shape[1], 1) 58 | label = label.view(label.shape[0] * label.shape[1], 1) 59 | 60 | pos_weight = torch.tensor([1000.0]).cuda() 61 | loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, label) 62 | return torch.mean(loss) 63 | 64 | 65 | class GeoMatch(nn.Module): 66 | """GeoMatch model class.""" 67 | 68 | def __init__(self, config) -> None: 69 | super().__init__() 70 | 71 | self.config = config 72 | self.n_kp = config.keypoint_n 73 | self.robot_weighting = config.robot_weighting 74 | self.match_weighting = config.matchnet_weighting 75 | self.dist_loss_weight = config.dist_loss_weight 76 | self.match_loss_weight = config.match_loss_weight 77 | 78 | self.obj_encoder = GCN( 79 | nfeat=config.obj_in_feats, 80 | nhid=config.hidden_n, 81 | nout=config.obj_out_feats, 82 | dropout=0.5, 83 | num_hidden=config.num_hidden, 84 | ) 85 | 86 | self.robot_encoder = GCN( 87 | nfeat=config.robot_in_feats, 88 | nhid=config.hidden_n, 89 | nout=config.robot_out_feats, 90 | dropout=0.5, 91 | num_hidden=config.num_hidden, 92 | ) 93 | 94 | self.obj_proj = nn.Linear(self.config.obj_out_feats, 64, bias=False) 95 | self.robot_proj = nn.Linear(self.config.robot_out_feats, 64, bias=False) 96 | self.kp_ar_model_1 = GeoMatchARModule(config, 1) 97 | self.kp_ar_model_2 = GeoMatchARModule(config, 2) 98 | self.kp_ar_model_3 = GeoMatchARModule(config, 3) 99 | self.kp_ar_model_4 = GeoMatchARModule(config, 4) 100 | self.kp_ar_model_5 = GeoMatchARModule(config, 5) 101 | 102 | def encode_embed(self, encoder, feature, adj_mat, normalize_emb=True): 103 | x = encoder(feature, adj_mat) 104 | if normalize_emb: 105 | x = x.clone() / torch.norm(x, dim=-1, keepdim=True) 106 | return x 107 | 108 | def forward( 109 | self, obj_pc, robot_pc, robot_key_point_idx, obj_adj, robot_adj, xyz_prev 110 | ): 111 | obj_embed = self.encode_embed(self.obj_encoder, obj_pc, obj_adj) 112 | robot_embed = self.encode_embed(self.robot_encoder, robot_pc, robot_adj) 113 | 114 | robot_feat_size = robot_embed.shape[2] 115 | keypoint_feat = torch.gather( 116 | robot_embed, 117 | 1, 118 | robot_key_point_idx[..., None].long().repeat(1, 1, robot_feat_size), 119 | ) 120 | contact_map_pred = torch.matmul(obj_embed, keypoint_feat.transpose(2, 1))[ 121 | ..., None 122 | ] 123 | 124 | obj_proj_embed = self.obj_proj(obj_embed) 125 | robot_proj_embed = self.robot_proj(robot_embed) 126 | 127 | output_1 = self.kp_ar_model_1( 128 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev 129 | ) 130 | output_2 = self.kp_ar_model_2( 131 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev 132 | ) 133 | output_3 = self.kp_ar_model_3( 134 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev 135 | ) 136 | output_4 = self.kp_ar_model_4( 137 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev 138 | ) 139 | output_5 = self.kp_ar_model_5( 140 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev 141 | ) 142 | 143 | output = torch.cat( 144 | (output_1, output_2, output_3, output_4, output_5), dim=-1 145 | )[..., None] 146 | 147 | return contact_map_pred, output 148 | 149 | def calc_loss(self, gt_contact_map, contact_map_pred, pred, label): 150 | flat_contact_map_pred = contact_map_pred.view( 151 | contact_map_pred.shape[0] 152 | * contact_map_pred.shape[1] 153 | * contact_map_pred.shape[2], 154 | 1, 155 | ) 156 | flat_gt_contact_map = gt_contact_map.view( 157 | gt_contact_map.shape[0] 158 | * gt_contact_map.shape[1] 159 | * gt_contact_map.shape[2], 160 | 1, 161 | ) 162 | 163 | pos_weight = torch.Tensor([self.robot_weighting]).cuda() 164 | loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)( 165 | flat_contact_map_pred, flat_gt_contact_map 166 | ) 167 | l_dist = torch.mean(loss) 168 | 169 | pos_weight = torch.tensor([self.match_weighting]).cuda() 170 | 171 | loss = [] 172 | for i in range(self.n_kp - 1): 173 | pred_i = pred[:, :, i] 174 | label_i = label[:, :, i] 175 | pred_i = pred_i.view(pred_i.shape[0] * pred_i.shape[1], 1) 176 | label_i = label_i.view(label_i.shape[0] * label_i.shape[1], 1) 177 | loss.append(nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred_i, label_i)) 178 | 179 | loss = torch.stack(loss) 180 | l_match = torch.mean(loss) 181 | 182 | return self.dist_loss_weight * l_dist + self.match_loss_weight * l_match 183 | -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implementation of Graph Convolutional Neural Networks.""" 17 | 18 | import copy 19 | import math 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | 24 | 25 | def clones(module, n): 26 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) 27 | 28 | 29 | class GraphConvolution(nn.Module): 30 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" 31 | 32 | def __init__(self, in_features, out_features, bias=True): 33 | super().__init__() 34 | self.in_features = in_features 35 | self.out_features = out_features 36 | self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) 37 | if bias: 38 | self.bias = nn.Parameter(torch.FloatTensor(out_features)) 39 | else: 40 | self.register_parameter('bias', None) 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | stdv = 1.0 / math.sqrt(self.weight.size(1)) 45 | self.weight.data.uniform_(-stdv, stdv) 46 | if self.bias is not None: 47 | self.bias.data.uniform_(-stdv, stdv) 48 | 49 | def forward(self, inp, adj): 50 | support = torch.matmul(inp, self.weight) 51 | output = torch.matmul(adj.to_dense(), support) 52 | if self.bias is not None: 53 | return output + self.bias 54 | else: 55 | return output 56 | 57 | def __repr__(self): 58 | return ( 59 | self.__class__.__name__ 60 | + ' (' 61 | + str(self.in_features) 62 | + ' -> ' 63 | + str(self.out_features) 64 | + ')' 65 | ) 66 | 67 | 68 | class GCN(nn.Module): 69 | """Graph Convolutional Neural Network class.""" 70 | 71 | def __init__(self, nfeat, nhid, nout, dropout, num_hidden): 72 | super().__init__() 73 | 74 | self.gc0 = GraphConvolution(nfeat, nhid) 75 | self.gc_layers = clones(GraphConvolution(nhid, nhid), num_hidden) 76 | self.out = nn.Linear(nhid, nout) 77 | self.dropout = dropout 78 | 79 | def forward(self, x, adj): 80 | x = F.relu(self.gc0(x, adj)) 81 | 82 | for i, _ in enumerate(self.gc_layers): 83 | x = F.relu(self.gc_layers[i](x, adj)) 84 | return self.out(x) 85 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implementation of a Multi-Layer Perceptron.""" 17 | 18 | import copy 19 | from torch import nn 20 | import torch.nn.functional as F 21 | 22 | 23 | def clones(module, n): 24 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) 25 | 26 | 27 | class MLP(nn.Module): 28 | """MLP class.""" 29 | 30 | def __init__(self, in_features, out_features, num_hidden, hidden_dim) -> None: 31 | super().__init__() 32 | 33 | self.layer0 = nn.Linear(in_features, hidden_dim) 34 | self.layers = clones(nn.Linear(hidden_dim, hidden_dim), num_hidden) 35 | self.out = nn.Linear(hidden_dim, out_features) 36 | 37 | def forward(self, x): 38 | x = F.relu(self.layer0(x)) 39 | 40 | for l in self.layers: 41 | x = F.relu(l(x)) 42 | 43 | return self.out(x) 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Training script for GeoMatch.""" 17 | 18 | import argparse 19 | import dataclasses 20 | import os 21 | import shutil 22 | import sys 23 | import time 24 | import config 25 | import matplotlib.pyplot as plt 26 | from models.geomatch import GeoMatch 27 | import numpy as np 28 | import torch 29 | from torch import nn 30 | from torch import optim 31 | from torch.utils.data import DataLoader 32 | from torch.utils.tensorboard import SummaryWriter 33 | from tqdm import tqdm 34 | from utils.gnn_utils import square 35 | from utils.gnn_utils import train_metrics 36 | from utils_data.gnn_dataset import GNNDataset 37 | 38 | 39 | @dataclasses.dataclass 40 | class TrainState: 41 | geomatch_model: GeoMatch 42 | train_step: int 43 | eval_step: int 44 | writer: SummaryWriter 45 | optimizer: optim.Optimizer 46 | epoch: int 47 | 48 | 49 | def train(global_args, state: TrainState, dataloader: DataLoader): 50 | """Full training function.""" 51 | 52 | state.geomatch_model.train() 53 | state.optimizer.zero_grad() 54 | 55 | loss_history = [] 56 | acc_history = [] 57 | 58 | for _, data in enumerate( 59 | tqdm(dataloader, desc=f'EPOCH[{state.epoch}/{global_args.epochs}]') 60 | ): 61 | state.train_step += 1 62 | 63 | ( 64 | obj_adj, 65 | obj_features, 66 | obj_contacts, 67 | robot_adj, 68 | robot_features, 69 | robot_key_point_idx, 70 | robot_contacts, 71 | top_obj_contact_kps, 72 | _, 73 | _, 74 | _, 75 | _, 76 | ) = data 77 | 78 | if global_args.device == 'cuda': 79 | obj_features = obj_features.cuda() 80 | obj_adj = obj_adj.cuda() 81 | obj_contacts = obj_contacts.cuda() 82 | robot_adj = robot_adj.cuda() 83 | robot_features = robot_features.cuda() 84 | robot_key_point_idx = robot_key_point_idx.cuda().long() 85 | robot_contacts = robot_contacts.cuda() 86 | top_obj_contact_kps = top_obj_contact_kps.cuda() 87 | 88 | gt_contact_map = ( 89 | (obj_contacts * robot_contacts.repeat(1, 1, config.obj_pc_n)) 90 | .transpose(2, 1)[..., None] 91 | .contiguous() 92 | ) 93 | 94 | contact_map_pred, pred_curr = state.geomatch_model( 95 | obj_features, 96 | robot_features, 97 | robot_key_point_idx, 98 | obj_adj, 99 | robot_adj, 100 | top_obj_contact_kps, 101 | ) 102 | 103 | loss_train = state.geomatch_model.calc_loss( 104 | gt_contact_map, 105 | contact_map_pred, 106 | pred_curr, 107 | gt_contact_map[:, :, 1 : config.keypoint_n, :], 108 | ) 109 | ( 110 | acc, 111 | _, 112 | _, 113 | true_positives, 114 | true_negatives, 115 | false_positives, 116 | false_negatives, 117 | ) = train_metrics(pred_curr, gt_contact_map[:, :, 1 : config.keypoint_n, :]) 118 | 119 | state.optimizer.zero_grad() 120 | loss_train.backward() 121 | 122 | nn.utils.clip_grad_value_(state.geomatch_model.parameters(), clip_value=1.0) 123 | state.optimizer.step() 124 | 125 | loss_history.append(loss_train) 126 | acc_history.append(acc) 127 | 128 | if state.train_step % 10 == 0: 129 | loss = torch.mean(torch.stack(loss_history)) 130 | square(true_positives, false_positives, true_negatives, false_negatives) 131 | precision_recall_square = plt.imread('square.jpg').transpose(2, 0, 1) 132 | 133 | state.writer.add_scalar( 134 | 'train/loss', loss.item(), global_step=state.train_step 135 | ) 136 | state.writer.add_image( 137 | 'train/precision_recall', 138 | precision_recall_square, 139 | global_step=state.train_step, 140 | ) 141 | 142 | plt.close() 143 | 144 | epoch_loss = torch.mean(torch.stack(loss_history)) 145 | epoch_accuracy = torch.mean(torch.stack(acc_history)) 146 | 147 | if state.epoch % 1 == 0: 148 | print( 149 | f'[train] loss on {state.epoch}: {epoch_loss}\n' 150 | f' accuracy: {epoch_accuracy}\n' 151 | ) 152 | 153 | 154 | def validate(global_args, state: TrainState, dataloader: DataLoader): 155 | """Full evaluation function.""" 156 | 157 | with torch.no_grad(): 158 | state.geomatch_model.eval() 159 | 160 | loss_history = [] 161 | acc_history = [] 162 | 163 | for data in tqdm( 164 | dataloader, desc=f'EPOCH[{state.epoch}/{global_args.epochs}]' 165 | ): 166 | state.eval_step += 1 167 | 168 | ( 169 | obj_adj, 170 | obj_features, 171 | obj_contacts, 172 | robot_adj, 173 | robot_features, 174 | robot_key_point_idx, 175 | robot_contacts, 176 | top_obj_contact_kps, 177 | _, 178 | _, 179 | _, 180 | _, 181 | ) = data 182 | 183 | if global_args.device == 'cuda': 184 | obj_features = obj_features.cuda() 185 | obj_adj = obj_adj.cuda() 186 | obj_contacts = obj_contacts.cuda() 187 | robot_adj = robot_adj.cuda() 188 | robot_features = robot_features.cuda() 189 | robot_key_point_idx = robot_key_point_idx.cuda().long() 190 | robot_contacts = robot_contacts.cuda() 191 | top_obj_contact_kps = top_obj_contact_kps.cuda() 192 | 193 | gt_contact_map = ( 194 | (obj_contacts * robot_contacts.repeat(1, 1, config.obj_pc_n)) 195 | .transpose(2, 1)[..., None] 196 | .contiguous() 197 | ) 198 | 199 | contact_map_pred, pred_curr = state.geomatch_model( 200 | obj_features, 201 | robot_features, 202 | robot_key_point_idx, 203 | obj_adj, 204 | robot_adj, 205 | top_obj_contact_kps, 206 | ) 207 | 208 | loss_val = state.geomatch_model.calc_loss( 209 | gt_contact_map, 210 | contact_map_pred, 211 | pred_curr, 212 | gt_contact_map[:, :, 1 : config.keypoint_n, :], 213 | ) 214 | ( 215 | acc, 216 | _, 217 | _, 218 | true_positives, 219 | true_negatives, 220 | false_positives, 221 | false_negatives, 222 | ) = train_metrics( 223 | pred_curr, gt_contact_map[:, :, 1 : config.keypoint_n, :] 224 | ) 225 | 226 | loss_history.append(loss_val) 227 | acc_history.append(acc) 228 | 229 | if state.eval_step % 10 == 0: 230 | loss = torch.mean(torch.stack(loss_history)) 231 | square(true_positives, false_positives, true_negatives, false_negatives) 232 | precision_recall_square = plt.imread('square.jpg').transpose(2, 0, 1) 233 | 234 | state.writer.add_scalar( 235 | 'validate/loss', loss.item(), global_step=state.eval_step 236 | ) 237 | state.writer.add_image( 238 | 'validate/precision_recall', 239 | precision_recall_square, 240 | global_step=state.eval_step, 241 | ) 242 | 243 | plt.close() 244 | 245 | epoch_loss = torch.mean(torch.stack(loss_history)) 246 | epoch_accuracy = torch.mean(torch.stack(acc_history)) 247 | print( 248 | f'[validate] loss: {epoch_loss}\n' 249 | f' accuracy: {epoch_accuracy}\n' 250 | ) 251 | 252 | 253 | if __name__ == '__main__': 254 | start_time = time.time() 255 | parser = argparse.ArgumentParser() 256 | parser.add_argument( 257 | '--exp_name', type=str, default='all_end_effectors_object_split' 258 | ) 259 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 260 | parser.add_argument('--batch_size', type=int, default=64, help='Random seed.') 261 | parser.add_argument( 262 | '--device', type=str, default='cuda', help='Use cuda if available' 263 | ) 264 | parser.add_argument( 265 | '--epochs', type=int, default=200, help='Number of epochs to train.' 266 | ) 267 | parser.add_argument( 268 | '--lr', type=float, default=1e-4, help='Initial learning rate.' 269 | ) 270 | parser.add_argument( 271 | '--weight_decay', 272 | type=float, 273 | default=0.0, 274 | help='Weight decay (L2 loss on parameters).', 275 | ) 276 | parser.add_argument( 277 | '--out_features', 278 | type=int, 279 | default=512, 280 | help='Number of object and end-effector feature dimension.', 281 | ) 282 | parser.add_argument( 283 | '--robot_weighting', 284 | type=int, 285 | default=500, 286 | help='Weight for full distribution BCE class loss.', 287 | ) 288 | parser.add_argument( 289 | '--matchnet_weighting', 290 | type=int, 291 | default=200, 292 | help='Weight for matching BCE class loss.', 293 | ) 294 | parser.add_argument( 295 | '--exclude_ee_list', 296 | nargs='+', 297 | help='End-effector to be excluded from the default list.', 298 | ) 299 | 300 | parser.add_argument( 301 | '--dataset_basedir', 302 | type=str, 303 | default='/data/grasp_gnn', 304 | help='Path to data.', 305 | ) 306 | 307 | args = parser.parse_args() 308 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 309 | 310 | np.random.seed(args.seed) 311 | torch.manual_seed(args.seed) 312 | 313 | if args.device == 'cuda': 314 | torch.cuda.manual_seed_all(args.seed) 315 | 316 | config.obj_out_feats = args.out_features 317 | config.robot_out_feats = config.obj_out_feats 318 | config.robot_weighting = args.robot_weighting 319 | config.matchnet_weighting = args.matchnet_weighting 320 | 321 | robot_name_list = [ 322 | 'ezgripper', 323 | 'barrett', 324 | 'robotiq_3finger', 325 | 'allegro', 326 | 'shadowhand', 327 | ] 328 | 329 | if args.exclude_ee_list is not None: 330 | for ee in args.exclude_ee_list: 331 | robot_name_list.remove(ee) 332 | 333 | log_dir = os.path.join('logs_train', f'exp-{args.exp_name}-{str(start_time)}') 334 | weight_dir = os.path.join(log_dir, 'weights') 335 | tb_dir = os.path.join(log_dir, 'tb_dir') 336 | shutil.rmtree(log_dir, ignore_errors=True) 337 | os.makedirs(log_dir, exist_ok=True) 338 | os.makedirs(weight_dir, exist_ok=True) 339 | os.makedirs(tb_dir, exist_ok=True) 340 | f = open(os.path.join(log_dir, 'command.txt'), 'w') 341 | f.write(' '.join(sys.argv)) 342 | f.close() 343 | writer = SummaryWriter(log_dir=tb_dir) 344 | 345 | dataset_basedir = args.dataset_basedir 346 | 347 | batchsize = args.batch_size 348 | train_dataset = GNNDataset( 349 | dataset_basedir=dataset_basedir, 350 | mode='train', 351 | device=args.device, 352 | robot_name_list=robot_name_list, 353 | ) 354 | train_dataloader = DataLoader( 355 | dataset=train_dataset, 356 | batch_size=batchsize, 357 | shuffle=True, 358 | num_workers=0, 359 | ) 360 | 361 | validate_dataset = GNNDataset( 362 | dataset_basedir=dataset_basedir, 363 | mode='validate', 364 | device=args.device, 365 | robot_name_list=robot_name_list, 366 | ) 367 | validate_dataloader = DataLoader( 368 | dataset=validate_dataset, 369 | batch_size=batchsize, 370 | shuffle=True, 371 | num_workers=0, 372 | ) 373 | 374 | geomatch_model = GeoMatch(config) 375 | geomatch_model = geomatch_model.to(args.device) 376 | 377 | optimizer = optim.Adam( 378 | list(geomatch_model.parameters()), 379 | lr=args.lr, 380 | weight_decay=args.weight_decay, 381 | betas=(0.9, 0.99), 382 | ) 383 | 384 | torch.save( 385 | geomatch_model.state_dict(), os.path.join(weight_dir, 'grasp_gnn.pth') 386 | ) 387 | 388 | train_step = 0 389 | val_step = 0 390 | 391 | train_state = TrainState( 392 | geomatch_model=geomatch_model, 393 | writer=writer, 394 | optimizer=optimizer, 395 | train_step=train_step, 396 | eval_step=val_step, 397 | epoch=0, 398 | ) 399 | 400 | for i_epoch in range(args.epochs): 401 | train(args, train_state, train_dataloader) 402 | validate(args, train_state, validate_dataloader) 403 | 404 | torch.save( 405 | geomatch_model.state_dict(), os.path.join(weight_dir, 'grasp_gnn.pth') 406 | ) 407 | writer.close() 408 | print(f'consuming time: {time.time() - start_time}') 409 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """General utilities.""" 17 | 18 | import json 19 | import os 20 | import random 21 | import numpy as np 22 | import torch 23 | from utils import gripper_utils 24 | 25 | 26 | def get_handmodel( 27 | robot, batch_size, device, hand_scale=1.0, data_dir='/data/grasp_gnn' 28 | ): 29 | """Fetches the hand model object for a given gripper.""" 30 | urdf_assets_meta = json.load( 31 | open(os.path.join(data_dir, 'urdf/urdf_assets_meta.json')) 32 | ) 33 | urdf_path = urdf_assets_meta['urdf_path'][robot].replace('data', data_dir) 34 | meshes_path = urdf_assets_meta['meshes_path'][robot].replace('data', data_dir) 35 | hand_model = gripper_utils.HandModel( 36 | robot, 37 | urdf_path, 38 | meshes_path, 39 | batch_size=batch_size, 40 | device=device, 41 | hand_scale=hand_scale, 42 | data_dir=data_dir, 43 | ) 44 | return hand_model 45 | 46 | 47 | def set_global_seed(seed=42): 48 | torch.cuda.manual_seed_all(seed) 49 | torch.manual_seed(seed) 50 | np.random.seed(seed) 51 | random.seed(seed) 52 | -------------------------------------------------------------------------------- /utils/gnn_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GNN utilities.""" 17 | 18 | import os 19 | import matplotlib 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | import plotly.graph_objects as go 23 | import scipy.sparse as sp 24 | from scipy.spatial import distance 25 | import torch 26 | from torch import nn 27 | import trimesh as tm 28 | 29 | colors = ['blue', 'red', 'yellow', 'pink', 'gray', 'orange'] 30 | 31 | 32 | def plot_mesh(mesh, color='lightblue', opacity=1.0): 33 | return go.Mesh3d( 34 | x=mesh.vertices[:, 0], 35 | y=mesh.vertices[:, 1], 36 | z=mesh.vertices[:, 2], 37 | i=mesh.faces[:, 0], 38 | j=mesh.faces[:, 1], 39 | k=mesh.faces[:, 2], 40 | color=color, 41 | opacity=opacity, 42 | ) 43 | 44 | 45 | def plot_point_cloud( 46 | pts, color='lightblue', mode='markers', colorscale='Viridis', size=3 47 | ): 48 | return go.Scatter3d( 49 | x=pts[:, 0], 50 | y=pts[:, 1], 51 | z=pts[:, 2], 52 | mode=mode, 53 | marker=dict(color=color, colorscale=colorscale, size=size), 54 | ) 55 | 56 | 57 | def plot_grasp( 58 | hand_model, 59 | key_points_idx_dict, 60 | q, 61 | object_name, 62 | obj_pc, 63 | obj_contact_map, 64 | selected_kp_idx, 65 | selected_vert, 66 | object_mesh_basedir='/data/grasp_gnn/object', 67 | ): 68 | """Plots a given grasp.""" 69 | 70 | robot_keypoints_trans = hand_model.get_key_points_from_indices( 71 | key_points_idx_dict, q=q 72 | ) 73 | vis_data = hand_model.get_plotly_data(q=q, opacity=0.5) 74 | object_mesh_path = os.path.join( 75 | object_mesh_basedir, 76 | f'{object_name.split("+")[0]}', 77 | f'{object_name.split("+")[1]}', 78 | f'{object_name.split("+")[1]}.stl', 79 | ) 80 | vis_data += [plot_mesh(mesh=tm.load(object_mesh_path))] 81 | vis_data += [plot_point_cloud(obj_pc, obj_contact_map.squeeze())] 82 | vis_data += [plot_point_cloud(selected_vert, color='red')] 83 | vis_data += [plot_point_cloud(robot_keypoints_trans, color='black')] 84 | vis_data += [ 85 | plot_point_cloud( 86 | robot_keypoints_trans[selected_kp_idx, :][None], color='red' 87 | ) 88 | ] 89 | fig = go.Figure(data=vis_data) 90 | fig.show() 91 | 92 | 93 | def plot_hand_only(hand_model, q, key_points, selected_kp): 94 | vis_data = hand_model.get_plotly_data(q=q, opacity=0.5) 95 | vis_data += [plot_point_cloud(key_points, color='black')] 96 | vis_data += [plot_point_cloud(selected_kp, color='red')] 97 | fig = go.Figure(data=vis_data) 98 | fig.show() 99 | 100 | 101 | def plot_obj_only(obj_pc, obj_contact_map, selected_vert): 102 | vis_data = [plot_point_cloud(obj_pc, obj_contact_map.squeeze())] 103 | vis_data += [plot_point_cloud(selected_vert, color='red')] 104 | fig = go.Figure(data=vis_data) 105 | fig.show() 106 | 107 | 108 | def encode_onehot(labels, threshold=0.5): 109 | if isinstance(labels, np.ndarray): 110 | labels = torch.Tensor(labels) 111 | labels_onehot = torch.where(labels > threshold, 1.0, 0.0) 112 | return labels_onehot 113 | 114 | 115 | def euclidean_min_dist(point, point_cloud): 116 | dist_array = np.linalg.norm( 117 | np.array(point_cloud) - np.array(point).reshape((1, 3)), axis=-1 118 | ) 119 | return np.min(dist_array), np.argsort(dist_array) 120 | 121 | 122 | def generate_contact_maps(contact_map): 123 | labels = contact_map 124 | labels = torch.FloatTensor(encode_onehot(labels)) 125 | return labels 126 | 127 | 128 | def normalize_pc(points, scale_fact): 129 | centroid = torch.mean(points, dim=0) 130 | points -= centroid 131 | points /= scale_fact 132 | 133 | return points 134 | 135 | 136 | def generate_adj_mat_feats(point_cloud, knn=8): 137 | """Generates a graph from a given point cloud based on k-NN.""" 138 | 139 | features = sp.csr_matrix(point_cloud, dtype=np.float32) 140 | 141 | # build graph 142 | dist = distance.squareform(distance.pdist(np.asarray(point_cloud))) 143 | closest = np.argsort(dist, axis=1) 144 | adj = np.zeros(closest.shape) 145 | 146 | for i in range(adj.shape[0]): 147 | adj[i, closest[i, 0 : knn + 1]] = 1 148 | 149 | adj = sp.coo_matrix(adj) 150 | 151 | # build symmetric adjacency matrix 152 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 153 | 154 | # features = normalize(features) 155 | adj = normalize(adj + sp.eye(adj.shape[0])) 156 | 157 | features = torch.FloatTensor(np.array(features.todense())) 158 | adj = sparse_mx_to_torch_sparse_tensor(adj) 159 | 160 | return adj, features 161 | 162 | 163 | def normalize(mx): 164 | """Row-normalize sparse matrix.""" 165 | rowsum = np.array(mx.sum(1)) 166 | r_inv = np.power(rowsum, -1).flatten() 167 | r_inv[np.isinf(r_inv)] = 0.0 168 | r_mat_inv = sp.diags(r_inv) 169 | mx = r_mat_inv.dot(mx) 170 | return mx 171 | 172 | 173 | def train_metrics(output, labels, threshold=0.5): 174 | """Generates all training metrics.""" 175 | 176 | batch_size = labels.shape[0] 177 | size = labels.shape[1] * labels.shape[2] 178 | total_num = batch_size * size 179 | preds = nn.Sigmoid()(output) 180 | preds = encode_onehot(preds, threshold=threshold) 181 | 182 | true_positives = (preds * labels).sum() 183 | false_positives = (preds * (1 - labels)).sum() 184 | false_negatives = ((1 - preds) * labels).sum() 185 | true_negatives = ((1 - preds) * (1 - labels)).sum() 186 | 187 | precision = true_positives / ((true_positives + false_positives) + 1e-6) 188 | recall = true_positives / (true_positives + false_negatives) 189 | acc = (true_negatives + true_positives) / total_num 190 | return ( 191 | acc, 192 | precision, 193 | recall, 194 | true_positives, 195 | true_negatives, 196 | false_positives, 197 | false_negatives, 198 | ) 199 | 200 | 201 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 202 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 203 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 204 | indices = torch.from_numpy( 205 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 206 | ) 207 | values = torch.from_numpy(sparse_mx.data) 208 | shape = torch.Size(sparse_mx.shape) 209 | return torch.sparse.FloatTensor(indices, values, shape) 210 | 211 | 212 | def graph_cross_convolution(inp, kernel, inp_adj, krn_adj): 213 | """Experimental: Graph Cross Convolution.""" 214 | kernel = torch.matmul(krn_adj.to_dense(), kernel) 215 | support = torch.matmul(inp, kernel.transpose(-2, -1)) 216 | output = torch.matmul(inp_adj.to_dense(), support) 217 | return output 218 | 219 | 220 | def matplotlib_imshow(img, one_channel=False): 221 | if one_channel: 222 | img = img.mean(dim=0) 223 | img = img / 2 + 0.5 # unnormalize 224 | npimg = img.numpy() 225 | if one_channel: 226 | plt.imshow(npimg, cmap='Greys') 227 | else: 228 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 229 | 230 | 231 | def square(true_positives, false_positives, true_negatives, false_negatives): 232 | """Defines a precision/recall square for Tensorboard plotting.""" 233 | 234 | fig = plt.figure(figsize=(9, 9)) 235 | ax = fig.add_subplot(111) 236 | rect1 = matplotlib.patches.Rectangle((0, 2), 2, 2, color='green') 237 | rect2 = matplotlib.patches.Rectangle((2, 2), 2, 2, color='red') 238 | rect3 = matplotlib.patches.Rectangle((0, 0), 2, 2, color='red') 239 | rect4 = matplotlib.patches.Rectangle((2, 0), 2, 2, color='green') 240 | 241 | ax.add_patch(rect1) 242 | ax.add_patch(rect2) 243 | ax.add_patch(rect3) 244 | ax.add_patch(rect4) 245 | rectangles = [rect1, rect2, rect3, rect4] 246 | tags = [ 247 | 'True Positives=' + str(true_positives.item()), 248 | 'False Positives=' + str(false_positives.item()), 249 | 'True Negatives=' + str(true_negatives.item()), 250 | 'False Negatives=' + str(false_negatives.item()), 251 | ] 252 | 253 | for r in range(4): 254 | ax.add_artist(rectangles[r]) 255 | rx, ry = rectangles[r].get_xy() 256 | cx = rx + rectangles[r].get_width() / 2.0 257 | cy = ry + rectangles[r].get_height() / 2.0 258 | 259 | ax.annotate( 260 | tags[r], 261 | (cx, cy), 262 | color='w', 263 | weight='bold', 264 | fontsize=12, 265 | ha='center', 266 | va='center', 267 | ) 268 | 269 | plt.xlim([0, 4]) 270 | plt.ylim([0, 4]) 271 | plt.savefig('square.jpg') 272 | -------------------------------------------------------------------------------- /utils/gripper_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities to represent an end-effector.""" 17 | 18 | import json 19 | import os 20 | import numpy as np 21 | from plotly import graph_objects as go 22 | import pytorch_kinematics as pk 23 | from pytorch_kinematics.urdf_parser_py.urdf import Box 24 | from pytorch_kinematics.urdf_parser_py.urdf import Cylinder 25 | from pytorch_kinematics.urdf_parser_py.urdf import Mesh 26 | from pytorch_kinematics.urdf_parser_py.urdf import Sphere 27 | from pytorch_kinematics.urdf_parser_py.urdf import URDF 28 | import torch 29 | import torch.nn 30 | import transforms3d 31 | import trimesh as tm 32 | import trimesh.sample 33 | import urdf_parser_py.urdf as URDF_PARSER 34 | from utils import math_utils 35 | 36 | 37 | class HandModel: 38 | """Hand model class based on: https://github.com/tengyu-liu/GenDexGrasp/blob/main/utils_model/HandModel.py.""" 39 | 40 | def __init__( 41 | self, 42 | robot_name, 43 | urdf_filename, 44 | mesh_path, 45 | batch_size=1, 46 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 47 | hand_scale=2.0, 48 | data_dir='data', 49 | ): 50 | self.device = device 51 | self.batch_size = batch_size 52 | self.data_dir = data_dir 53 | 54 | self.robot = pk.build_chain_from_urdf(open(urdf_filename).read()).to( 55 | dtype=torch.float, device=self.device 56 | ) 57 | self.robot_full = URDF_PARSER.URDF.from_xml_file(urdf_filename) 58 | 59 | if robot_name == 'allegro_right': 60 | self.robot_name = 'allegro_right' 61 | robot_name = 'allegro' 62 | else: 63 | self.robot_name = robot_name 64 | 65 | self.global_translation = None 66 | self.global_rotation = None 67 | self.softmax = torch.nn.Softmax(dim=-1) 68 | 69 | self.contact_point_dict = json.load( 70 | open(os.path.join(self.data_dir, 'urdf/contact_%s.json' % robot_name)) 71 | ) 72 | self.contact_point_basis = {} 73 | self.contact_normals = {} 74 | self.surface_points = {} 75 | self.surface_points_normal = {} 76 | visual = URDF.from_xml_string(open(urdf_filename).read()) 77 | self.centroids = json.load( 78 | open(os.path.join(self.data_dir, 'robot_centroids.json')) 79 | )[robot_name] 80 | self.keypoints = json.load( 81 | open(os.path.join(self.data_dir, 'robot_keypoints.json')) 82 | )[self.robot_name] 83 | self.key_point_idx_dict = {} 84 | self.mesh_verts = {} 85 | self.mesh_faces = {} 86 | 87 | self.canon_verts = [] 88 | self.canon_faces = [] 89 | self.idx_vert_faces = [] 90 | self.face_normals = [] 91 | self.links = [link.name for link in visual.links] 92 | 93 | for i_link, link in enumerate(visual.links): 94 | print(f'Processing link #{i_link}: {link.name}') 95 | 96 | if not link.visuals: 97 | continue 98 | if isinstance(link.visuals[0].geometry, Mesh): 99 | if ( 100 | robot_name == 'shadowhand' 101 | or robot_name == 'allegro' 102 | or robot_name == 'barrett' 103 | ): 104 | filename = link.visuals[0].geometry.filename.split('/')[-1] 105 | elif robot_name == 'allegro': 106 | filename = f"{link.visuals[0].geometry.filename.split('/')[-2]}/{link.visuals[0].geometry.filename.split('/')[-1]}" 107 | else: 108 | filename = link.visuals[0].geometry.filename 109 | mesh = tm.load( 110 | os.path.join(mesh_path, filename), force='mesh', process=False 111 | ) 112 | elif isinstance(link.visuals[0].geometry, Cylinder): 113 | mesh = tm.primitives.Cylinder( 114 | radius=link.visuals[0].geometry.radius, 115 | height=link.visuals[0].geometry.length, 116 | ) 117 | elif isinstance(link.visuals[0].geometry, Box): 118 | mesh = tm.primitives.Box(extents=link.visuals[0].geometry.size) 119 | elif isinstance(link.visuals[0].geometry, Sphere): 120 | mesh = tm.primitives.Sphere(radius=link.visuals[0].geometry.radius) 121 | else: 122 | print(type(link.visuals[0].geometry)) 123 | raise NotImplementedError 124 | try: 125 | scale = np.array(link.visuals[0].geometry.scale).reshape([1, 3]) 126 | except Exception: # pylint: disable=broad-exception-caught 127 | scale = np.array([[1, 1, 1]]) 128 | try: 129 | rotation = transforms3d.euler.euler2mat(*link.visuals[0].origin.rpy) 130 | translation = np.reshape(link.visuals[0].origin.xyz, [1, 3]) 131 | 132 | except Exception: # pylint: disable=broad-exception-caught 133 | rotation = transforms3d.euler.euler2mat(0, 0, 0) 134 | translation = np.array([[0, 0, 0]]) 135 | 136 | if self.robot_name == 'shadowhand': 137 | pts, pts_face_index = trimesh.sample.sample_surface(mesh=mesh, count=64) 138 | pts_normal = np.array( 139 | [mesh.face_normals[x] for x in pts_face_index], dtype=float 140 | ) 141 | else: 142 | pts, pts_face_index = trimesh.sample.sample_surface( 143 | mesh=mesh, count=128 144 | ) 145 | pts_normal = np.array( 146 | [mesh.face_normals[x] for x in pts_face_index], dtype=float 147 | ) 148 | 149 | pts *= scale 150 | if robot_name == 'shadowhand': 151 | pts = pts[:, [0, 2, 1]] 152 | pts_normal = pts_normal[:, [0, 2, 1]] 153 | pts[:, 1] *= -1 154 | pts_normal[:, 1] *= -1 155 | 156 | pts = np.matmul(rotation, pts.T).T + translation 157 | pts = np.concatenate([pts, np.ones([len(pts), 1])], axis=-1) 158 | pts_normal = np.concatenate( 159 | [pts_normal, np.ones([len(pts_normal), 1])], axis=-1 160 | ) 161 | self.surface_points[link.name] = ( 162 | torch.from_numpy(pts) 163 | .to(device) 164 | .float() 165 | .unsqueeze(0) 166 | .repeat(batch_size, 1, 1) 167 | ) 168 | self.surface_points_normal[link.name] = ( 169 | torch.from_numpy(pts_normal) 170 | .to(device) 171 | .float() 172 | .unsqueeze(0) 173 | .repeat(batch_size, 1, 1) 174 | ) 175 | 176 | # visualization mesh 177 | self.mesh_verts[link.name] = np.array(mesh.vertices) * scale 178 | if robot_name == 'shadowhand': 179 | self.mesh_verts[link.name] = self.mesh_verts[link.name][:, [0, 2, 1]] 180 | self.mesh_verts[link.name][:, 1] *= -1 181 | self.mesh_verts[link.name] = ( 182 | np.matmul(rotation, self.mesh_verts[link.name].T).T + translation 183 | ) 184 | self.mesh_faces[link.name] = np.array(mesh.faces) 185 | 186 | # contact point 187 | if link.name in self.contact_point_dict: 188 | cpb = np.array(self.contact_point_dict[link.name]) 189 | if len(cpb.shape) > 1: 190 | cpb = cpb[np.random.randint(cpb.shape[0], size=1)][0] 191 | 192 | cp_basis = mesh.vertices[cpb] * scale 193 | if robot_name == 'shadowhand': 194 | cp_basis = cp_basis[:, [0, 2, 1]] 195 | cp_basis[:, 1] *= -1 196 | cp_basis = np.matmul(rotation, cp_basis.T).T + translation 197 | cp_basis = torch.cat( 198 | [ 199 | torch.from_numpy(cp_basis).to(device).float(), 200 | torch.ones([4, 1]).to(device).float(), 201 | ], 202 | dim=-1, 203 | ) 204 | self.contact_point_basis[link.name] = cp_basis.unsqueeze(0).repeat( 205 | batch_size, 1, 1 206 | ) 207 | v1 = cp_basis[1, :3] - cp_basis[0, :3] 208 | v2 = cp_basis[2, :3] - cp_basis[0, :3] 209 | v1 = v1 / torch.norm(v1) 210 | v2 = v2 / torch.norm(v2) 211 | self.contact_normals[link.name] = torch.cross(v1, v2).view([1, 3]) 212 | self.contact_normals[link.name] = ( 213 | self.contact_normals[link.name] 214 | .unsqueeze(0) 215 | .repeat(batch_size, 1, 1) 216 | ) 217 | 218 | self.scale = hand_scale 219 | 220 | # new 2.1 221 | self.revolute_joints = [] 222 | for i, _ in enumerate(self.robot_full.joints): 223 | if self.robot_full.joints[i].joint_type == 'revolute': 224 | self.revolute_joints.append(self.robot_full.joints[i]) 225 | self.revolute_joints_q_mid = [] 226 | self.revolute_joints_q_var = [] 227 | self.revolute_joints_q_upper = [] 228 | self.revolute_joints_q_lower = [] 229 | for i, _ in enumerate(self.robot.get_joint_parameter_names()): 230 | for j, _ in enumerate(self.revolute_joints): 231 | if ( 232 | self.revolute_joints[j].name 233 | == self.robot.get_joint_parameter_names()[i] 234 | ): 235 | joint = self.revolute_joints[j] 236 | assert joint.name == self.robot.get_joint_parameter_names()[i] 237 | self.revolute_joints_q_mid.append( 238 | (joint.limit.lower + joint.limit.upper) / 2 239 | ) 240 | self.revolute_joints_q_var.append( 241 | ((joint.limit.upper - joint.limit.lower) / 2) ** 2 242 | ) 243 | self.revolute_joints_q_lower.append(joint.limit.lower) 244 | self.revolute_joints_q_upper.append(joint.limit.upper) 245 | 246 | joint_lower = np.array(self.revolute_joints_q_lower) 247 | joint_upper = np.array(self.revolute_joints_q_upper) 248 | joint_mid = (joint_lower + joint_upper) / 2 249 | joints_q = (joint_mid + joint_lower) / 2 250 | self.rest_pose = ( 251 | torch.from_numpy( 252 | np.concatenate([np.array([0, 0, 0, 1, 0, 0, 0, 1, 0]), joints_q]) 253 | ) 254 | .unsqueeze(0) 255 | .to(device) 256 | .float() 257 | ) 258 | 259 | self.revolute_joints_q_lower = ( 260 | torch.Tensor(self.revolute_joints_q_lower) 261 | .repeat([self.batch_size, 1]) 262 | .to(device) 263 | ) 264 | self.revolute_joints_q_upper = ( 265 | torch.Tensor(self.revolute_joints_q_upper) 266 | .repeat([self.batch_size, 1]) 267 | .to(device) 268 | ) 269 | 270 | self.rest_pose = self.rest_pose.repeat([self.batch_size, 1]) 271 | 272 | self.current_status = None 273 | self.canonical_keypoints = self.get_canonical_keypoints().to(device) 274 | 275 | def update_kinematics(self, q): 276 | self.global_translation = q[:, :3] 277 | 278 | self.global_rotation = ( 279 | math_utils.robust_compute_rotation_matrix_from_ortho6d(q[:, 3:9]) 280 | ) 281 | self.current_status = self.robot.forward_kinematics(q[:, 9:]) 282 | 283 | def get_surface_points(self, q=None, downsample=False): 284 | """Returns surface points on the end-effector on a given pose.""" 285 | 286 | if q is not None: 287 | self.update_kinematics(q) 288 | surface_points = [] 289 | for link_name in self.surface_points: 290 | if self.robot_name == 'robotiq_3finger' and link_name == 'gripper_palm': 291 | continue 292 | if ( 293 | self.robot_name == 'robotiq_3finger_real_robot' 294 | and link_name == 'palm' 295 | ): 296 | continue 297 | trans_matrix = self.current_status[link_name].get_matrix() 298 | surface_points.append( 299 | torch.matmul( 300 | trans_matrix, self.surface_points[link_name].transpose(1, 2) 301 | ).transpose(1, 2)[..., :3] 302 | ) 303 | surface_points = torch.cat(surface_points, 1) 304 | surface_points = torch.matmul( 305 | self.global_rotation, surface_points.transpose(1, 2) 306 | ).transpose(1, 2) + self.global_translation.unsqueeze(1) 307 | if downsample: 308 | surface_points = surface_points[ 309 | :, torch.randperm(surface_points.shape[1]) 310 | ][:, :1000] 311 | return surface_points * self.scale 312 | 313 | def get_canonical_keypoints(self): 314 | """Returns canonical keypoints aka the N user-selected keypoints.""" 315 | 316 | self.update_kinematics(self.rest_pose) 317 | key_points = np.array([ 318 | np.array(keypoint[str(i)][0]) 319 | for i, keypoint in enumerate(self.keypoints) 320 | ]) 321 | key_points = torch.tensor(key_points).unsqueeze(0).float().to(self.device) 322 | key_points = key_points.repeat(self.batch_size, 1, 1) 323 | key_points -= self.global_translation.unsqueeze(1) 324 | key_points = torch.matmul( 325 | torch.inverse(self.global_rotation), key_points.transpose(1, 2) 326 | ).transpose(1, 2) 327 | new_key_points = [] 328 | 329 | for i, keypoint in enumerate(self.keypoints): 330 | curr_keypoint = key_points[0, i, :] 331 | curr_keypoint_link_name = keypoint[str(i)][1] 332 | curr_keypoint = torch.cat( 333 | (curr_keypoint.clone().detach(), torch.tensor([1.0]).to(self.device)) 334 | ).float() 335 | 336 | trans_matrix = self.current_status[curr_keypoint_link_name].get_matrix() 337 | # Address batch size if present. 338 | if trans_matrix.shape[0] != self.batch_size: 339 | trans_matrix = trans_matrix.repeat(self.batch_size, 1, 1) 340 | 341 | self.key_point_idx_dict[curr_keypoint_link_name] = [] 342 | new_key_points.append( 343 | torch.matmul( 344 | torch.inverse(trans_matrix), 345 | curr_keypoint[None, None].transpose(1, 2), 346 | ).transpose(1, 2)[..., :3] 347 | ) 348 | new_key_points = torch.cat(new_key_points, 1) 349 | 350 | return new_key_points 351 | 352 | def get_static_key_points(self, q, surface_pt_sample=None): 353 | """Returns the canonical keypoints when in a given end-effector pose.""" 354 | 355 | final_key_points = [] 356 | final_key_points_idx = [] 357 | 358 | self.update_kinematics(q) 359 | 360 | for i in range(self.canonical_keypoints.shape[1]): 361 | curr_keypoint = self.canonical_keypoints[0, i, :] 362 | curr_keypoint_link_name = self.keypoints[i][str(i)][1] 363 | curr_keypoint = torch.cat( 364 | (curr_keypoint.clone().detach(), torch.tensor([1.0]).to(self.device)) 365 | ).float() 366 | 367 | if surface_pt_sample is not None: 368 | self.key_point_idx_dict[curr_keypoint_link_name] = [] 369 | 370 | trans_matrix = self.current_status[curr_keypoint_link_name].get_matrix() 371 | final_key_points.append( 372 | torch.matmul( 373 | trans_matrix, curr_keypoint[None, None].transpose(1, 2) 374 | ).transpose(1, 2)[..., :3] 375 | ) 376 | 377 | final_key_points = torch.cat(final_key_points, 1) 378 | final_key_points = torch.matmul( 379 | self.global_rotation, final_key_points.transpose(1, 2) 380 | ).transpose(1, 2) + self.global_translation.unsqueeze(1) 381 | final_key_points = (final_key_points * self.scale).squeeze(0) 382 | 383 | if surface_pt_sample is not None: 384 | for i, final_kp in enumerate(final_key_points): 385 | curr_keypoint_link_name = self.keypoints[i][str(i)][1] 386 | closest_vert_idx = np.argsort( 387 | np.linalg.norm( 388 | np.array(self.mesh_verts[curr_keypoint_link_name]) 389 | - np.array(final_kp).reshape((1, 3)), 390 | axis=-1, 391 | ) 392 | )[0] 393 | self.key_point_idx_dict[curr_keypoint_link_name].append( 394 | torch.tensor(closest_vert_idx) 395 | ) 396 | closest_surface_sample_idx = np.argsort( 397 | np.linalg.norm( 398 | np.array(surface_pt_sample) 399 | - np.array(final_kp).reshape((1, 3)), 400 | axis=-1, 401 | ) 402 | )[0] 403 | final_key_points_idx.append(closest_surface_sample_idx) 404 | 405 | for k in self.key_point_idx_dict: 406 | self.key_point_idx_dict[k] = torch.tensor(self.key_point_idx_dict[k]) 407 | 408 | return ( 409 | final_key_points, 410 | self.key_point_idx_dict, 411 | torch.Tensor(final_key_points_idx), 412 | ) 413 | 414 | def get_key_points_from_indices(self, key_point_idx_dict, q=None): 415 | """Returns keypoints from a set of indices when in a given pose.""" 416 | 417 | if q is not None: 418 | self.update_kinematics(q) 419 | 420 | key_points = [] 421 | for link_name in key_point_idx_dict: 422 | trans_matrix = self.current_status[link_name].get_matrix() 423 | pts = np.concatenate( 424 | [ 425 | self.mesh_verts[link_name], 426 | np.ones([len(self.mesh_verts[link_name]), 1]), 427 | ], 428 | axis=-1, 429 | ) 430 | surface_points = ( 431 | torch.from_numpy(pts).float().unsqueeze(0).repeat(1, 1, 1) 432 | ) 433 | 434 | key_point_idx = key_point_idx_dict[link_name] 435 | key_points.append( 436 | torch.matmul( 437 | trans_matrix, 438 | surface_points[:, key_point_idx.long(), :].transpose(1, 2), 439 | ).transpose(1, 2) 440 | ) 441 | 442 | key_points = torch.cat(key_points, dim=1)[..., :3] 443 | key_points = torch.matmul( 444 | self.global_rotation, key_points.transpose(1, 2) 445 | ).transpose(1, 2) + self.global_translation.unsqueeze(1) 446 | return (key_points * self.scale).squeeze(0) 447 | 448 | def get_meshes_from_q(self, q=None, i=0): 449 | """Returns gripper meshes in a given pose.""" 450 | 451 | data = [] 452 | if q is not None: 453 | self.update_kinematics(q) 454 | for _, link_name in enumerate(self.mesh_verts): 455 | trans_matrix = self.current_status[link_name].get_matrix() 456 | trans_matrix = ( 457 | trans_matrix[min(len(trans_matrix) - 1, i)].detach().cpu().numpy() 458 | ) 459 | v = self.mesh_verts[link_name] 460 | transformed_v = np.concatenate([v, np.ones([len(v), 1])], axis=-1) 461 | transformed_v = np.matmul(trans_matrix, transformed_v.T).T[..., :3] 462 | transformed_v = np.matmul( 463 | self.global_rotation[i].detach().cpu().numpy(), transformed_v.T 464 | ).T + np.expand_dims(self.global_translation[i].detach().cpu().numpy(), 0) 465 | transformed_v = transformed_v * self.scale 466 | f = self.mesh_faces[link_name] 467 | data.append(tm.Trimesh(vertices=transformed_v, faces=f)) 468 | return data 469 | 470 | def get_plotly_data(self, q=None, i=0, color='lightblue', opacity=1.0): 471 | """Returns plot data for the gripper in a given pose.""" 472 | 473 | data = [] 474 | if q is not None: 475 | self.update_kinematics(q) 476 | for _, link_name in enumerate(self.mesh_verts): 477 | trans_matrix = self.current_status[link_name].get_matrix() 478 | trans_matrix = ( 479 | trans_matrix[min(len(trans_matrix) - 1, i)].detach().cpu().numpy() 480 | ) 481 | v = self.mesh_verts[link_name] 482 | transformed_v = np.concatenate([v, np.ones([len(v), 1])], axis=-1) 483 | transformed_v = np.matmul(trans_matrix, transformed_v.T).T[..., :3] 484 | transformed_v = np.matmul( 485 | self.global_rotation[i].detach().cpu().numpy(), transformed_v.T 486 | ).T + np.expand_dims(self.global_translation[i].detach().cpu().numpy(), 0) 487 | transformed_v = transformed_v * self.scale 488 | f = self.mesh_faces[link_name] 489 | data.append( 490 | go.Mesh3d( 491 | x=transformed_v[:, 0], 492 | y=transformed_v[:, 1], 493 | z=transformed_v[:, 2], 494 | i=f[:, 0], 495 | j=f[:, 1], 496 | k=f[:, 2], 497 | color=color, 498 | opacity=opacity, 499 | ) 500 | ) 501 | return data 502 | -------------------------------------------------------------------------------- /utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Math utilities for rotations and vector algebra.""" 17 | 18 | import numpy as np 19 | import torch 20 | import transforms3d 21 | 22 | 23 | def get_rot6d_from_rot3d(rot3d): 24 | global_rotation = np.array( 25 | transforms3d.euler.euler2mat(rot3d[0], rot3d[1], rot3d[2]) 26 | ) 27 | return global_rotation.T.reshape(9)[:6] 28 | 29 | 30 | def robust_compute_rotation_matrix_from_ortho6d(poses): 31 | """TODO(jmattarian): Code from: XXXXXXX.""" 32 | 33 | x_raw = poses[:, 0:3] 34 | y_raw = poses[:, 3:6] 35 | 36 | x = normalize_vector(x_raw) 37 | y = normalize_vector(y_raw) 38 | middle = normalize_vector(x + y) 39 | orthmid = normalize_vector(x - y) 40 | x = normalize_vector(middle + orthmid) 41 | y = normalize_vector(middle - orthmid) 42 | z = normalize_vector(cross_product(x, y)) 43 | 44 | x = x.view(-1, 3, 1) 45 | y = y.view(-1, 3, 1) 46 | z = z.view(-1, 3, 1) 47 | matrix = torch.cat((x, y, z), 2) 48 | return matrix 49 | 50 | 51 | def normalize_vector(v): 52 | batch = v.shape[0] 53 | v_mag = torch.sqrt(v.pow(2).sum(1)) # batch 54 | v_mag = torch.max(v_mag, v.new([1e-8])) 55 | v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) 56 | v = v / v_mag 57 | return v 58 | 59 | 60 | def cross_product(u, v): 61 | batch = u.shape[0] 62 | i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] 63 | j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] 64 | k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] 65 | 66 | out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) 67 | 68 | return out 69 | -------------------------------------------------------------------------------- /utils_data/augmentors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Augmentation classes for point clouds.""" 17 | 18 | import numpy as np 19 | import torch 20 | 21 | 22 | def angle_axis(angle: float, axis: np.ndarray): 23 | """Returns a 4x4 rotation matrix that performs a rotation around axis by angle.""" 24 | 25 | u = axis / np.linalg.norm(axis) 26 | cosval, sinval = np.cos(angle), np.sin(angle) 27 | 28 | cross_prod_mat = np.array( 29 | [[0.0, -u[2], u[1]], [u[2], 0.0, -u[0]], [-u[1], u[0], 0.0]] 30 | ) 31 | 32 | r = torch.from_numpy( 33 | cosval * np.eye(3) 34 | + sinval * cross_prod_mat 35 | + (1.0 - cosval) * np.outer(u, u) 36 | ) 37 | return r.float() 38 | 39 | 40 | class PointcloudJitter(object): 41 | """Adds jitter with a given std to a point cloud.""" 42 | 43 | def __init__(self, std=0.005, clip=0.025): 44 | self.std, self.clip = std, clip 45 | 46 | def __call__(self, points): 47 | jittered_data = ( 48 | points.new(points.size(0), 3) 49 | .normal_(mean=0.0, std=self.std) 50 | .clamp_(-self.clip, self.clip) 51 | ) 52 | points[:, 0:3] += jittered_data 53 | return points 54 | 55 | 56 | class PointcloudScale(object): 57 | 58 | def __init__(self, lo=0.8, hi=1.25): 59 | self.lo, self.hi = lo, hi 60 | 61 | def __call__(self, points): 62 | scaler = np.random.uniform(self.lo, self.hi) 63 | points[:, 0:3] *= scaler 64 | return points 65 | 66 | 67 | class PointcloudTranslate(object): 68 | 69 | def __init__(self, translate_range=0.1): 70 | self.translate_range = translate_range 71 | 72 | def __call__(self, points): 73 | translation = np.random.uniform(-self.translate_range, self.translate_range) 74 | points[:, 0:3] += translation 75 | return points 76 | 77 | 78 | class PointcloudRotatePerturbation(object): 79 | """Applies a random rotation to a point cloud.""" 80 | 81 | def __init__(self, angle_sigma=0.06, angle_clip=0.18): 82 | self.angle_sigma, self.angle_clip = angle_sigma, angle_clip 83 | 84 | def _get_angles(self): 85 | angles = np.clip( 86 | self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip 87 | ) 88 | 89 | return angles 90 | 91 | def __call__(self, points): 92 | angles = self._get_angles() 93 | rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0])) 94 | ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0])) 95 | rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0])) 96 | 97 | rotation_matrix = torch.matmul(torch.matmul(rz, ry), rx) 98 | 99 | normals = points.size(1) > 3 100 | if not normals: 101 | return torch.matmul(points, rotation_matrix.t()) 102 | else: 103 | pc_xyz = points[:, 0:3] 104 | pc_normals = points[:, 3:] 105 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 106 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 107 | 108 | return points 109 | -------------------------------------------------------------------------------- /utils_data/gnn_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implements Pytorch dataloader class needed for GeoMatch.""" 17 | 18 | import json 19 | import os 20 | import torch 21 | from torch.utils import data 22 | from utils_data import augmentors 23 | 24 | 25 | class GNNDataset(data.Dataset): 26 | """Dataloader class for GeoMatch.""" 27 | 28 | def __init__( 29 | self, 30 | dataset_basedir, 31 | object_npts=2048, 32 | device='cuda' if torch.cuda.is_available() else 'cpu', 33 | mode='train', 34 | robot_name_list=None, 35 | ): 36 | self.device = device 37 | self.dataset_basedir = dataset_basedir 38 | self.object_npts = object_npts 39 | 40 | if not robot_name_list: 41 | self.robot_name_list = [ 42 | 'ezgripper', 43 | 'barrett', 44 | 'robotiq_3finger', 45 | 'allegro', 46 | 'shadowhand', 47 | ] 48 | else: 49 | self.robot_name_list = robot_name_list 50 | 51 | print('loading object point clouds and adjacency matrices....') 52 | self.object_pc_adj = torch.load( 53 | os.path.join(dataset_basedir, 'gnn_obj_adj_point_clouds_new.pt') 54 | ) 55 | 56 | print('loading robot point clouds and adjacency matrices...') 57 | self.robot_pc_adj = torch.load( 58 | os.path.join(dataset_basedir, 'gnn_robot_adj_point_clouds_new.pt') 59 | ) 60 | 61 | print('loading object/robot cmaps....') 62 | cmap_dataset = torch.load( 63 | os.path.join(dataset_basedir, 'gnn_obj_cmap_robot_cmap_adj_new.pt') 64 | )['metadata'] 65 | 66 | self.metadata = cmap_dataset 67 | 68 | if mode == 'train': 69 | self.object_list = json.load( 70 | open( 71 | os.path.join( 72 | dataset_basedir, 73 | 'CMapDataset-sqrt_align/split_train_validate_objects.json', 74 | ), 75 | 'rb', 76 | ) 77 | )[mode] 78 | self.metadata = [ 79 | t 80 | for t in self.metadata 81 | if t[6] in self.object_list and t[7] in self.robot_name_list 82 | ] 83 | elif mode == 'validate': 84 | self.object_list = json.load( 85 | open( 86 | os.path.join( 87 | dataset_basedir, 88 | 'CMapDataset-sqrt_align/split_train_validate_objects.json', 89 | ), 90 | 'rb', 91 | ) 92 | )[mode] 93 | self.metadata = [ 94 | t 95 | for t in self.metadata 96 | if t[6] in self.object_list and t[7] in self.robot_name_list 97 | ] 98 | elif mode == 'full': 99 | self.object_list = ( 100 | json.load( 101 | open( 102 | os.path.join( 103 | dataset_basedir, 104 | 'CMapDataset-sqrt_align/split_train_validate_objects.json', 105 | ), 106 | 'rb', 107 | ) 108 | )['train'] 109 | + json.load( 110 | open( 111 | os.path.join( 112 | dataset_basedir, 'split_train_validate_objects.json' 113 | ), 114 | 'rb', 115 | ) 116 | )['validate'] 117 | ) 118 | self.metadata = [ 119 | t 120 | for t in self.metadata 121 | if t[6] in self.object_list and t[7] in self.robot_name_list 122 | ] 123 | else: 124 | raise NotImplementedError() 125 | print(f'object selection: {self.object_list}') 126 | 127 | self.mode = mode 128 | 129 | self.datasize = len(self.metadata) 130 | print('finish loading dataset....') 131 | 132 | self.rotate = augmentors.PointcloudRotatePerturbation( 133 | angle_sigma=0.03, angle_clip=0.1 134 | ) 135 | self.translate = augmentors.PointcloudTranslate(translate_range=0.01) 136 | self.jitter = augmentors.PointcloudJitter(std=0.04, clip=0.1) 137 | 138 | def __len__(self): 139 | return self.datasize 140 | 141 | def __getitem__(self, item): 142 | object_name = self.metadata[item][6] 143 | robot_name = self.metadata[item][7] 144 | 145 | obj_adj = self.object_pc_adj[object_name][0] 146 | obj_contacts = self.metadata[item][0] 147 | obj_features = self.object_pc_adj[object_name][1] 148 | 149 | if self.mode in ['train']: 150 | obj_features = self.rotate(obj_features) 151 | 152 | robot_adj = self.robot_pc_adj[robot_name][0] 153 | robot_features = self.robot_pc_adj[robot_name][1] 154 | robot_key_point_idx = self.robot_pc_adj[robot_name][3] 155 | assert robot_key_point_idx.shape[0] == 6 156 | robot_contacts = self.metadata[item][1] 157 | top_obj_contact_kps = self.metadata[item][2] 158 | assert top_obj_contact_kps.shape[0] == 6 159 | top_obj_contact_verts = self.metadata[item][3] 160 | assert top_obj_contact_verts.shape[0] == 6 161 | full_obj_contact_map = self.metadata[item][4] 162 | 163 | if self.mode in ['train']: 164 | robot_features = self.rotate(robot_features) 165 | 166 | return ( 167 | obj_adj, 168 | obj_features, 169 | obj_contacts, 170 | robot_adj, 171 | robot_features, 172 | robot_key_point_idx, 173 | robot_contacts, 174 | top_obj_contact_kps, 175 | top_obj_contact_verts, 176 | full_obj_contact_map, 177 | object_name, 178 | robot_name, 179 | ) 180 | -------------------------------------------------------------------------------- /visualize_geomatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Visualization helped for GeoMatch predictions.""" 17 | 18 | import argparse 19 | import itertools 20 | import json 21 | import os 22 | import random 23 | import config 24 | from models.geomatch import GeoMatch 25 | from plotly.subplots import make_subplots 26 | import torch 27 | from torch import nn 28 | from utils.general_utils import get_handmodel 29 | from utils.gnn_utils import plot_point_cloud 30 | 31 | 32 | def return_random_obj_ee_pair( 33 | obj_name, rbt_name, obj_data, rbt_data, contact_map_data_gt 34 | ): 35 | """Returns a random object-gripper pair to use for visualization.""" 36 | 37 | obj_adj = obj_data[obj_name][0] 38 | obj_pc = obj_data[obj_name][1] 39 | robot_adj = rbt_data[rbt_name][0] 40 | robot_pc = rbt_data[rbt_name][1] 41 | rest_pose = rbt_data[rbt_name][2] 42 | keypoints_idx = rbt_data[rbt_name][3] 43 | keypoints_idx_dict = rbt_data[rbt_name][4] 44 | 45 | found = False 46 | idx_list = [] 47 | for i, data in enumerate(contact_map_data_gt['metadata']): 48 | if data[6] == obj_name and data[7] == rbt_name: 49 | idx_list.append(i) 50 | found = True 51 | break 52 | 53 | if not found: 54 | raise ModuleNotFoundError('Did not find a matching combination, try again!') 55 | 56 | rand_idx = random.choice(idx_list) 57 | obj_cmap = contact_map_data_gt['metadata'][rand_idx][0] 58 | robot_cmap = contact_map_data_gt['metadata'][rand_idx][1] 59 | top_obj_contact_kps = contact_map_data_gt['metadata'][rand_idx][2] 60 | top_obj_contact_verts = contact_map_data_gt['metadata'][rand_idx][3] 61 | q = contact_map_data_gt['metadata'][rand_idx][4] 62 | 63 | return ( 64 | obj_adj, 65 | obj_pc, 66 | robot_adj, 67 | robot_pc, 68 | rest_pose, 69 | keypoints_idx, 70 | obj_cmap, 71 | robot_cmap, 72 | top_obj_contact_kps, 73 | top_obj_contact_verts, 74 | q, 75 | keypoints_idx_dict, 76 | ) 77 | 78 | 79 | def autoregressive_inference( 80 | contact_map_pred, match_model, obj_pc, robot_embed, obj_embed, top_k=0 81 | ): 82 | """Performs the autoregressive inference of GeoMatch.""" 83 | 84 | with torch.no_grad(): 85 | max_per_kp = torch.topk(contact_map_pred, k=3, dim=1) 86 | pred_curr = None 87 | grasp_points = [] 88 | contact_or_not = [] 89 | 90 | obj_proj_embed = match_model.obj_proj(obj_embed) 91 | robot_proj_embed = match_model.robot_proj(robot_embed) 92 | 93 | for i_prev in range(config.keypoint_n - 1): 94 | model_kp = match_model.kp_ar_model_1 95 | 96 | if i_prev == 1: 97 | model_kp = match_model.kp_ar_model_2 98 | elif i_prev == 2: 99 | model_kp = match_model.kp_ar_model_3 100 | elif i_prev == 3: 101 | model_kp = match_model.kp_ar_model_4 102 | elif i_prev == 4: 103 | model_kp = match_model.kp_ar_model_5 104 | 105 | xyz_prev = torch.gather( 106 | obj_pc[None], 107 | 1, 108 | max_per_kp.indices[:, top_k, i_prev, :].repeat(1, 1, 3), 109 | ) 110 | 111 | if i_prev == 0: 112 | grasp_points.append(xyz_prev.squeeze()) 113 | contact_or_not.append(torch.tensor(1)) 114 | else: 115 | xyz_prev = torch.stack(grasp_points, dim=0)[None] 116 | 117 | pred_curr = model_kp( 118 | obj_proj_embed, obj_pc[None], robot_proj_embed, xyz_prev 119 | ) 120 | pred_prob = nn.Sigmoid()(pred_curr) 121 | vert_pred = torch.max(pred_prob[..., 0], dim=-1) 122 | min_idx = vert_pred.indices[0] 123 | contact_or_not.append(torch.tensor(int(vert_pred.values[0] >= 0.5))) 124 | 125 | # Projected on object 126 | pred_curr = obj_pc[min_idx] 127 | grasp_points.append(pred_curr) 128 | 129 | grasp_points = torch.stack(grasp_points, dim=0) 130 | contact_or_not = torch.stack(contact_or_not, dim=0) 131 | 132 | return torch.cat((grasp_points, contact_or_not[..., None]), dim=-1) 133 | 134 | 135 | def plot_side_by_side( 136 | point_cloud, 137 | contact_map, 138 | hand_data, 139 | i_keypoint, 140 | save_dir, 141 | save_plot, 142 | gt_contact_map=None, 143 | pred_points=None, 144 | top_obj_contact_kps=None, 145 | ): 146 | """Side-by-side plots of the object point cloud with the predicted keypoints, gripper with the canonical keypoints and a GT sample for comparison. 147 | 148 | Each keypoint will generate a new plot. The current keypoint is depicted with 149 | a different color. 150 | 151 | Args: 152 | point_cloud: the object point cloud 153 | contact_map: the predicted contact map 154 | hand_data: gripper data to plot - mesh, point cloud etc. 155 | i_keypoint: i-th keypoint to plot data for 156 | save_dir: directory to save plots in 157 | save_plot: bool, whether to save plots 158 | gt_contact_map: ground truth contact map for comparison 159 | pred_points: predicted keypoints to plot 160 | top_obj_contact_kps: grouth truth contact points to plot 161 | """ 162 | fig = make_subplots( 163 | rows=1, 164 | cols=3, 165 | specs=[ 166 | [{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}] 167 | ], 168 | ) 169 | fig.add_trace( 170 | plot_point_cloud(point_cloud, contact_map.squeeze()), row=1, col=1 171 | ) 172 | 173 | if pred_points is not None: 174 | pred_points = pred_points.detach().numpy() 175 | 176 | for i in range(pred_points.shape[0]): 177 | c = 'black' 178 | if pred_points[i, 3] == 1.0: 179 | c = 'red' 180 | fig.add_trace( 181 | plot_point_cloud(pred_points[i, :3][None], color=c, size=5), 182 | row=1, 183 | col=1, 184 | ) 185 | fig.add_trace( 186 | plot_point_cloud( 187 | pred_points[i_keypoint][None], color='magenta', size=5 188 | ), 189 | row=1, 190 | col=1, 191 | ) 192 | 193 | for d in hand_data: 194 | fig.add_trace(d, row=1, col=2) 195 | 196 | if gt_contact_map is not None: 197 | fig.add_trace( 198 | plot_point_cloud(point_cloud, gt_contact_map.squeeze()), row=1, col=3 199 | ) 200 | 201 | if top_obj_contact_kps is not None: 202 | fig.add_trace( 203 | plot_point_cloud( 204 | top_obj_contact_kps[i_keypoint, :][None], color='red' 205 | ), 206 | row=1, 207 | col=3, 208 | ) 209 | 210 | fig.update_layout( 211 | height=800, 212 | width=1800, 213 | title_text=( 214 | f'Prediction on keypoint {i_keypoint} and one GT grasp for' 215 | ' comparison.' 216 | ), 217 | ) 218 | if not save_plot: 219 | fig.show() 220 | else: 221 | fig.write_image( 222 | os.path.join(save_dir, f'prediction+rand_gt_keypoint_{i_keypoint}'), 223 | 'jpg', 224 | scale=2, 225 | ) 226 | 227 | 228 | def plot_predicted_keypoints( 229 | obj_name, 230 | rbt_name, 231 | obj_data, 232 | rbt_data, 233 | contact_map_data_gt, 234 | geomatch_model, 235 | save_dir, 236 | save_plot, 237 | data_basedir, 238 | top_k=0, 239 | ): 240 | """Generates plots for a predicted grasp for a given object-gripper pair.""" 241 | ( 242 | obj_adj, 243 | obj_pc, 244 | robot_adj, 245 | robot_pc, 246 | rest_pose, 247 | keypoints_idx, 248 | obj_cmap, 249 | robot_cmap, 250 | top_obj_contact_kps, 251 | _, 252 | _, 253 | _, 254 | ) = return_random_obj_ee_pair( 255 | obj_name, rbt_name, obj_data, rbt_data, contact_map_data_gt 256 | ) 257 | 258 | with torch.no_grad(): 259 | obj_embed = geomatch_model.encode_embed( 260 | geomatch_model.obj_encoder, obj_pc[None], obj_adj[None] 261 | ) 262 | robot_embed = geomatch_model.encode_embed( 263 | geomatch_model.robot_encoder, robot_pc[None], robot_adj[None] 264 | ) 265 | 266 | robot_feat_size = robot_embed.shape[2] 267 | keypoint_feat = torch.gather( 268 | robot_embed, 269 | 1, 270 | keypoints_idx[..., None].long().repeat(1, 1, robot_feat_size), 271 | ) 272 | contact_map_pred = torch.matmul(obj_embed, keypoint_feat.transpose(2, 1))[ 273 | ..., None 274 | ] 275 | gt_contact_map = ( 276 | (obj_cmap * robot_cmap.repeat(1, config.obj_pc_n)) 277 | .transpose(1, 0)[..., None] 278 | .contiguous() 279 | ) 280 | 281 | top_obj_contact_kps_pred = autoregressive_inference( 282 | contact_map_pred, model, obj_pc, robot_embed, obj_embed, top_k 283 | ) 284 | pred_points = top_obj_contact_kps_pred 285 | 286 | hand_model = get_handmodel(rbt_name, 1, 'cpu', 1.0, data_dir=data_basedir) 287 | print('PREDICTION: ', pred_points) 288 | 289 | for i in range(contact_map_pred.shape[2]): 290 | gt_contact_map_i = gt_contact_map[:, i, :] 291 | 292 | obj_kp_cmap = contact_map_pred[:, :, i, :] 293 | obj_kp_cmap_labels = torch.nn.Sigmoid()(obj_kp_cmap) 294 | 295 | selected_kp = robot_pc[keypoints_idx[i].long(), :][None] 296 | vis_data = hand_model.get_plotly_data(q=rest_pose, opacity=0.5) 297 | vis_data += [ 298 | plot_point_cloud(robot_pc[keypoints_idx.long(), :].cpu(), color='black') 299 | ] 300 | vis_data += [plot_point_cloud(selected_kp.cpu(), color='red')] 301 | 302 | plot_side_by_side( 303 | obj_pc, 304 | obj_kp_cmap_labels.detach().numpy(), 305 | vis_data, 306 | i, 307 | save_dir, 308 | save_plot, 309 | gt_contact_map_i, 310 | pred_points, 311 | top_obj_contact_kps, 312 | ) 313 | 314 | 315 | if __name__ == '__main__': 316 | parser = argparse.ArgumentParser() 317 | parser.add_argument('--object_name', type=str, default='') 318 | parser.add_argument('--robot_name', type=str, default='') 319 | parser.add_argument('--random_example', default=True, action='store_true') 320 | parser.add_argument('--dataset_dir', type=str, default='/data/grasp_gnn') 321 | parser.add_argument('--save_plots', default=False, action='store_true') 322 | parser.add_argument('--top_k_idx', type=int, default=0) 323 | parser.add_argument( 324 | '--saved_model_dir', 325 | type=str, 326 | default='logs_train/exp-pos_weight_500_200_6_kps-1683055568.7644374/', 327 | ) 328 | args = parser.parse_args() 329 | 330 | dataset_basedir = args.dataset_dir 331 | saved_model_dir = args.saved_model_dir 332 | top_k_idx = args.top_k_idx 333 | 334 | saved_model_dir = args.saved_model_dir 335 | save_plot_dir = os.path.join(saved_model_dir, 'plots') 336 | 337 | if not os.path.exists(save_plot_dir): 338 | os.mkdir(save_plot_dir) 339 | 340 | device = 'cpu' 341 | object_data = torch.load( 342 | os.path.join(dataset_basedir, 'gnn_obj_adj_point_clouds_new.pt') 343 | ) 344 | robot_data = torch.load( 345 | os.path.join(dataset_basedir, 'gnn_robot_adj_point_clouds_new.pt') 346 | ) 347 | cmap_data_gt = torch.load( 348 | os.path.join(dataset_basedir, 'gnn_obj_cmap_robot_cmap_adj_new.pt'), 349 | map_location=torch.device('cpu'), 350 | ) 351 | 352 | eval_object_list = json.load( 353 | open( 354 | os.path.join( 355 | dataset_basedir, 356 | 'CMapDataset-sqrt_align/split_train_validate_objects.json', 357 | ), 358 | 'rb', 359 | ) 360 | )['validate'] 361 | robot_name_list = [ 362 | 'ezgripper', 363 | 'barrett', 364 | 'robotiq_3finger', 365 | 'allegro', 366 | 'shadowhand', 367 | ] 368 | 369 | obj_robot_pairs = list(itertools.product(eval_object_list, robot_name_list)) 370 | 371 | model = GeoMatch(config) 372 | model.load_state_dict( 373 | torch.load( 374 | os.path.join(saved_model_dir, 'weights/grasp_gnn.pth'), 375 | map_location=torch.device('cpu'), 376 | ) 377 | ) 378 | 379 | model.eval() 380 | 381 | object_name = args.object_name 382 | robot_name = args.robot_name 383 | 384 | plot_predicted_keypoints( 385 | object_name, 386 | robot_name, 387 | object_data, 388 | robot_data, 389 | cmap_data_gt, 390 | model, 391 | save_plot_dir, 392 | args.save_plots, 393 | dataset_basedir, 394 | top_k_idx, 395 | ) 396 | --------------------------------------------------------------------------------