├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── config.py
├── create_gnn_dataset.py
├── generate_grasps_for_all_objs.py
├── generate_grasps_for_obj.py
├── models
├── geomatch.py
├── gnn.py
└── mlp.py
├── train.py
├── utils
├── general_utils.py
├── gnn_utils.py
├── gripper_utils.py
└── math_utils.py
├── utils_data
├── augmentors.py
└── gnn_dataset.py
└── visualize_geomatch.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GeoMatch: Geometry Matching for Multi-Embodiment Grasping
2 |
3 | While significant progress has been made on the problem of generating grasps, many existing learning-based approaches still concentrate on a single embodiment, provide limited generalization to higher DoF end-effectors and cannot capture a diverse set of grasp modes. In this paper, we tackle the problem of grasping multi-embodiments through the viewpoint of learning rich geometric representations for both objects and end-effectors using Graph Neural Networks (GNN). Our novel method - GeoMatch - applies supervised learning on grasping data from multiple embodiments, learning end-to-end contact point likelihood maps as well as conditional autoregressive prediction of grasps keypoint-by-keypoint. We compare our method against 3 baselines that provide multi-embodiment support. Our approach performs better across 3 end-effectors, while also providing competitive diversity of grasps. Examples can be found at geomatch.github.io.
4 |
5 | This is source code for the paper: [Geometry Matching for Multi-Embodiment Grasping](https://arxiv.org/abs/2312.03864).
6 |
7 | ## Installation
8 |
9 | To get started, creating an Anaconda or virtual environment is recommended.
10 |
11 | This repository was developed on pytorch 1.13 among other dependencies:
12 |
13 | ```pip install torch==1.13.1 pytorch-kinematics matplotlib transforms3d numpy scipy plotly trimesh urdf_parser_py tqdm argparse```
14 |
15 | For this work, we used the data from [GenDexGrasp: Generalizable Dexterous Grasping](https://github.com/tengyu-liu/GenDexGrasp/tree/main). Please follow instructions on this link to download.
16 |
17 | ## Usage
18 |
19 | To train our model, run:
20 |
21 | ```python3 train.py --epochs=XXX --batch_size=YYY```
22 |
23 | To generate grasps for a given object and all end-effectors, run:
24 |
25 | ```python3 generate_grasps_for_obj.py --saved_model_dir= --object_name=```
26 |
27 | Optionally, you can plot grasps as they're generated by passing in `--plot_grasps` to the command above.
28 |
29 | To generate grasps for all objects of the eval set and all end-effectors, run:
30 |
31 | ```python3 generate_grasps_for_all_objs.py --saved_model_dir=```
32 |
33 |
34 | ## Citing this work
35 |
36 | If you liked and used our repository, please cite us:
37 |
38 | ```
39 | @inproceedings{attarian2023geometry,
40 | title={Geometry Matching for Multi-Embodiment Grasping},
41 | author={Attarian, Maria and Asif, Muhammad Adil and Liu, Jingzhou and Hari, Ruthrash and Garg, Animesh and Gilitschenski, Igor and Tompson, Jonathan},
42 | booktitle={Proceedings of the 7th Conference on Robot Learning (CoRL)},
43 | year={2023}
44 | }
45 | ```
46 |
47 | ## License and disclaimer
48 |
49 | Copyright 2023 DeepMind Technologies Limited
50 |
51 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
52 | you may not use this file except in compliance with the Apache 2.0 license.
53 | You may obtain a copy of the Apache 2.0 license at:
54 | https://www.apache.org/licenses/LICENSE-2.0
55 |
56 | All other materials are licensed under the Creative Commons Attribution 4.0
57 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
58 | https://creativecommons.org/licenses/by/4.0/legalcode
59 |
60 | Unless required by applicable law or agreed to in writing, all software and
61 | materials distributed here under the Apache 2.0 or CC-BY licenses are
62 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
63 | either express or implied. See the licenses for the specific language governing
64 | permissions and limitations under those licenses.
65 |
66 | This is not an official Google product.
67 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Config file with parameters."""
17 |
18 | obj_pc_n = 2048
19 | robot_pc_n = 6
20 | keypoint_n = 6
21 |
22 | hidden_n = 256
23 | obj_in_feats = 3
24 | robot_in_feats = obj_in_feats
25 | obj_out_feats = 512
26 | robot_out_feats = obj_out_feats
27 | robot_weighting = 500.0
28 | matchnet_weighting = 200.0
29 | num_hidden = 3
30 |
31 | dist_loss_weight = 0.5
32 | match_loss_weight = 0.5
33 |
--------------------------------------------------------------------------------
/create_gnn_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Script that preprocesses the data into the format expected by GeoMatch."""
17 |
18 | import argparse
19 | import os
20 | import numpy as np
21 | import torch
22 | import torch.utils.data
23 | from tqdm import tqdm
24 | import trimesh as tm
25 | from utils.general_utils import get_handmodel
26 | from utils.gnn_utils import euclidean_min_dist
27 | from utils.gnn_utils import generate_adj_mat_feats
28 | from utils.gnn_utils import generate_contact_maps
29 |
30 | device = 'cpu'
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--dataset_basedir', type=str, default='/data/grasp_gnn')
35 | parser.add_argument(
36 | '--object_mesh_basedir', type=str, default='/data/grasp_gnn/object'
37 | )
38 | parser.add_argument('--gnn_dataset_basedir', type=str, default='data')
39 |
40 | args = parser.parse_args()
41 |
42 | cmap_data = torch.load(
43 | os.path.join(
44 | args.dataset_basedir, 'CMapDataset-sqrt_align/cmap_dataset.pt'
45 | ),
46 | map_location=torch.device('cpu'),
47 | )
48 | object_data = torch.load(
49 | os.path.join(
50 | args.dataset_basedir, 'CMapDataset-sqrt_align/object_point_clouds.pt'
51 | )
52 | )
53 |
54 | robot_name_list = [
55 | 'ezgripper',
56 | 'barrett',
57 | 'robotiq_3finger',
58 | 'allegro',
59 | 'shadowhand',
60 | ]
61 | hand_model = {}
62 | robot_key_point_idx = {}
63 | surface_points_per_robot = {}
64 | threshold = 0.04
65 |
66 | data_dict = {}
67 | for robot_name in tqdm(robot_name_list):
68 | hand_model[robot_name] = get_handmodel(
69 | robot_name, 1, 'cpu', 1.0, data_dir=args.dataset_basedir
70 | )
71 | joint_lower = np.array(
72 | hand_model[robot_name].revolute_joints_q_lower.cpu().reshape(-1)
73 | )
74 | joint_upper = np.array(
75 | hand_model[robot_name].revolute_joints_q_upper.cpu().reshape(-1)
76 | )
77 | joint_mid = (joint_lower + joint_upper) / 2
78 | joints_q = (joint_mid + joint_lower) / 2
79 | rest_pose = (
80 | torch.from_numpy(
81 | np.concatenate([np.array([0, 0, 0, 1, 0, 0, 0, 1, 0]), joints_q])
82 | )
83 | .unsqueeze(0)
84 | .to(device)
85 | .float()
86 | )
87 | surface_points = (
88 | hand_model[robot_name]
89 | .get_surface_points(rest_pose, downsample=True)
90 | .cpu()
91 | .squeeze(0)
92 | )
93 | key_points, key_point_idx_dict, surface_sample_kp_idx = hand_model[
94 | robot_name
95 | ].get_static_key_points(rest_pose, surface_points)
96 | robot_key_point_idx[robot_name] = key_point_idx_dict
97 | surface_points_per_robot[robot_name] = surface_points
98 | robot_adj, robot_features = generate_adj_mat_feats(surface_points, knn=8)
99 | data_dict[robot_name] = (
100 | robot_adj,
101 | robot_features,
102 | rest_pose,
103 | surface_sample_kp_idx,
104 | key_point_idx_dict,
105 | robot_name,
106 | )
107 |
108 | torch.save(
109 | data_dict,
110 | os.path.join(
111 | args.gnn_dataset_basedir, 'gnn_robot_adj_point_clouds_new.pt'
112 | ),
113 | )
114 |
115 | data_dict = {}
116 | for obj_name in tqdm(object_data):
117 | object_mesh_path = os.path.join(
118 | args.object_mesh_basedir,
119 | f'{obj_name.split("+")[0]}',
120 | f'{obj_name.split("+")[1]}',
121 | f'{obj_name.split("+")[1]}.stl',
122 | )
123 | obj_point_cloud = object_data[obj_name]
124 | obj_mesh = tm.load(object_mesh_path)
125 | normals = []
126 |
127 | for p in obj_point_cloud:
128 | dist, indices = euclidean_min_dist(p, obj_mesh.vertices)
129 | normals.append(obj_mesh.vertex_normals[indices[0]])
130 |
131 | normals = np.stack(normals, axis=0)
132 |
133 | obj_adj, obj_features = generate_adj_mat_feats(obj_point_cloud, knn=8)
134 | data_dict[obj_name] = (obj_adj, obj_features, torch.tensor(normals))
135 |
136 | torch.save(
137 | data_dict,
138 | os.path.join(args.gnn_dataset_basedir, 'gnn_obj_adj_point_clouds_new.pt'),
139 | )
140 |
141 | data_list = []
142 | for metadata in tqdm(cmap_data['metadata']):
143 | _, q, object_name, robot_name = metadata
144 | q = q.unsqueeze(0)
145 | obj_point_cloud = object_data[object_name]
146 |
147 | robot_grasp_kps, _, _ = hand_model[robot_name].get_static_key_points(
148 | q, surface_points_per_robot[robot_name]
149 | )
150 | obj_contact_map = np.zeros((6, obj_point_cloud.shape[0]))
151 | full_obj_contact_map = np.zeros((obj_point_cloud.shape[0], 1))
152 |
153 | point_dists_idxs = [
154 | euclidean_min_dist(x, obj_point_cloud) for x in robot_grasp_kps
155 | ]
156 | robot_contact_map = np.array(
157 | [int(x[0] < threshold) for x in point_dists_idxs]
158 | ).reshape(-1, 1)
159 | top_obj_contact_kps = torch.stack(
160 | [obj_point_cloud[x[1][0]] for x in point_dists_idxs], dim=0
161 | )
162 | top_obj_contact_verts = torch.tensor(
163 | [x[1][0] for x in point_dists_idxs]
164 | ).long()
165 |
166 | for i in range(top_obj_contact_verts.shape[0]):
167 | obj_contact_map[i, point_dists_idxs[i][1][:20]] = 1
168 | full_obj_contact_map[point_dists_idxs[i][1][:20]] = 1
169 |
170 | obj_contacts = generate_contact_maps(obj_contact_map)
171 | robot_contacts = generate_contact_maps(robot_contact_map)
172 |
173 | data_point = (
174 | obj_contacts,
175 | robot_contacts,
176 | top_obj_contact_kps,
177 | top_obj_contact_verts,
178 | full_obj_contact_map,
179 | q.squeeze(0),
180 | object_name,
181 | robot_name,
182 | )
183 | data_list.append(data_point)
184 |
185 | data_dict = {
186 | 'info': [
187 | 'obj_contacts',
188 | 'robot_contacts',
189 | 'top_obj_contact_kps',
190 | 'top_obj_contact_verts',
191 | 'full_obj_contact_map',
192 | 'q',
193 | 'object_name',
194 | 'robot_name',
195 | ],
196 | 'metadata': data_list,
197 | }
198 | torch.save(
199 | data_dict,
200 | os.path.join(
201 | args.gnn_dataset_basedir, 'gnn_obj_cmap_robot_cmap_adj_new.pt'
202 | ),
203 | )
204 |
--------------------------------------------------------------------------------
/generate_grasps_for_all_objs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Batch script that generates grasps for all eval objects and all end-effectors."""
17 |
18 | import argparse
19 | import json
20 | import os
21 | import subprocess
22 | import numpy as np
23 | import torch
24 |
25 |
26 | if __name__ == '__main__':
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--seed', type=int, default=0, help='Random seed.')
29 | parser.add_argument(
30 | '--data_dir',
31 | type=str,
32 | default='/data/grasp_gnn',
33 | help='Base data directory.',
34 | )
35 | parser.add_argument(
36 | '--saved_model_dir',
37 | type=str,
38 | default=(
39 | 'logs_train/exp-pos_weight_500_200_6_kps_final-1683209255.5473607/'
40 | ),
41 | )
42 | parser.add_argument('--output_dir', type=str, default='logs_out_grasps/')
43 |
44 | args = parser.parse_args()
45 |
46 | np.random.seed(args.seed)
47 | torch.manual_seed(args.seed)
48 |
49 | data_dir = args.data_dir
50 | print(f'Saved model dir: {args.saved_model_dir}')
51 |
52 | object_list = json.load(
53 | open(
54 | os.path.join(
55 | data_dir,
56 | 'CMapDataset-sqrt_align/split_train_validate_objects.json',
57 | ),
58 | 'rb',
59 | )
60 | )['validate']
61 |
62 | ps = [
63 | subprocess.Popen(
64 | [
65 | 'python',
66 | 'generate_grasps_for_obj.py',
67 | '--object_name',
68 | object_name,
69 | '--data_dir',
70 | data_dir,
71 | '--saved_model_dir',
72 | args.saved_model_dir,
73 | '--output_dir',
74 | args.output_dir,
75 | ],
76 | stdout=subprocess.PIPE,
77 | )
78 | for object_name in object_list
79 | ]
80 |
81 | exit_codes = [p.wait() for p in ps]
82 |
83 | finished_proc_inds = [i for i, p in enumerate(ps) if p.poll() is not None]
84 |
85 | print(f'Exit codes of all processes: {exit_codes}')
86 | print(f'All processes finished?: {finished_proc_inds}')
87 |
--------------------------------------------------------------------------------
/generate_grasps_for_obj.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Script that performs inference on GeoMatch to produce keypoints for a given object and all end-effectors, followed by the IK used for the paper.
17 |
18 | Note: Any other IK could be used.
19 | """
20 |
21 | import argparse
22 | import json
23 | import os
24 | import config
25 | from models.geomatch import GeoMatch
26 | import numpy as np
27 | import plotly.graph_objects as go
28 | import scipy
29 | import torch
30 | from torch import nn
31 | import trimesh as tm
32 | from utils.general_utils import get_handmodel
33 | from utils.gnn_utils import euclidean_min_dist
34 | from utils.gnn_utils import plot_mesh
35 | from utils.gnn_utils import plot_point_cloud
36 |
37 |
38 | def compute_pose_from_rotation_matrix(t_pose, r_matrix):
39 | """Computes a 6D pose from a rotation matrix."""
40 |
41 | batch = r_matrix.shape[0]
42 | joint_num = 2
43 | r_matrices = (
44 | r_matrix.view(batch, 1, 3, 3)
45 | .expand(batch, joint_num, 3, 3)
46 | .contiguous()
47 | .view(batch * joint_num, 3, 3)
48 | )
49 | src_poses = (
50 | t_pose.view(1, joint_num, 3, 1)
51 | .expand(batch, joint_num, 3, 1)
52 | .contiguous()
53 | .view(batch * joint_num, 3, 1)
54 | )
55 |
56 | out_poses = torch.matmul(r_matrices, src_poses.double())
57 |
58 | return out_poses.view(batch, joint_num * 3)
59 |
60 |
61 | def rotation_matrix_from_vectors(vec1, vec2):
62 | """Returns the rotation matrix that aligns two vectors.
63 |
64 | Source:
65 | https://stackoverflow.com/questions/45142959/calculate-rotation-matrix-to-align-two-vectors-in-3d-space
66 |
67 | Args:
68 | vec1: first vector.
69 | vec2: second vector.
70 |
71 | Returns:
72 | A 3x3 rotation_matrix/
73 | """
74 | a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (
75 | vec2 / np.linalg.norm(vec2)
76 | ).reshape(3)
77 | v = np.cross(a, b)
78 | c = np.dot(a, b)
79 | s = np.linalg.norm(v)
80 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
81 | rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2))
82 | return rotation_matrix
83 |
84 |
85 | def get_heuristic_init_pose(
86 | gripper_model, centroids, object_point_cloud, robot_point_cloud, q_rest
87 | ):
88 | """Gets an initial pose based on heuristics, to start IK."""
89 |
90 | _, sorted_vert = euclidean_min_dist(centroids, robot_point_cloud)
91 | hand_closest_point = robot_point_cloud[sorted_vert[0]]
92 | center_mass = object_point_cloud.mean(dim=0)
93 |
94 | _, sorted_vert_obj = euclidean_min_dist(centroids, object_point_cloud)
95 | obj_closest_point = object_point_cloud[sorted_vert_obj[0]]
96 |
97 | hand_vec_normal = hand_closest_point - torch.tensor(centroids)
98 | obj_vec_normal = obj_closest_point - center_mass
99 |
100 | rot_mat = rotation_matrix_from_vectors(hand_vec_normal, obj_vec_normal)
101 |
102 | new_q_rot = compute_pose_from_rotation_matrix(
103 | q_rest.squeeze()[3:9], torch.tensor(rot_mat[None])
104 | )
105 |
106 | new_hand_closest = np.matmul(rot_mat, hand_closest_point)
107 | trans = obj_closest_point - new_hand_closest
108 |
109 | gripper_model.global_rotation = torch.tensor(rot_mat[None])
110 | gripper_model.global_translation = trans[None]
111 |
112 | heuristic_q = np.concatenate(
113 | (trans, new_q_rot.squeeze(), q_rest.squeeze()[9:])
114 | )
115 | heuristic_q = torch.tensor(heuristic_q).float()
116 |
117 | return heuristic_q
118 |
119 |
120 | def autoregressive_inference(
121 | contact_map_pred,
122 | match_model,
123 | top_k,
124 | obj_pc,
125 | robot_embed,
126 | obj_embed,
127 | ):
128 | """Performs the autoregressive inference part of GeoMatch."""
129 |
130 | max_topk = max(top_k)
131 |
132 | with torch.no_grad():
133 | max_per_kp = torch.topk(contact_map_pred, k=(max_topk + 1), dim=1)
134 | all_grasps = []
135 |
136 | for k in top_k:
137 | pred_curr = None
138 | grasp_points = []
139 | contact_or_not = []
140 |
141 | obj_proj_embed = match_model.obj_proj(obj_embed)
142 | robot_proj_embed = match_model.robot_proj(robot_embed)
143 |
144 | for i_prev in range(config.keypoint_n - 1):
145 | model_kp = match_model.kp_ar_model_1
146 |
147 | if i_prev == 1:
148 | model_kp = match_model.kp_ar_model_2
149 | elif i_prev == 2:
150 | model_kp = match_model.kp_ar_model_3
151 | elif i_prev == 3:
152 | model_kp = match_model.kp_ar_model_4
153 | elif i_prev == 4:
154 | model_kp = match_model.kp_ar_model_5
155 |
156 | xyz_prev = torch.gather(
157 | obj_pc[None],
158 | 1,
159 | max_per_kp.indices[:, k, i_prev, :].repeat(1, 1, 3),
160 | )
161 |
162 | if i_prev == 0:
163 | grasp_points.append(xyz_prev.squeeze())
164 | contact_or_not.append(torch.tensor(1))
165 | else:
166 | xyz_prev = torch.stack(grasp_points, dim=0)[None]
167 |
168 | pred_curr = model_kp(
169 | obj_proj_embed, obj_pc[None], robot_proj_embed, xyz_prev
170 | )
171 | pred_prob = nn.Sigmoid()(pred_curr)
172 | vert_pred = torch.max(pred_prob[..., 0], dim=-1)
173 | min_idx = vert_pred.indices[0]
174 | contact_or_not.append(torch.tensor(int(vert_pred.values[0] >= 0.5)))
175 |
176 | pred_curr = obj_pc[min_idx]
177 | grasp_points.append(pred_curr)
178 |
179 | grasp_points = torch.stack(grasp_points, dim=0)
180 | contact_or_not = torch.stack(contact_or_not, dim=0)
181 | final_grasp_points = torch.cat(
182 | (grasp_points, contact_or_not[..., None]), dim=-1
183 | )
184 | all_grasps.append(final_grasp_points)
185 |
186 | return torch.stack(all_grasps, dim=0)
187 |
188 |
189 | def inference(
190 | geomatch_model,
191 | top_k,
192 | obj_pc,
193 | obj_adjacency,
194 | robot_point_cloud,
195 | robot_adjacency,
196 | keypoints_idx,
197 | ):
198 | """Performs full inference for GeoMatch."""
199 |
200 | with torch.no_grad():
201 | obj_embed = geomatch_model.encode_embed(
202 | geomatch_model.obj_encoder, obj_pc[None], obj_adjacency[None]
203 | )
204 | robot_embed = geomatch_model.encode_embed(
205 | geomatch_model.robot_encoder,
206 | robot_point_cloud[None],
207 | robot_adjacency[None],
208 | )
209 |
210 | robot_feat_size = robot_embed.shape[2]
211 | keypoint_feat = torch.gather(
212 | robot_embed,
213 | 1,
214 | keypoints_idx[..., None].long().repeat(1, 1, robot_feat_size),
215 | )
216 | contact_map_pred = torch.matmul(obj_embed, keypoint_feat.transpose(2, 1))[
217 | ..., None
218 | ]
219 |
220 | top_obj_contact_kps_pred = autoregressive_inference(
221 | contact_map_pred,
222 | geomatch_model,
223 | top_k,
224 | obj_pc,
225 | robot_embed,
226 | obj_embed,
227 | )
228 | pred_points = top_obj_contact_kps_pred
229 |
230 | return pred_points
231 |
232 |
233 | def inverse_kinematics_optimization(gripper_model, q_init, target_points):
234 | """Function performing IK optimization and returning a pose."""
235 |
236 | def optimize_target(q):
237 | q = torch.tensor(q).float()
238 | source_points, _, _ = gripper_model.get_static_key_points(q.unsqueeze(0))
239 | e = [
240 | np.linalg.norm(source_points[i] - target_points[i])
241 | for i in range(len(source_points))
242 | ]
243 | return e
244 |
245 | real_bounds = []
246 |
247 | for _ in range(3):
248 | real_bounds.append((-0.5, 0.5))
249 |
250 | for _ in range(6):
251 | real_bounds.append((-np.pi, np.pi))
252 |
253 | for idx, _ in enumerate(gripper_model.revolute_joints_q_lower.squeeze()):
254 | real_bounds.append((
255 | gripper_model.revolute_joints_q_lower[:, idx].squeeze().item(),
256 | gripper_model.revolute_joints_q_upper[:, idx].squeeze().item(),
257 | ))
258 |
259 | result = scipy.optimize.least_squares(
260 | optimize_target, q_init, method='trf', bounds=tuple(zip(*real_bounds))
261 | )
262 | return result
263 |
264 |
265 | if __name__ == '__main__':
266 | parser = argparse.ArgumentParser()
267 | parser.add_argument('--seed', type=int, default=0, help='Random seed.')
268 | parser.add_argument(
269 | '--device', type=str, default='cpu', help='Use cuda if available'
270 | )
271 | parser.add_argument(
272 | '--object_name',
273 | type=str,
274 | default='contactdb+rubber_duck',
275 | help='Which object to calculate grasp for.',
276 | )
277 | parser.add_argument(
278 | '--data_dir',
279 | type=str,
280 | default='/data/grasp_gnn',
281 | help='Base data directory.',
282 | )
283 | parser.add_argument(
284 | '--saved_model_dir',
285 | type=str,
286 | default=(
287 | 'logs_train/exp-pos_weight_500_200_6_kps_final-1683209255.5473607/'
288 | ),
289 | )
290 | parser.add_argument('--output_dir', type=str, default='logs_out_grasps/')
291 | parser.add_argument('--plot_grasps', default=False, action='store_true')
292 | args = parser.parse_args()
293 |
294 | np.random.seed(args.seed)
295 | torch.manual_seed(args.seed)
296 |
297 | data_dir = args.data_dir
298 | output_dir = os.path.join(args.output_dir, args.object_name)
299 | os.makedirs(output_dir, exist_ok=True)
300 |
301 | object_name = args.object_name
302 | object_mesh_basedir = os.path.join(data_dir, 'object')
303 | object_mesh_path = os.path.join(
304 | object_mesh_basedir,
305 | f'{args.object_name.split("+")[0]}',
306 | f'{args.object_name.split("+")[1]}',
307 | f'{args.object_name.split("+")[1]}.stl',
308 | )
309 | obj_mesh = tm.load(object_mesh_path)
310 | obj_normals = obj_mesh.vertex_normals
311 | plot_grasps = args.plot_grasps
312 |
313 | robot_centroids = json.load(
314 | open(os.path.join(data_dir, 'robot_centroids.json'))
315 | )
316 | robot_list = [
317 | 'ezgripper',
318 | 'barrett',
319 | 'shadowhand',
320 | ]
321 | top_ks = [0, 20, 50, 100]
322 |
323 | model = GeoMatch(config)
324 | model.load_state_dict(
325 | torch.load(
326 | os.path.join(args.saved_model_dir, 'weights/grasp_gnn.pth'),
327 | map_location=torch.device('cpu'),
328 | )
329 | )
330 |
331 | model.eval()
332 |
333 | robot_pc_adj = torch.load(
334 | os.path.join(data_dir, 'gnn_robot_adj_point_clouds_new.pt')
335 | )
336 |
337 | object_pc_adj = torch.load(
338 | os.path.join(data_dir, 'gnn_obj_adj_point_clouds_new.pt')
339 | )
340 | new_object_pc_adj = torch.load(
341 | os.path.join(data_dir, 'gnn_obj_adj_point_clouds_new_unseen.pt')
342 | )
343 | object_pc_adj.update(new_object_pc_adj)
344 |
345 | object_pc = object_pc_adj[object_name][1]
346 | obj_adj = object_pc_adj[object_name][0]
347 | generated_grasps = []
348 |
349 | for robot_name in robot_list:
350 | hand_model = get_handmodel(robot_name, 1, 'cpu', 1.0)
351 | robot_pc = robot_pc_adj[robot_name][1]
352 | robot_adj = robot_pc_adj[robot_name][0]
353 |
354 | robot_keypoints_idx = robot_pc_adj[robot_name][3]
355 | rest_pose = hand_model.rest_pose
356 | hand_centroids = robot_centroids[robot_name]
357 |
358 | q_heuristic = get_heuristic_init_pose(
359 | hand_model, hand_centroids, object_pc, robot_pc, rest_pose
360 | )
361 |
362 | all_grasps_predicted_keypoints = inference(
363 | model,
364 | top_ks,
365 | object_pc,
366 | obj_adj,
367 | robot_pc,
368 | robot_adj,
369 | robot_keypoints_idx,
370 | )
371 |
372 | for i in range(all_grasps_predicted_keypoints.shape[0]):
373 | predicted_keypoints = all_grasps_predicted_keypoints[i]
374 | closest_mesh_idxs = [
375 | euclidean_min_dist(x, obj_mesh.vertices)[1][0]
376 | for x in predicted_keypoints[:, :3]
377 | ]
378 | closest_normals = obj_normals[closest_mesh_idxs]
379 |
380 | pregrasp_pred_keypoints = predicted_keypoints[
381 | :, :3
382 | ] + 0.005 * closest_normals.astype('float32')
383 |
384 | res = inverse_kinematics_optimization(
385 | hand_model, q_heuristic, pregrasp_pred_keypoints
386 | )
387 | q_calc = res.x
388 | q_calc = torch.tensor(q_calc).float()
389 |
390 | calc_key_points, _, _ = hand_model.get_static_key_points(
391 | q_calc.unsqueeze(0)
392 | )
393 |
394 | sample = {
395 | 'object_name': object_name,
396 | 'robot_name': robot_name,
397 | 'pred_keypoints': predicted_keypoints,
398 | 'robot_final_keypoints': calc_key_points,
399 | 'init_pose': q_heuristic,
400 | 'pred_grasp_pose': q_calc,
401 | 'sample_idx': i,
402 | 'scale': 1.0,
403 | }
404 |
405 | if plot_grasps:
406 | data = [
407 | plot_mesh(mesh=tm.load(object_mesh_path), opacity=1.0, color='blue')
408 | ]
409 | data += hand_model.get_plotly_data(
410 | q=q_calc.unsqueeze(0), opacity=1.0, color='pink'
411 | )
412 | data += [plot_point_cloud(predicted_keypoints, color='green')]
413 | data += [plot_point_cloud(calc_key_points, color='purple')]
414 |
415 | fig = go.Figure(data=data)
416 | fig.show()
417 |
418 | generated_grasps.append(sample)
419 |
420 | data_dict = {
421 | 'info': [
422 | 'object_name',
423 | 'robot_name',
424 | 'pred_keypoints',
425 | 'robot_final_keypoints',
426 | 'init_pose',
427 | 'pred_grasp_pose',
428 | 'sample_idx',
429 | ],
430 | 'metadata': generated_grasps,
431 | }
432 | torch.save(data_dict, os.path.join(output_dir, 'gen_grasps.pt'))
433 |
--------------------------------------------------------------------------------
/models/geomatch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """GeoMatch model definition."""
17 |
18 | from models.gnn import GCN
19 | from models.mlp import MLP
20 | import torch
21 | from torch import nn
22 |
23 |
24 | class GeoMatchARModule(nn.Module):
25 | """Autoregressive module class for GeoMatch."""
26 |
27 | def __init__(self, config, n_kp) -> None:
28 | super().__init__()
29 |
30 | self.config = config
31 | self.n_kp = n_kp
32 | self.final_fc = MLP(128 + 3 * self.n_kp, 1, 3, 256)
33 |
34 | def forward(self, obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev):
35 | robot_i_embed = (
36 | robot_proj_embed[:, self.n_kp][..., None]
37 | .transpose(2, 1)
38 | .repeat(1, self.config.obj_pc_n, 1)
39 | )
40 | obj_robot_embed = torch.cat((obj_proj_embed, robot_i_embed), dim=-1)
41 |
42 | diff_xyz_tensor = []
43 | for i in range(self.n_kp):
44 | diff_xyz = obj_pc - xyz_prev[:, i, :][..., None].transpose(2, 1)
45 | diff_xyz_tensor.append(diff_xyz)
46 |
47 | diff_xyz_tensor = torch.stack(diff_xyz_tensor, dim=-1)
48 | diff_xyz_tensor = diff_xyz_tensor.view(
49 | diff_xyz_tensor.shape[0], diff_xyz_tensor.shape[1], -1
50 | )
51 | inp = torch.cat((obj_robot_embed, diff_xyz_tensor), dim=-1)
52 | pred_curr = self.final_fc(inp)
53 |
54 | return pred_curr
55 |
56 | def calc_loss(self, pred, label):
57 | pred = pred.view(pred.shape[0] * pred.shape[1], 1)
58 | label = label.view(label.shape[0] * label.shape[1], 1)
59 |
60 | pos_weight = torch.tensor([1000.0]).cuda()
61 | loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, label)
62 | return torch.mean(loss)
63 |
64 |
65 | class GeoMatch(nn.Module):
66 | """GeoMatch model class."""
67 |
68 | def __init__(self, config) -> None:
69 | super().__init__()
70 |
71 | self.config = config
72 | self.n_kp = config.keypoint_n
73 | self.robot_weighting = config.robot_weighting
74 | self.match_weighting = config.matchnet_weighting
75 | self.dist_loss_weight = config.dist_loss_weight
76 | self.match_loss_weight = config.match_loss_weight
77 |
78 | self.obj_encoder = GCN(
79 | nfeat=config.obj_in_feats,
80 | nhid=config.hidden_n,
81 | nout=config.obj_out_feats,
82 | dropout=0.5,
83 | num_hidden=config.num_hidden,
84 | )
85 |
86 | self.robot_encoder = GCN(
87 | nfeat=config.robot_in_feats,
88 | nhid=config.hidden_n,
89 | nout=config.robot_out_feats,
90 | dropout=0.5,
91 | num_hidden=config.num_hidden,
92 | )
93 |
94 | self.obj_proj = nn.Linear(self.config.obj_out_feats, 64, bias=False)
95 | self.robot_proj = nn.Linear(self.config.robot_out_feats, 64, bias=False)
96 | self.kp_ar_model_1 = GeoMatchARModule(config, 1)
97 | self.kp_ar_model_2 = GeoMatchARModule(config, 2)
98 | self.kp_ar_model_3 = GeoMatchARModule(config, 3)
99 | self.kp_ar_model_4 = GeoMatchARModule(config, 4)
100 | self.kp_ar_model_5 = GeoMatchARModule(config, 5)
101 |
102 | def encode_embed(self, encoder, feature, adj_mat, normalize_emb=True):
103 | x = encoder(feature, adj_mat)
104 | if normalize_emb:
105 | x = x.clone() / torch.norm(x, dim=-1, keepdim=True)
106 | return x
107 |
108 | def forward(
109 | self, obj_pc, robot_pc, robot_key_point_idx, obj_adj, robot_adj, xyz_prev
110 | ):
111 | obj_embed = self.encode_embed(self.obj_encoder, obj_pc, obj_adj)
112 | robot_embed = self.encode_embed(self.robot_encoder, robot_pc, robot_adj)
113 |
114 | robot_feat_size = robot_embed.shape[2]
115 | keypoint_feat = torch.gather(
116 | robot_embed,
117 | 1,
118 | robot_key_point_idx[..., None].long().repeat(1, 1, robot_feat_size),
119 | )
120 | contact_map_pred = torch.matmul(obj_embed, keypoint_feat.transpose(2, 1))[
121 | ..., None
122 | ]
123 |
124 | obj_proj_embed = self.obj_proj(obj_embed)
125 | robot_proj_embed = self.robot_proj(robot_embed)
126 |
127 | output_1 = self.kp_ar_model_1(
128 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev
129 | )
130 | output_2 = self.kp_ar_model_2(
131 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev
132 | )
133 | output_3 = self.kp_ar_model_3(
134 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev
135 | )
136 | output_4 = self.kp_ar_model_4(
137 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev
138 | )
139 | output_5 = self.kp_ar_model_5(
140 | obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev
141 | )
142 |
143 | output = torch.cat(
144 | (output_1, output_2, output_3, output_4, output_5), dim=-1
145 | )[..., None]
146 |
147 | return contact_map_pred, output
148 |
149 | def calc_loss(self, gt_contact_map, contact_map_pred, pred, label):
150 | flat_contact_map_pred = contact_map_pred.view(
151 | contact_map_pred.shape[0]
152 | * contact_map_pred.shape[1]
153 | * contact_map_pred.shape[2],
154 | 1,
155 | )
156 | flat_gt_contact_map = gt_contact_map.view(
157 | gt_contact_map.shape[0]
158 | * gt_contact_map.shape[1]
159 | * gt_contact_map.shape[2],
160 | 1,
161 | )
162 |
163 | pos_weight = torch.Tensor([self.robot_weighting]).cuda()
164 | loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(
165 | flat_contact_map_pred, flat_gt_contact_map
166 | )
167 | l_dist = torch.mean(loss)
168 |
169 | pos_weight = torch.tensor([self.match_weighting]).cuda()
170 |
171 | loss = []
172 | for i in range(self.n_kp - 1):
173 | pred_i = pred[:, :, i]
174 | label_i = label[:, :, i]
175 | pred_i = pred_i.view(pred_i.shape[0] * pred_i.shape[1], 1)
176 | label_i = label_i.view(label_i.shape[0] * label_i.shape[1], 1)
177 | loss.append(nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred_i, label_i))
178 |
179 | loss = torch.stack(loss)
180 | l_match = torch.mean(loss)
181 |
182 | return self.dist_loss_weight * l_dist + self.match_loss_weight * l_match
183 |
--------------------------------------------------------------------------------
/models/gnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Implementation of Graph Convolutional Neural Networks."""
17 |
18 | import copy
19 | import math
20 | import torch
21 | from torch import nn
22 | import torch.nn.functional as F
23 |
24 |
25 | def clones(module, n):
26 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
27 |
28 |
29 | class GraphConvolution(nn.Module):
30 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
31 |
32 | def __init__(self, in_features, out_features, bias=True):
33 | super().__init__()
34 | self.in_features = in_features
35 | self.out_features = out_features
36 | self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
37 | if bias:
38 | self.bias = nn.Parameter(torch.FloatTensor(out_features))
39 | else:
40 | self.register_parameter('bias', None)
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | stdv = 1.0 / math.sqrt(self.weight.size(1))
45 | self.weight.data.uniform_(-stdv, stdv)
46 | if self.bias is not None:
47 | self.bias.data.uniform_(-stdv, stdv)
48 |
49 | def forward(self, inp, adj):
50 | support = torch.matmul(inp, self.weight)
51 | output = torch.matmul(adj.to_dense(), support)
52 | if self.bias is not None:
53 | return output + self.bias
54 | else:
55 | return output
56 |
57 | def __repr__(self):
58 | return (
59 | self.__class__.__name__
60 | + ' ('
61 | + str(self.in_features)
62 | + ' -> '
63 | + str(self.out_features)
64 | + ')'
65 | )
66 |
67 |
68 | class GCN(nn.Module):
69 | """Graph Convolutional Neural Network class."""
70 |
71 | def __init__(self, nfeat, nhid, nout, dropout, num_hidden):
72 | super().__init__()
73 |
74 | self.gc0 = GraphConvolution(nfeat, nhid)
75 | self.gc_layers = clones(GraphConvolution(nhid, nhid), num_hidden)
76 | self.out = nn.Linear(nhid, nout)
77 | self.dropout = dropout
78 |
79 | def forward(self, x, adj):
80 | x = F.relu(self.gc0(x, adj))
81 |
82 | for i, _ in enumerate(self.gc_layers):
83 | x = F.relu(self.gc_layers[i](x, adj))
84 | return self.out(x)
85 |
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Implementation of a Multi-Layer Perceptron."""
17 |
18 | import copy
19 | from torch import nn
20 | import torch.nn.functional as F
21 |
22 |
23 | def clones(module, n):
24 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
25 |
26 |
27 | class MLP(nn.Module):
28 | """MLP class."""
29 |
30 | def __init__(self, in_features, out_features, num_hidden, hidden_dim) -> None:
31 | super().__init__()
32 |
33 | self.layer0 = nn.Linear(in_features, hidden_dim)
34 | self.layers = clones(nn.Linear(hidden_dim, hidden_dim), num_hidden)
35 | self.out = nn.Linear(hidden_dim, out_features)
36 |
37 | def forward(self, x):
38 | x = F.relu(self.layer0(x))
39 |
40 | for l in self.layers:
41 | x = F.relu(l(x))
42 |
43 | return self.out(x)
44 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Training script for GeoMatch."""
17 |
18 | import argparse
19 | import dataclasses
20 | import os
21 | import shutil
22 | import sys
23 | import time
24 | import config
25 | import matplotlib.pyplot as plt
26 | from models.geomatch import GeoMatch
27 | import numpy as np
28 | import torch
29 | from torch import nn
30 | from torch import optim
31 | from torch.utils.data import DataLoader
32 | from torch.utils.tensorboard import SummaryWriter
33 | from tqdm import tqdm
34 | from utils.gnn_utils import square
35 | from utils.gnn_utils import train_metrics
36 | from utils_data.gnn_dataset import GNNDataset
37 |
38 |
39 | @dataclasses.dataclass
40 | class TrainState:
41 | geomatch_model: GeoMatch
42 | train_step: int
43 | eval_step: int
44 | writer: SummaryWriter
45 | optimizer: optim.Optimizer
46 | epoch: int
47 |
48 |
49 | def train(global_args, state: TrainState, dataloader: DataLoader):
50 | """Full training function."""
51 |
52 | state.geomatch_model.train()
53 | state.optimizer.zero_grad()
54 |
55 | loss_history = []
56 | acc_history = []
57 |
58 | for _, data in enumerate(
59 | tqdm(dataloader, desc=f'EPOCH[{state.epoch}/{global_args.epochs}]')
60 | ):
61 | state.train_step += 1
62 |
63 | (
64 | obj_adj,
65 | obj_features,
66 | obj_contacts,
67 | robot_adj,
68 | robot_features,
69 | robot_key_point_idx,
70 | robot_contacts,
71 | top_obj_contact_kps,
72 | _,
73 | _,
74 | _,
75 | _,
76 | ) = data
77 |
78 | if global_args.device == 'cuda':
79 | obj_features = obj_features.cuda()
80 | obj_adj = obj_adj.cuda()
81 | obj_contacts = obj_contacts.cuda()
82 | robot_adj = robot_adj.cuda()
83 | robot_features = robot_features.cuda()
84 | robot_key_point_idx = robot_key_point_idx.cuda().long()
85 | robot_contacts = robot_contacts.cuda()
86 | top_obj_contact_kps = top_obj_contact_kps.cuda()
87 |
88 | gt_contact_map = (
89 | (obj_contacts * robot_contacts.repeat(1, 1, config.obj_pc_n))
90 | .transpose(2, 1)[..., None]
91 | .contiguous()
92 | )
93 |
94 | contact_map_pred, pred_curr = state.geomatch_model(
95 | obj_features,
96 | robot_features,
97 | robot_key_point_idx,
98 | obj_adj,
99 | robot_adj,
100 | top_obj_contact_kps,
101 | )
102 |
103 | loss_train = state.geomatch_model.calc_loss(
104 | gt_contact_map,
105 | contact_map_pred,
106 | pred_curr,
107 | gt_contact_map[:, :, 1 : config.keypoint_n, :],
108 | )
109 | (
110 | acc,
111 | _,
112 | _,
113 | true_positives,
114 | true_negatives,
115 | false_positives,
116 | false_negatives,
117 | ) = train_metrics(pred_curr, gt_contact_map[:, :, 1 : config.keypoint_n, :])
118 |
119 | state.optimizer.zero_grad()
120 | loss_train.backward()
121 |
122 | nn.utils.clip_grad_value_(state.geomatch_model.parameters(), clip_value=1.0)
123 | state.optimizer.step()
124 |
125 | loss_history.append(loss_train)
126 | acc_history.append(acc)
127 |
128 | if state.train_step % 10 == 0:
129 | loss = torch.mean(torch.stack(loss_history))
130 | square(true_positives, false_positives, true_negatives, false_negatives)
131 | precision_recall_square = plt.imread('square.jpg').transpose(2, 0, 1)
132 |
133 | state.writer.add_scalar(
134 | 'train/loss', loss.item(), global_step=state.train_step
135 | )
136 | state.writer.add_image(
137 | 'train/precision_recall',
138 | precision_recall_square,
139 | global_step=state.train_step,
140 | )
141 |
142 | plt.close()
143 |
144 | epoch_loss = torch.mean(torch.stack(loss_history))
145 | epoch_accuracy = torch.mean(torch.stack(acc_history))
146 |
147 | if state.epoch % 1 == 0:
148 | print(
149 | f'[train] loss on {state.epoch}: {epoch_loss}\n'
150 | f' accuracy: {epoch_accuracy}\n'
151 | )
152 |
153 |
154 | def validate(global_args, state: TrainState, dataloader: DataLoader):
155 | """Full evaluation function."""
156 |
157 | with torch.no_grad():
158 | state.geomatch_model.eval()
159 |
160 | loss_history = []
161 | acc_history = []
162 |
163 | for data in tqdm(
164 | dataloader, desc=f'EPOCH[{state.epoch}/{global_args.epochs}]'
165 | ):
166 | state.eval_step += 1
167 |
168 | (
169 | obj_adj,
170 | obj_features,
171 | obj_contacts,
172 | robot_adj,
173 | robot_features,
174 | robot_key_point_idx,
175 | robot_contacts,
176 | top_obj_contact_kps,
177 | _,
178 | _,
179 | _,
180 | _,
181 | ) = data
182 |
183 | if global_args.device == 'cuda':
184 | obj_features = obj_features.cuda()
185 | obj_adj = obj_adj.cuda()
186 | obj_contacts = obj_contacts.cuda()
187 | robot_adj = robot_adj.cuda()
188 | robot_features = robot_features.cuda()
189 | robot_key_point_idx = robot_key_point_idx.cuda().long()
190 | robot_contacts = robot_contacts.cuda()
191 | top_obj_contact_kps = top_obj_contact_kps.cuda()
192 |
193 | gt_contact_map = (
194 | (obj_contacts * robot_contacts.repeat(1, 1, config.obj_pc_n))
195 | .transpose(2, 1)[..., None]
196 | .contiguous()
197 | )
198 |
199 | contact_map_pred, pred_curr = state.geomatch_model(
200 | obj_features,
201 | robot_features,
202 | robot_key_point_idx,
203 | obj_adj,
204 | robot_adj,
205 | top_obj_contact_kps,
206 | )
207 |
208 | loss_val = state.geomatch_model.calc_loss(
209 | gt_contact_map,
210 | contact_map_pred,
211 | pred_curr,
212 | gt_contact_map[:, :, 1 : config.keypoint_n, :],
213 | )
214 | (
215 | acc,
216 | _,
217 | _,
218 | true_positives,
219 | true_negatives,
220 | false_positives,
221 | false_negatives,
222 | ) = train_metrics(
223 | pred_curr, gt_contact_map[:, :, 1 : config.keypoint_n, :]
224 | )
225 |
226 | loss_history.append(loss_val)
227 | acc_history.append(acc)
228 |
229 | if state.eval_step % 10 == 0:
230 | loss = torch.mean(torch.stack(loss_history))
231 | square(true_positives, false_positives, true_negatives, false_negatives)
232 | precision_recall_square = plt.imread('square.jpg').transpose(2, 0, 1)
233 |
234 | state.writer.add_scalar(
235 | 'validate/loss', loss.item(), global_step=state.eval_step
236 | )
237 | state.writer.add_image(
238 | 'validate/precision_recall',
239 | precision_recall_square,
240 | global_step=state.eval_step,
241 | )
242 |
243 | plt.close()
244 |
245 | epoch_loss = torch.mean(torch.stack(loss_history))
246 | epoch_accuracy = torch.mean(torch.stack(acc_history))
247 | print(
248 | f'[validate] loss: {epoch_loss}\n'
249 | f' accuracy: {epoch_accuracy}\n'
250 | )
251 |
252 |
253 | if __name__ == '__main__':
254 | start_time = time.time()
255 | parser = argparse.ArgumentParser()
256 | parser.add_argument(
257 | '--exp_name', type=str, default='all_end_effectors_object_split'
258 | )
259 | parser.add_argument('--seed', type=int, default=42, help='Random seed.')
260 | parser.add_argument('--batch_size', type=int, default=64, help='Random seed.')
261 | parser.add_argument(
262 | '--device', type=str, default='cuda', help='Use cuda if available'
263 | )
264 | parser.add_argument(
265 | '--epochs', type=int, default=200, help='Number of epochs to train.'
266 | )
267 | parser.add_argument(
268 | '--lr', type=float, default=1e-4, help='Initial learning rate.'
269 | )
270 | parser.add_argument(
271 | '--weight_decay',
272 | type=float,
273 | default=0.0,
274 | help='Weight decay (L2 loss on parameters).',
275 | )
276 | parser.add_argument(
277 | '--out_features',
278 | type=int,
279 | default=512,
280 | help='Number of object and end-effector feature dimension.',
281 | )
282 | parser.add_argument(
283 | '--robot_weighting',
284 | type=int,
285 | default=500,
286 | help='Weight for full distribution BCE class loss.',
287 | )
288 | parser.add_argument(
289 | '--matchnet_weighting',
290 | type=int,
291 | default=200,
292 | help='Weight for matching BCE class loss.',
293 | )
294 | parser.add_argument(
295 | '--exclude_ee_list',
296 | nargs='+',
297 | help='End-effector to be excluded from the default list.',
298 | )
299 |
300 | parser.add_argument(
301 | '--dataset_basedir',
302 | type=str,
303 | default='/data/grasp_gnn',
304 | help='Path to data.',
305 | )
306 |
307 | args = parser.parse_args()
308 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
309 |
310 | np.random.seed(args.seed)
311 | torch.manual_seed(args.seed)
312 |
313 | if args.device == 'cuda':
314 | torch.cuda.manual_seed_all(args.seed)
315 |
316 | config.obj_out_feats = args.out_features
317 | config.robot_out_feats = config.obj_out_feats
318 | config.robot_weighting = args.robot_weighting
319 | config.matchnet_weighting = args.matchnet_weighting
320 |
321 | robot_name_list = [
322 | 'ezgripper',
323 | 'barrett',
324 | 'robotiq_3finger',
325 | 'allegro',
326 | 'shadowhand',
327 | ]
328 |
329 | if args.exclude_ee_list is not None:
330 | for ee in args.exclude_ee_list:
331 | robot_name_list.remove(ee)
332 |
333 | log_dir = os.path.join('logs_train', f'exp-{args.exp_name}-{str(start_time)}')
334 | weight_dir = os.path.join(log_dir, 'weights')
335 | tb_dir = os.path.join(log_dir, 'tb_dir')
336 | shutil.rmtree(log_dir, ignore_errors=True)
337 | os.makedirs(log_dir, exist_ok=True)
338 | os.makedirs(weight_dir, exist_ok=True)
339 | os.makedirs(tb_dir, exist_ok=True)
340 | f = open(os.path.join(log_dir, 'command.txt'), 'w')
341 | f.write(' '.join(sys.argv))
342 | f.close()
343 | writer = SummaryWriter(log_dir=tb_dir)
344 |
345 | dataset_basedir = args.dataset_basedir
346 |
347 | batchsize = args.batch_size
348 | train_dataset = GNNDataset(
349 | dataset_basedir=dataset_basedir,
350 | mode='train',
351 | device=args.device,
352 | robot_name_list=robot_name_list,
353 | )
354 | train_dataloader = DataLoader(
355 | dataset=train_dataset,
356 | batch_size=batchsize,
357 | shuffle=True,
358 | num_workers=0,
359 | )
360 |
361 | validate_dataset = GNNDataset(
362 | dataset_basedir=dataset_basedir,
363 | mode='validate',
364 | device=args.device,
365 | robot_name_list=robot_name_list,
366 | )
367 | validate_dataloader = DataLoader(
368 | dataset=validate_dataset,
369 | batch_size=batchsize,
370 | shuffle=True,
371 | num_workers=0,
372 | )
373 |
374 | geomatch_model = GeoMatch(config)
375 | geomatch_model = geomatch_model.to(args.device)
376 |
377 | optimizer = optim.Adam(
378 | list(geomatch_model.parameters()),
379 | lr=args.lr,
380 | weight_decay=args.weight_decay,
381 | betas=(0.9, 0.99),
382 | )
383 |
384 | torch.save(
385 | geomatch_model.state_dict(), os.path.join(weight_dir, 'grasp_gnn.pth')
386 | )
387 |
388 | train_step = 0
389 | val_step = 0
390 |
391 | train_state = TrainState(
392 | geomatch_model=geomatch_model,
393 | writer=writer,
394 | optimizer=optimizer,
395 | train_step=train_step,
396 | eval_step=val_step,
397 | epoch=0,
398 | )
399 |
400 | for i_epoch in range(args.epochs):
401 | train(args, train_state, train_dataloader)
402 | validate(args, train_state, validate_dataloader)
403 |
404 | torch.save(
405 | geomatch_model.state_dict(), os.path.join(weight_dir, 'grasp_gnn.pth')
406 | )
407 | writer.close()
408 | print(f'consuming time: {time.time() - start_time}')
409 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """General utilities."""
17 |
18 | import json
19 | import os
20 | import random
21 | import numpy as np
22 | import torch
23 | from utils import gripper_utils
24 |
25 |
26 | def get_handmodel(
27 | robot, batch_size, device, hand_scale=1.0, data_dir='/data/grasp_gnn'
28 | ):
29 | """Fetches the hand model object for a given gripper."""
30 | urdf_assets_meta = json.load(
31 | open(os.path.join(data_dir, 'urdf/urdf_assets_meta.json'))
32 | )
33 | urdf_path = urdf_assets_meta['urdf_path'][robot].replace('data', data_dir)
34 | meshes_path = urdf_assets_meta['meshes_path'][robot].replace('data', data_dir)
35 | hand_model = gripper_utils.HandModel(
36 | robot,
37 | urdf_path,
38 | meshes_path,
39 | batch_size=batch_size,
40 | device=device,
41 | hand_scale=hand_scale,
42 | data_dir=data_dir,
43 | )
44 | return hand_model
45 |
46 |
47 | def set_global_seed(seed=42):
48 | torch.cuda.manual_seed_all(seed)
49 | torch.manual_seed(seed)
50 | np.random.seed(seed)
51 | random.seed(seed)
52 |
--------------------------------------------------------------------------------
/utils/gnn_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """GNN utilities."""
17 |
18 | import os
19 | import matplotlib
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 | import plotly.graph_objects as go
23 | import scipy.sparse as sp
24 | from scipy.spatial import distance
25 | import torch
26 | from torch import nn
27 | import trimesh as tm
28 |
29 | colors = ['blue', 'red', 'yellow', 'pink', 'gray', 'orange']
30 |
31 |
32 | def plot_mesh(mesh, color='lightblue', opacity=1.0):
33 | return go.Mesh3d(
34 | x=mesh.vertices[:, 0],
35 | y=mesh.vertices[:, 1],
36 | z=mesh.vertices[:, 2],
37 | i=mesh.faces[:, 0],
38 | j=mesh.faces[:, 1],
39 | k=mesh.faces[:, 2],
40 | color=color,
41 | opacity=opacity,
42 | )
43 |
44 |
45 | def plot_point_cloud(
46 | pts, color='lightblue', mode='markers', colorscale='Viridis', size=3
47 | ):
48 | return go.Scatter3d(
49 | x=pts[:, 0],
50 | y=pts[:, 1],
51 | z=pts[:, 2],
52 | mode=mode,
53 | marker=dict(color=color, colorscale=colorscale, size=size),
54 | )
55 |
56 |
57 | def plot_grasp(
58 | hand_model,
59 | key_points_idx_dict,
60 | q,
61 | object_name,
62 | obj_pc,
63 | obj_contact_map,
64 | selected_kp_idx,
65 | selected_vert,
66 | object_mesh_basedir='/data/grasp_gnn/object',
67 | ):
68 | """Plots a given grasp."""
69 |
70 | robot_keypoints_trans = hand_model.get_key_points_from_indices(
71 | key_points_idx_dict, q=q
72 | )
73 | vis_data = hand_model.get_plotly_data(q=q, opacity=0.5)
74 | object_mesh_path = os.path.join(
75 | object_mesh_basedir,
76 | f'{object_name.split("+")[0]}',
77 | f'{object_name.split("+")[1]}',
78 | f'{object_name.split("+")[1]}.stl',
79 | )
80 | vis_data += [plot_mesh(mesh=tm.load(object_mesh_path))]
81 | vis_data += [plot_point_cloud(obj_pc, obj_contact_map.squeeze())]
82 | vis_data += [plot_point_cloud(selected_vert, color='red')]
83 | vis_data += [plot_point_cloud(robot_keypoints_trans, color='black')]
84 | vis_data += [
85 | plot_point_cloud(
86 | robot_keypoints_trans[selected_kp_idx, :][None], color='red'
87 | )
88 | ]
89 | fig = go.Figure(data=vis_data)
90 | fig.show()
91 |
92 |
93 | def plot_hand_only(hand_model, q, key_points, selected_kp):
94 | vis_data = hand_model.get_plotly_data(q=q, opacity=0.5)
95 | vis_data += [plot_point_cloud(key_points, color='black')]
96 | vis_data += [plot_point_cloud(selected_kp, color='red')]
97 | fig = go.Figure(data=vis_data)
98 | fig.show()
99 |
100 |
101 | def plot_obj_only(obj_pc, obj_contact_map, selected_vert):
102 | vis_data = [plot_point_cloud(obj_pc, obj_contact_map.squeeze())]
103 | vis_data += [plot_point_cloud(selected_vert, color='red')]
104 | fig = go.Figure(data=vis_data)
105 | fig.show()
106 |
107 |
108 | def encode_onehot(labels, threshold=0.5):
109 | if isinstance(labels, np.ndarray):
110 | labels = torch.Tensor(labels)
111 | labels_onehot = torch.where(labels > threshold, 1.0, 0.0)
112 | return labels_onehot
113 |
114 |
115 | def euclidean_min_dist(point, point_cloud):
116 | dist_array = np.linalg.norm(
117 | np.array(point_cloud) - np.array(point).reshape((1, 3)), axis=-1
118 | )
119 | return np.min(dist_array), np.argsort(dist_array)
120 |
121 |
122 | def generate_contact_maps(contact_map):
123 | labels = contact_map
124 | labels = torch.FloatTensor(encode_onehot(labels))
125 | return labels
126 |
127 |
128 | def normalize_pc(points, scale_fact):
129 | centroid = torch.mean(points, dim=0)
130 | points -= centroid
131 | points /= scale_fact
132 |
133 | return points
134 |
135 |
136 | def generate_adj_mat_feats(point_cloud, knn=8):
137 | """Generates a graph from a given point cloud based on k-NN."""
138 |
139 | features = sp.csr_matrix(point_cloud, dtype=np.float32)
140 |
141 | # build graph
142 | dist = distance.squareform(distance.pdist(np.asarray(point_cloud)))
143 | closest = np.argsort(dist, axis=1)
144 | adj = np.zeros(closest.shape)
145 |
146 | for i in range(adj.shape[0]):
147 | adj[i, closest[i, 0 : knn + 1]] = 1
148 |
149 | adj = sp.coo_matrix(adj)
150 |
151 | # build symmetric adjacency matrix
152 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
153 |
154 | # features = normalize(features)
155 | adj = normalize(adj + sp.eye(adj.shape[0]))
156 |
157 | features = torch.FloatTensor(np.array(features.todense()))
158 | adj = sparse_mx_to_torch_sparse_tensor(adj)
159 |
160 | return adj, features
161 |
162 |
163 | def normalize(mx):
164 | """Row-normalize sparse matrix."""
165 | rowsum = np.array(mx.sum(1))
166 | r_inv = np.power(rowsum, -1).flatten()
167 | r_inv[np.isinf(r_inv)] = 0.0
168 | r_mat_inv = sp.diags(r_inv)
169 | mx = r_mat_inv.dot(mx)
170 | return mx
171 |
172 |
173 | def train_metrics(output, labels, threshold=0.5):
174 | """Generates all training metrics."""
175 |
176 | batch_size = labels.shape[0]
177 | size = labels.shape[1] * labels.shape[2]
178 | total_num = batch_size * size
179 | preds = nn.Sigmoid()(output)
180 | preds = encode_onehot(preds, threshold=threshold)
181 |
182 | true_positives = (preds * labels).sum()
183 | false_positives = (preds * (1 - labels)).sum()
184 | false_negatives = ((1 - preds) * labels).sum()
185 | true_negatives = ((1 - preds) * (1 - labels)).sum()
186 |
187 | precision = true_positives / ((true_positives + false_positives) + 1e-6)
188 | recall = true_positives / (true_positives + false_negatives)
189 | acc = (true_negatives + true_positives) / total_num
190 | return (
191 | acc,
192 | precision,
193 | recall,
194 | true_positives,
195 | true_negatives,
196 | false_positives,
197 | false_negatives,
198 | )
199 |
200 |
201 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
202 | """Convert a scipy sparse matrix to a torch sparse tensor."""
203 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
204 | indices = torch.from_numpy(
205 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
206 | )
207 | values = torch.from_numpy(sparse_mx.data)
208 | shape = torch.Size(sparse_mx.shape)
209 | return torch.sparse.FloatTensor(indices, values, shape)
210 |
211 |
212 | def graph_cross_convolution(inp, kernel, inp_adj, krn_adj):
213 | """Experimental: Graph Cross Convolution."""
214 | kernel = torch.matmul(krn_adj.to_dense(), kernel)
215 | support = torch.matmul(inp, kernel.transpose(-2, -1))
216 | output = torch.matmul(inp_adj.to_dense(), support)
217 | return output
218 |
219 |
220 | def matplotlib_imshow(img, one_channel=False):
221 | if one_channel:
222 | img = img.mean(dim=0)
223 | img = img / 2 + 0.5 # unnormalize
224 | npimg = img.numpy()
225 | if one_channel:
226 | plt.imshow(npimg, cmap='Greys')
227 | else:
228 | plt.imshow(np.transpose(npimg, (1, 2, 0)))
229 |
230 |
231 | def square(true_positives, false_positives, true_negatives, false_negatives):
232 | """Defines a precision/recall square for Tensorboard plotting."""
233 |
234 | fig = plt.figure(figsize=(9, 9))
235 | ax = fig.add_subplot(111)
236 | rect1 = matplotlib.patches.Rectangle((0, 2), 2, 2, color='green')
237 | rect2 = matplotlib.patches.Rectangle((2, 2), 2, 2, color='red')
238 | rect3 = matplotlib.patches.Rectangle((0, 0), 2, 2, color='red')
239 | rect4 = matplotlib.patches.Rectangle((2, 0), 2, 2, color='green')
240 |
241 | ax.add_patch(rect1)
242 | ax.add_patch(rect2)
243 | ax.add_patch(rect3)
244 | ax.add_patch(rect4)
245 | rectangles = [rect1, rect2, rect3, rect4]
246 | tags = [
247 | 'True Positives=' + str(true_positives.item()),
248 | 'False Positives=' + str(false_positives.item()),
249 | 'True Negatives=' + str(true_negatives.item()),
250 | 'False Negatives=' + str(false_negatives.item()),
251 | ]
252 |
253 | for r in range(4):
254 | ax.add_artist(rectangles[r])
255 | rx, ry = rectangles[r].get_xy()
256 | cx = rx + rectangles[r].get_width() / 2.0
257 | cy = ry + rectangles[r].get_height() / 2.0
258 |
259 | ax.annotate(
260 | tags[r],
261 | (cx, cy),
262 | color='w',
263 | weight='bold',
264 | fontsize=12,
265 | ha='center',
266 | va='center',
267 | )
268 |
269 | plt.xlim([0, 4])
270 | plt.ylim([0, 4])
271 | plt.savefig('square.jpg')
272 |
--------------------------------------------------------------------------------
/utils/gripper_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities to represent an end-effector."""
17 |
18 | import json
19 | import os
20 | import numpy as np
21 | from plotly import graph_objects as go
22 | import pytorch_kinematics as pk
23 | from pytorch_kinematics.urdf_parser_py.urdf import Box
24 | from pytorch_kinematics.urdf_parser_py.urdf import Cylinder
25 | from pytorch_kinematics.urdf_parser_py.urdf import Mesh
26 | from pytorch_kinematics.urdf_parser_py.urdf import Sphere
27 | from pytorch_kinematics.urdf_parser_py.urdf import URDF
28 | import torch
29 | import torch.nn
30 | import transforms3d
31 | import trimesh as tm
32 | import trimesh.sample
33 | import urdf_parser_py.urdf as URDF_PARSER
34 | from utils import math_utils
35 |
36 |
37 | class HandModel:
38 | """Hand model class based on: https://github.com/tengyu-liu/GenDexGrasp/blob/main/utils_model/HandModel.py."""
39 |
40 | def __init__(
41 | self,
42 | robot_name,
43 | urdf_filename,
44 | mesh_path,
45 | batch_size=1,
46 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
47 | hand_scale=2.0,
48 | data_dir='data',
49 | ):
50 | self.device = device
51 | self.batch_size = batch_size
52 | self.data_dir = data_dir
53 |
54 | self.robot = pk.build_chain_from_urdf(open(urdf_filename).read()).to(
55 | dtype=torch.float, device=self.device
56 | )
57 | self.robot_full = URDF_PARSER.URDF.from_xml_file(urdf_filename)
58 |
59 | if robot_name == 'allegro_right':
60 | self.robot_name = 'allegro_right'
61 | robot_name = 'allegro'
62 | else:
63 | self.robot_name = robot_name
64 |
65 | self.global_translation = None
66 | self.global_rotation = None
67 | self.softmax = torch.nn.Softmax(dim=-1)
68 |
69 | self.contact_point_dict = json.load(
70 | open(os.path.join(self.data_dir, 'urdf/contact_%s.json' % robot_name))
71 | )
72 | self.contact_point_basis = {}
73 | self.contact_normals = {}
74 | self.surface_points = {}
75 | self.surface_points_normal = {}
76 | visual = URDF.from_xml_string(open(urdf_filename).read())
77 | self.centroids = json.load(
78 | open(os.path.join(self.data_dir, 'robot_centroids.json'))
79 | )[robot_name]
80 | self.keypoints = json.load(
81 | open(os.path.join(self.data_dir, 'robot_keypoints.json'))
82 | )[self.robot_name]
83 | self.key_point_idx_dict = {}
84 | self.mesh_verts = {}
85 | self.mesh_faces = {}
86 |
87 | self.canon_verts = []
88 | self.canon_faces = []
89 | self.idx_vert_faces = []
90 | self.face_normals = []
91 | self.links = [link.name for link in visual.links]
92 |
93 | for i_link, link in enumerate(visual.links):
94 | print(f'Processing link #{i_link}: {link.name}')
95 |
96 | if not link.visuals:
97 | continue
98 | if isinstance(link.visuals[0].geometry, Mesh):
99 | if (
100 | robot_name == 'shadowhand'
101 | or robot_name == 'allegro'
102 | or robot_name == 'barrett'
103 | ):
104 | filename = link.visuals[0].geometry.filename.split('/')[-1]
105 | elif robot_name == 'allegro':
106 | filename = f"{link.visuals[0].geometry.filename.split('/')[-2]}/{link.visuals[0].geometry.filename.split('/')[-1]}"
107 | else:
108 | filename = link.visuals[0].geometry.filename
109 | mesh = tm.load(
110 | os.path.join(mesh_path, filename), force='mesh', process=False
111 | )
112 | elif isinstance(link.visuals[0].geometry, Cylinder):
113 | mesh = tm.primitives.Cylinder(
114 | radius=link.visuals[0].geometry.radius,
115 | height=link.visuals[0].geometry.length,
116 | )
117 | elif isinstance(link.visuals[0].geometry, Box):
118 | mesh = tm.primitives.Box(extents=link.visuals[0].geometry.size)
119 | elif isinstance(link.visuals[0].geometry, Sphere):
120 | mesh = tm.primitives.Sphere(radius=link.visuals[0].geometry.radius)
121 | else:
122 | print(type(link.visuals[0].geometry))
123 | raise NotImplementedError
124 | try:
125 | scale = np.array(link.visuals[0].geometry.scale).reshape([1, 3])
126 | except Exception: # pylint: disable=broad-exception-caught
127 | scale = np.array([[1, 1, 1]])
128 | try:
129 | rotation = transforms3d.euler.euler2mat(*link.visuals[0].origin.rpy)
130 | translation = np.reshape(link.visuals[0].origin.xyz, [1, 3])
131 |
132 | except Exception: # pylint: disable=broad-exception-caught
133 | rotation = transforms3d.euler.euler2mat(0, 0, 0)
134 | translation = np.array([[0, 0, 0]])
135 |
136 | if self.robot_name == 'shadowhand':
137 | pts, pts_face_index = trimesh.sample.sample_surface(mesh=mesh, count=64)
138 | pts_normal = np.array(
139 | [mesh.face_normals[x] for x in pts_face_index], dtype=float
140 | )
141 | else:
142 | pts, pts_face_index = trimesh.sample.sample_surface(
143 | mesh=mesh, count=128
144 | )
145 | pts_normal = np.array(
146 | [mesh.face_normals[x] for x in pts_face_index], dtype=float
147 | )
148 |
149 | pts *= scale
150 | if robot_name == 'shadowhand':
151 | pts = pts[:, [0, 2, 1]]
152 | pts_normal = pts_normal[:, [0, 2, 1]]
153 | pts[:, 1] *= -1
154 | pts_normal[:, 1] *= -1
155 |
156 | pts = np.matmul(rotation, pts.T).T + translation
157 | pts = np.concatenate([pts, np.ones([len(pts), 1])], axis=-1)
158 | pts_normal = np.concatenate(
159 | [pts_normal, np.ones([len(pts_normal), 1])], axis=-1
160 | )
161 | self.surface_points[link.name] = (
162 | torch.from_numpy(pts)
163 | .to(device)
164 | .float()
165 | .unsqueeze(0)
166 | .repeat(batch_size, 1, 1)
167 | )
168 | self.surface_points_normal[link.name] = (
169 | torch.from_numpy(pts_normal)
170 | .to(device)
171 | .float()
172 | .unsqueeze(0)
173 | .repeat(batch_size, 1, 1)
174 | )
175 |
176 | # visualization mesh
177 | self.mesh_verts[link.name] = np.array(mesh.vertices) * scale
178 | if robot_name == 'shadowhand':
179 | self.mesh_verts[link.name] = self.mesh_verts[link.name][:, [0, 2, 1]]
180 | self.mesh_verts[link.name][:, 1] *= -1
181 | self.mesh_verts[link.name] = (
182 | np.matmul(rotation, self.mesh_verts[link.name].T).T + translation
183 | )
184 | self.mesh_faces[link.name] = np.array(mesh.faces)
185 |
186 | # contact point
187 | if link.name in self.contact_point_dict:
188 | cpb = np.array(self.contact_point_dict[link.name])
189 | if len(cpb.shape) > 1:
190 | cpb = cpb[np.random.randint(cpb.shape[0], size=1)][0]
191 |
192 | cp_basis = mesh.vertices[cpb] * scale
193 | if robot_name == 'shadowhand':
194 | cp_basis = cp_basis[:, [0, 2, 1]]
195 | cp_basis[:, 1] *= -1
196 | cp_basis = np.matmul(rotation, cp_basis.T).T + translation
197 | cp_basis = torch.cat(
198 | [
199 | torch.from_numpy(cp_basis).to(device).float(),
200 | torch.ones([4, 1]).to(device).float(),
201 | ],
202 | dim=-1,
203 | )
204 | self.contact_point_basis[link.name] = cp_basis.unsqueeze(0).repeat(
205 | batch_size, 1, 1
206 | )
207 | v1 = cp_basis[1, :3] - cp_basis[0, :3]
208 | v2 = cp_basis[2, :3] - cp_basis[0, :3]
209 | v1 = v1 / torch.norm(v1)
210 | v2 = v2 / torch.norm(v2)
211 | self.contact_normals[link.name] = torch.cross(v1, v2).view([1, 3])
212 | self.contact_normals[link.name] = (
213 | self.contact_normals[link.name]
214 | .unsqueeze(0)
215 | .repeat(batch_size, 1, 1)
216 | )
217 |
218 | self.scale = hand_scale
219 |
220 | # new 2.1
221 | self.revolute_joints = []
222 | for i, _ in enumerate(self.robot_full.joints):
223 | if self.robot_full.joints[i].joint_type == 'revolute':
224 | self.revolute_joints.append(self.robot_full.joints[i])
225 | self.revolute_joints_q_mid = []
226 | self.revolute_joints_q_var = []
227 | self.revolute_joints_q_upper = []
228 | self.revolute_joints_q_lower = []
229 | for i, _ in enumerate(self.robot.get_joint_parameter_names()):
230 | for j, _ in enumerate(self.revolute_joints):
231 | if (
232 | self.revolute_joints[j].name
233 | == self.robot.get_joint_parameter_names()[i]
234 | ):
235 | joint = self.revolute_joints[j]
236 | assert joint.name == self.robot.get_joint_parameter_names()[i]
237 | self.revolute_joints_q_mid.append(
238 | (joint.limit.lower + joint.limit.upper) / 2
239 | )
240 | self.revolute_joints_q_var.append(
241 | ((joint.limit.upper - joint.limit.lower) / 2) ** 2
242 | )
243 | self.revolute_joints_q_lower.append(joint.limit.lower)
244 | self.revolute_joints_q_upper.append(joint.limit.upper)
245 |
246 | joint_lower = np.array(self.revolute_joints_q_lower)
247 | joint_upper = np.array(self.revolute_joints_q_upper)
248 | joint_mid = (joint_lower + joint_upper) / 2
249 | joints_q = (joint_mid + joint_lower) / 2
250 | self.rest_pose = (
251 | torch.from_numpy(
252 | np.concatenate([np.array([0, 0, 0, 1, 0, 0, 0, 1, 0]), joints_q])
253 | )
254 | .unsqueeze(0)
255 | .to(device)
256 | .float()
257 | )
258 |
259 | self.revolute_joints_q_lower = (
260 | torch.Tensor(self.revolute_joints_q_lower)
261 | .repeat([self.batch_size, 1])
262 | .to(device)
263 | )
264 | self.revolute_joints_q_upper = (
265 | torch.Tensor(self.revolute_joints_q_upper)
266 | .repeat([self.batch_size, 1])
267 | .to(device)
268 | )
269 |
270 | self.rest_pose = self.rest_pose.repeat([self.batch_size, 1])
271 |
272 | self.current_status = None
273 | self.canonical_keypoints = self.get_canonical_keypoints().to(device)
274 |
275 | def update_kinematics(self, q):
276 | self.global_translation = q[:, :3]
277 |
278 | self.global_rotation = (
279 | math_utils.robust_compute_rotation_matrix_from_ortho6d(q[:, 3:9])
280 | )
281 | self.current_status = self.robot.forward_kinematics(q[:, 9:])
282 |
283 | def get_surface_points(self, q=None, downsample=False):
284 | """Returns surface points on the end-effector on a given pose."""
285 |
286 | if q is not None:
287 | self.update_kinematics(q)
288 | surface_points = []
289 | for link_name in self.surface_points:
290 | if self.robot_name == 'robotiq_3finger' and link_name == 'gripper_palm':
291 | continue
292 | if (
293 | self.robot_name == 'robotiq_3finger_real_robot'
294 | and link_name == 'palm'
295 | ):
296 | continue
297 | trans_matrix = self.current_status[link_name].get_matrix()
298 | surface_points.append(
299 | torch.matmul(
300 | trans_matrix, self.surface_points[link_name].transpose(1, 2)
301 | ).transpose(1, 2)[..., :3]
302 | )
303 | surface_points = torch.cat(surface_points, 1)
304 | surface_points = torch.matmul(
305 | self.global_rotation, surface_points.transpose(1, 2)
306 | ).transpose(1, 2) + self.global_translation.unsqueeze(1)
307 | if downsample:
308 | surface_points = surface_points[
309 | :, torch.randperm(surface_points.shape[1])
310 | ][:, :1000]
311 | return surface_points * self.scale
312 |
313 | def get_canonical_keypoints(self):
314 | """Returns canonical keypoints aka the N user-selected keypoints."""
315 |
316 | self.update_kinematics(self.rest_pose)
317 | key_points = np.array([
318 | np.array(keypoint[str(i)][0])
319 | for i, keypoint in enumerate(self.keypoints)
320 | ])
321 | key_points = torch.tensor(key_points).unsqueeze(0).float().to(self.device)
322 | key_points = key_points.repeat(self.batch_size, 1, 1)
323 | key_points -= self.global_translation.unsqueeze(1)
324 | key_points = torch.matmul(
325 | torch.inverse(self.global_rotation), key_points.transpose(1, 2)
326 | ).transpose(1, 2)
327 | new_key_points = []
328 |
329 | for i, keypoint in enumerate(self.keypoints):
330 | curr_keypoint = key_points[0, i, :]
331 | curr_keypoint_link_name = keypoint[str(i)][1]
332 | curr_keypoint = torch.cat(
333 | (curr_keypoint.clone().detach(), torch.tensor([1.0]).to(self.device))
334 | ).float()
335 |
336 | trans_matrix = self.current_status[curr_keypoint_link_name].get_matrix()
337 | # Address batch size if present.
338 | if trans_matrix.shape[0] != self.batch_size:
339 | trans_matrix = trans_matrix.repeat(self.batch_size, 1, 1)
340 |
341 | self.key_point_idx_dict[curr_keypoint_link_name] = []
342 | new_key_points.append(
343 | torch.matmul(
344 | torch.inverse(trans_matrix),
345 | curr_keypoint[None, None].transpose(1, 2),
346 | ).transpose(1, 2)[..., :3]
347 | )
348 | new_key_points = torch.cat(new_key_points, 1)
349 |
350 | return new_key_points
351 |
352 | def get_static_key_points(self, q, surface_pt_sample=None):
353 | """Returns the canonical keypoints when in a given end-effector pose."""
354 |
355 | final_key_points = []
356 | final_key_points_idx = []
357 |
358 | self.update_kinematics(q)
359 |
360 | for i in range(self.canonical_keypoints.shape[1]):
361 | curr_keypoint = self.canonical_keypoints[0, i, :]
362 | curr_keypoint_link_name = self.keypoints[i][str(i)][1]
363 | curr_keypoint = torch.cat(
364 | (curr_keypoint.clone().detach(), torch.tensor([1.0]).to(self.device))
365 | ).float()
366 |
367 | if surface_pt_sample is not None:
368 | self.key_point_idx_dict[curr_keypoint_link_name] = []
369 |
370 | trans_matrix = self.current_status[curr_keypoint_link_name].get_matrix()
371 | final_key_points.append(
372 | torch.matmul(
373 | trans_matrix, curr_keypoint[None, None].transpose(1, 2)
374 | ).transpose(1, 2)[..., :3]
375 | )
376 |
377 | final_key_points = torch.cat(final_key_points, 1)
378 | final_key_points = torch.matmul(
379 | self.global_rotation, final_key_points.transpose(1, 2)
380 | ).transpose(1, 2) + self.global_translation.unsqueeze(1)
381 | final_key_points = (final_key_points * self.scale).squeeze(0)
382 |
383 | if surface_pt_sample is not None:
384 | for i, final_kp in enumerate(final_key_points):
385 | curr_keypoint_link_name = self.keypoints[i][str(i)][1]
386 | closest_vert_idx = np.argsort(
387 | np.linalg.norm(
388 | np.array(self.mesh_verts[curr_keypoint_link_name])
389 | - np.array(final_kp).reshape((1, 3)),
390 | axis=-1,
391 | )
392 | )[0]
393 | self.key_point_idx_dict[curr_keypoint_link_name].append(
394 | torch.tensor(closest_vert_idx)
395 | )
396 | closest_surface_sample_idx = np.argsort(
397 | np.linalg.norm(
398 | np.array(surface_pt_sample)
399 | - np.array(final_kp).reshape((1, 3)),
400 | axis=-1,
401 | )
402 | )[0]
403 | final_key_points_idx.append(closest_surface_sample_idx)
404 |
405 | for k in self.key_point_idx_dict:
406 | self.key_point_idx_dict[k] = torch.tensor(self.key_point_idx_dict[k])
407 |
408 | return (
409 | final_key_points,
410 | self.key_point_idx_dict,
411 | torch.Tensor(final_key_points_idx),
412 | )
413 |
414 | def get_key_points_from_indices(self, key_point_idx_dict, q=None):
415 | """Returns keypoints from a set of indices when in a given pose."""
416 |
417 | if q is not None:
418 | self.update_kinematics(q)
419 |
420 | key_points = []
421 | for link_name in key_point_idx_dict:
422 | trans_matrix = self.current_status[link_name].get_matrix()
423 | pts = np.concatenate(
424 | [
425 | self.mesh_verts[link_name],
426 | np.ones([len(self.mesh_verts[link_name]), 1]),
427 | ],
428 | axis=-1,
429 | )
430 | surface_points = (
431 | torch.from_numpy(pts).float().unsqueeze(0).repeat(1, 1, 1)
432 | )
433 |
434 | key_point_idx = key_point_idx_dict[link_name]
435 | key_points.append(
436 | torch.matmul(
437 | trans_matrix,
438 | surface_points[:, key_point_idx.long(), :].transpose(1, 2),
439 | ).transpose(1, 2)
440 | )
441 |
442 | key_points = torch.cat(key_points, dim=1)[..., :3]
443 | key_points = torch.matmul(
444 | self.global_rotation, key_points.transpose(1, 2)
445 | ).transpose(1, 2) + self.global_translation.unsqueeze(1)
446 | return (key_points * self.scale).squeeze(0)
447 |
448 | def get_meshes_from_q(self, q=None, i=0):
449 | """Returns gripper meshes in a given pose."""
450 |
451 | data = []
452 | if q is not None:
453 | self.update_kinematics(q)
454 | for _, link_name in enumerate(self.mesh_verts):
455 | trans_matrix = self.current_status[link_name].get_matrix()
456 | trans_matrix = (
457 | trans_matrix[min(len(trans_matrix) - 1, i)].detach().cpu().numpy()
458 | )
459 | v = self.mesh_verts[link_name]
460 | transformed_v = np.concatenate([v, np.ones([len(v), 1])], axis=-1)
461 | transformed_v = np.matmul(trans_matrix, transformed_v.T).T[..., :3]
462 | transformed_v = np.matmul(
463 | self.global_rotation[i].detach().cpu().numpy(), transformed_v.T
464 | ).T + np.expand_dims(self.global_translation[i].detach().cpu().numpy(), 0)
465 | transformed_v = transformed_v * self.scale
466 | f = self.mesh_faces[link_name]
467 | data.append(tm.Trimesh(vertices=transformed_v, faces=f))
468 | return data
469 |
470 | def get_plotly_data(self, q=None, i=0, color='lightblue', opacity=1.0):
471 | """Returns plot data for the gripper in a given pose."""
472 |
473 | data = []
474 | if q is not None:
475 | self.update_kinematics(q)
476 | for _, link_name in enumerate(self.mesh_verts):
477 | trans_matrix = self.current_status[link_name].get_matrix()
478 | trans_matrix = (
479 | trans_matrix[min(len(trans_matrix) - 1, i)].detach().cpu().numpy()
480 | )
481 | v = self.mesh_verts[link_name]
482 | transformed_v = np.concatenate([v, np.ones([len(v), 1])], axis=-1)
483 | transformed_v = np.matmul(trans_matrix, transformed_v.T).T[..., :3]
484 | transformed_v = np.matmul(
485 | self.global_rotation[i].detach().cpu().numpy(), transformed_v.T
486 | ).T + np.expand_dims(self.global_translation[i].detach().cpu().numpy(), 0)
487 | transformed_v = transformed_v * self.scale
488 | f = self.mesh_faces[link_name]
489 | data.append(
490 | go.Mesh3d(
491 | x=transformed_v[:, 0],
492 | y=transformed_v[:, 1],
493 | z=transformed_v[:, 2],
494 | i=f[:, 0],
495 | j=f[:, 1],
496 | k=f[:, 2],
497 | color=color,
498 | opacity=opacity,
499 | )
500 | )
501 | return data
502 |
--------------------------------------------------------------------------------
/utils/math_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Math utilities for rotations and vector algebra."""
17 |
18 | import numpy as np
19 | import torch
20 | import transforms3d
21 |
22 |
23 | def get_rot6d_from_rot3d(rot3d):
24 | global_rotation = np.array(
25 | transforms3d.euler.euler2mat(rot3d[0], rot3d[1], rot3d[2])
26 | )
27 | return global_rotation.T.reshape(9)[:6]
28 |
29 |
30 | def robust_compute_rotation_matrix_from_ortho6d(poses):
31 | """TODO(jmattarian): Code from: XXXXXXX."""
32 |
33 | x_raw = poses[:, 0:3]
34 | y_raw = poses[:, 3:6]
35 |
36 | x = normalize_vector(x_raw)
37 | y = normalize_vector(y_raw)
38 | middle = normalize_vector(x + y)
39 | orthmid = normalize_vector(x - y)
40 | x = normalize_vector(middle + orthmid)
41 | y = normalize_vector(middle - orthmid)
42 | z = normalize_vector(cross_product(x, y))
43 |
44 | x = x.view(-1, 3, 1)
45 | y = y.view(-1, 3, 1)
46 | z = z.view(-1, 3, 1)
47 | matrix = torch.cat((x, y, z), 2)
48 | return matrix
49 |
50 |
51 | def normalize_vector(v):
52 | batch = v.shape[0]
53 | v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
54 | v_mag = torch.max(v_mag, v.new([1e-8]))
55 | v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
56 | v = v / v_mag
57 | return v
58 |
59 |
60 | def cross_product(u, v):
61 | batch = u.shape[0]
62 | i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
63 | j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
64 | k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
65 |
66 | out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1)
67 |
68 | return out
69 |
--------------------------------------------------------------------------------
/utils_data/augmentors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Augmentation classes for point clouds."""
17 |
18 | import numpy as np
19 | import torch
20 |
21 |
22 | def angle_axis(angle: float, axis: np.ndarray):
23 | """Returns a 4x4 rotation matrix that performs a rotation around axis by angle."""
24 |
25 | u = axis / np.linalg.norm(axis)
26 | cosval, sinval = np.cos(angle), np.sin(angle)
27 |
28 | cross_prod_mat = np.array(
29 | [[0.0, -u[2], u[1]], [u[2], 0.0, -u[0]], [-u[1], u[0], 0.0]]
30 | )
31 |
32 | r = torch.from_numpy(
33 | cosval * np.eye(3)
34 | + sinval * cross_prod_mat
35 | + (1.0 - cosval) * np.outer(u, u)
36 | )
37 | return r.float()
38 |
39 |
40 | class PointcloudJitter(object):
41 | """Adds jitter with a given std to a point cloud."""
42 |
43 | def __init__(self, std=0.005, clip=0.025):
44 | self.std, self.clip = std, clip
45 |
46 | def __call__(self, points):
47 | jittered_data = (
48 | points.new(points.size(0), 3)
49 | .normal_(mean=0.0, std=self.std)
50 | .clamp_(-self.clip, self.clip)
51 | )
52 | points[:, 0:3] += jittered_data
53 | return points
54 |
55 |
56 | class PointcloudScale(object):
57 |
58 | def __init__(self, lo=0.8, hi=1.25):
59 | self.lo, self.hi = lo, hi
60 |
61 | def __call__(self, points):
62 | scaler = np.random.uniform(self.lo, self.hi)
63 | points[:, 0:3] *= scaler
64 | return points
65 |
66 |
67 | class PointcloudTranslate(object):
68 |
69 | def __init__(self, translate_range=0.1):
70 | self.translate_range = translate_range
71 |
72 | def __call__(self, points):
73 | translation = np.random.uniform(-self.translate_range, self.translate_range)
74 | points[:, 0:3] += translation
75 | return points
76 |
77 |
78 | class PointcloudRotatePerturbation(object):
79 | """Applies a random rotation to a point cloud."""
80 |
81 | def __init__(self, angle_sigma=0.06, angle_clip=0.18):
82 | self.angle_sigma, self.angle_clip = angle_sigma, angle_clip
83 |
84 | def _get_angles(self):
85 | angles = np.clip(
86 | self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip
87 | )
88 |
89 | return angles
90 |
91 | def __call__(self, points):
92 | angles = self._get_angles()
93 | rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0]))
94 | ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0]))
95 | rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0]))
96 |
97 | rotation_matrix = torch.matmul(torch.matmul(rz, ry), rx)
98 |
99 | normals = points.size(1) > 3
100 | if not normals:
101 | return torch.matmul(points, rotation_matrix.t())
102 | else:
103 | pc_xyz = points[:, 0:3]
104 | pc_normals = points[:, 3:]
105 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
106 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())
107 |
108 | return points
109 |
--------------------------------------------------------------------------------
/utils_data/gnn_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Implements Pytorch dataloader class needed for GeoMatch."""
17 |
18 | import json
19 | import os
20 | import torch
21 | from torch.utils import data
22 | from utils_data import augmentors
23 |
24 |
25 | class GNNDataset(data.Dataset):
26 | """Dataloader class for GeoMatch."""
27 |
28 | def __init__(
29 | self,
30 | dataset_basedir,
31 | object_npts=2048,
32 | device='cuda' if torch.cuda.is_available() else 'cpu',
33 | mode='train',
34 | robot_name_list=None,
35 | ):
36 | self.device = device
37 | self.dataset_basedir = dataset_basedir
38 | self.object_npts = object_npts
39 |
40 | if not robot_name_list:
41 | self.robot_name_list = [
42 | 'ezgripper',
43 | 'barrett',
44 | 'robotiq_3finger',
45 | 'allegro',
46 | 'shadowhand',
47 | ]
48 | else:
49 | self.robot_name_list = robot_name_list
50 |
51 | print('loading object point clouds and adjacency matrices....')
52 | self.object_pc_adj = torch.load(
53 | os.path.join(dataset_basedir, 'gnn_obj_adj_point_clouds_new.pt')
54 | )
55 |
56 | print('loading robot point clouds and adjacency matrices...')
57 | self.robot_pc_adj = torch.load(
58 | os.path.join(dataset_basedir, 'gnn_robot_adj_point_clouds_new.pt')
59 | )
60 |
61 | print('loading object/robot cmaps....')
62 | cmap_dataset = torch.load(
63 | os.path.join(dataset_basedir, 'gnn_obj_cmap_robot_cmap_adj_new.pt')
64 | )['metadata']
65 |
66 | self.metadata = cmap_dataset
67 |
68 | if mode == 'train':
69 | self.object_list = json.load(
70 | open(
71 | os.path.join(
72 | dataset_basedir,
73 | 'CMapDataset-sqrt_align/split_train_validate_objects.json',
74 | ),
75 | 'rb',
76 | )
77 | )[mode]
78 | self.metadata = [
79 | t
80 | for t in self.metadata
81 | if t[6] in self.object_list and t[7] in self.robot_name_list
82 | ]
83 | elif mode == 'validate':
84 | self.object_list = json.load(
85 | open(
86 | os.path.join(
87 | dataset_basedir,
88 | 'CMapDataset-sqrt_align/split_train_validate_objects.json',
89 | ),
90 | 'rb',
91 | )
92 | )[mode]
93 | self.metadata = [
94 | t
95 | for t in self.metadata
96 | if t[6] in self.object_list and t[7] in self.robot_name_list
97 | ]
98 | elif mode == 'full':
99 | self.object_list = (
100 | json.load(
101 | open(
102 | os.path.join(
103 | dataset_basedir,
104 | 'CMapDataset-sqrt_align/split_train_validate_objects.json',
105 | ),
106 | 'rb',
107 | )
108 | )['train']
109 | + json.load(
110 | open(
111 | os.path.join(
112 | dataset_basedir, 'split_train_validate_objects.json'
113 | ),
114 | 'rb',
115 | )
116 | )['validate']
117 | )
118 | self.metadata = [
119 | t
120 | for t in self.metadata
121 | if t[6] in self.object_list and t[7] in self.robot_name_list
122 | ]
123 | else:
124 | raise NotImplementedError()
125 | print(f'object selection: {self.object_list}')
126 |
127 | self.mode = mode
128 |
129 | self.datasize = len(self.metadata)
130 | print('finish loading dataset....')
131 |
132 | self.rotate = augmentors.PointcloudRotatePerturbation(
133 | angle_sigma=0.03, angle_clip=0.1
134 | )
135 | self.translate = augmentors.PointcloudTranslate(translate_range=0.01)
136 | self.jitter = augmentors.PointcloudJitter(std=0.04, clip=0.1)
137 |
138 | def __len__(self):
139 | return self.datasize
140 |
141 | def __getitem__(self, item):
142 | object_name = self.metadata[item][6]
143 | robot_name = self.metadata[item][7]
144 |
145 | obj_adj = self.object_pc_adj[object_name][0]
146 | obj_contacts = self.metadata[item][0]
147 | obj_features = self.object_pc_adj[object_name][1]
148 |
149 | if self.mode in ['train']:
150 | obj_features = self.rotate(obj_features)
151 |
152 | robot_adj = self.robot_pc_adj[robot_name][0]
153 | robot_features = self.robot_pc_adj[robot_name][1]
154 | robot_key_point_idx = self.robot_pc_adj[robot_name][3]
155 | assert robot_key_point_idx.shape[0] == 6
156 | robot_contacts = self.metadata[item][1]
157 | top_obj_contact_kps = self.metadata[item][2]
158 | assert top_obj_contact_kps.shape[0] == 6
159 | top_obj_contact_verts = self.metadata[item][3]
160 | assert top_obj_contact_verts.shape[0] == 6
161 | full_obj_contact_map = self.metadata[item][4]
162 |
163 | if self.mode in ['train']:
164 | robot_features = self.rotate(robot_features)
165 |
166 | return (
167 | obj_adj,
168 | obj_features,
169 | obj_contacts,
170 | robot_adj,
171 | robot_features,
172 | robot_key_point_idx,
173 | robot_contacts,
174 | top_obj_contact_kps,
175 | top_obj_contact_verts,
176 | full_obj_contact_map,
177 | object_name,
178 | robot_name,
179 | )
180 |
--------------------------------------------------------------------------------
/visualize_geomatch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Visualization helped for GeoMatch predictions."""
17 |
18 | import argparse
19 | import itertools
20 | import json
21 | import os
22 | import random
23 | import config
24 | from models.geomatch import GeoMatch
25 | from plotly.subplots import make_subplots
26 | import torch
27 | from torch import nn
28 | from utils.general_utils import get_handmodel
29 | from utils.gnn_utils import plot_point_cloud
30 |
31 |
32 | def return_random_obj_ee_pair(
33 | obj_name, rbt_name, obj_data, rbt_data, contact_map_data_gt
34 | ):
35 | """Returns a random object-gripper pair to use for visualization."""
36 |
37 | obj_adj = obj_data[obj_name][0]
38 | obj_pc = obj_data[obj_name][1]
39 | robot_adj = rbt_data[rbt_name][0]
40 | robot_pc = rbt_data[rbt_name][1]
41 | rest_pose = rbt_data[rbt_name][2]
42 | keypoints_idx = rbt_data[rbt_name][3]
43 | keypoints_idx_dict = rbt_data[rbt_name][4]
44 |
45 | found = False
46 | idx_list = []
47 | for i, data in enumerate(contact_map_data_gt['metadata']):
48 | if data[6] == obj_name and data[7] == rbt_name:
49 | idx_list.append(i)
50 | found = True
51 | break
52 |
53 | if not found:
54 | raise ModuleNotFoundError('Did not find a matching combination, try again!')
55 |
56 | rand_idx = random.choice(idx_list)
57 | obj_cmap = contact_map_data_gt['metadata'][rand_idx][0]
58 | robot_cmap = contact_map_data_gt['metadata'][rand_idx][1]
59 | top_obj_contact_kps = contact_map_data_gt['metadata'][rand_idx][2]
60 | top_obj_contact_verts = contact_map_data_gt['metadata'][rand_idx][3]
61 | q = contact_map_data_gt['metadata'][rand_idx][4]
62 |
63 | return (
64 | obj_adj,
65 | obj_pc,
66 | robot_adj,
67 | robot_pc,
68 | rest_pose,
69 | keypoints_idx,
70 | obj_cmap,
71 | robot_cmap,
72 | top_obj_contact_kps,
73 | top_obj_contact_verts,
74 | q,
75 | keypoints_idx_dict,
76 | )
77 |
78 |
79 | def autoregressive_inference(
80 | contact_map_pred, match_model, obj_pc, robot_embed, obj_embed, top_k=0
81 | ):
82 | """Performs the autoregressive inference of GeoMatch."""
83 |
84 | with torch.no_grad():
85 | max_per_kp = torch.topk(contact_map_pred, k=3, dim=1)
86 | pred_curr = None
87 | grasp_points = []
88 | contact_or_not = []
89 |
90 | obj_proj_embed = match_model.obj_proj(obj_embed)
91 | robot_proj_embed = match_model.robot_proj(robot_embed)
92 |
93 | for i_prev in range(config.keypoint_n - 1):
94 | model_kp = match_model.kp_ar_model_1
95 |
96 | if i_prev == 1:
97 | model_kp = match_model.kp_ar_model_2
98 | elif i_prev == 2:
99 | model_kp = match_model.kp_ar_model_3
100 | elif i_prev == 3:
101 | model_kp = match_model.kp_ar_model_4
102 | elif i_prev == 4:
103 | model_kp = match_model.kp_ar_model_5
104 |
105 | xyz_prev = torch.gather(
106 | obj_pc[None],
107 | 1,
108 | max_per_kp.indices[:, top_k, i_prev, :].repeat(1, 1, 3),
109 | )
110 |
111 | if i_prev == 0:
112 | grasp_points.append(xyz_prev.squeeze())
113 | contact_or_not.append(torch.tensor(1))
114 | else:
115 | xyz_prev = torch.stack(grasp_points, dim=0)[None]
116 |
117 | pred_curr = model_kp(
118 | obj_proj_embed, obj_pc[None], robot_proj_embed, xyz_prev
119 | )
120 | pred_prob = nn.Sigmoid()(pred_curr)
121 | vert_pred = torch.max(pred_prob[..., 0], dim=-1)
122 | min_idx = vert_pred.indices[0]
123 | contact_or_not.append(torch.tensor(int(vert_pred.values[0] >= 0.5)))
124 |
125 | # Projected on object
126 | pred_curr = obj_pc[min_idx]
127 | grasp_points.append(pred_curr)
128 |
129 | grasp_points = torch.stack(grasp_points, dim=0)
130 | contact_or_not = torch.stack(contact_or_not, dim=0)
131 |
132 | return torch.cat((grasp_points, contact_or_not[..., None]), dim=-1)
133 |
134 |
135 | def plot_side_by_side(
136 | point_cloud,
137 | contact_map,
138 | hand_data,
139 | i_keypoint,
140 | save_dir,
141 | save_plot,
142 | gt_contact_map=None,
143 | pred_points=None,
144 | top_obj_contact_kps=None,
145 | ):
146 | """Side-by-side plots of the object point cloud with the predicted keypoints, gripper with the canonical keypoints and a GT sample for comparison.
147 |
148 | Each keypoint will generate a new plot. The current keypoint is depicted with
149 | a different color.
150 |
151 | Args:
152 | point_cloud: the object point cloud
153 | contact_map: the predicted contact map
154 | hand_data: gripper data to plot - mesh, point cloud etc.
155 | i_keypoint: i-th keypoint to plot data for
156 | save_dir: directory to save plots in
157 | save_plot: bool, whether to save plots
158 | gt_contact_map: ground truth contact map for comparison
159 | pred_points: predicted keypoints to plot
160 | top_obj_contact_kps: grouth truth contact points to plot
161 | """
162 | fig = make_subplots(
163 | rows=1,
164 | cols=3,
165 | specs=[
166 | [{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]
167 | ],
168 | )
169 | fig.add_trace(
170 | plot_point_cloud(point_cloud, contact_map.squeeze()), row=1, col=1
171 | )
172 |
173 | if pred_points is not None:
174 | pred_points = pred_points.detach().numpy()
175 |
176 | for i in range(pred_points.shape[0]):
177 | c = 'black'
178 | if pred_points[i, 3] == 1.0:
179 | c = 'red'
180 | fig.add_trace(
181 | plot_point_cloud(pred_points[i, :3][None], color=c, size=5),
182 | row=1,
183 | col=1,
184 | )
185 | fig.add_trace(
186 | plot_point_cloud(
187 | pred_points[i_keypoint][None], color='magenta', size=5
188 | ),
189 | row=1,
190 | col=1,
191 | )
192 |
193 | for d in hand_data:
194 | fig.add_trace(d, row=1, col=2)
195 |
196 | if gt_contact_map is not None:
197 | fig.add_trace(
198 | plot_point_cloud(point_cloud, gt_contact_map.squeeze()), row=1, col=3
199 | )
200 |
201 | if top_obj_contact_kps is not None:
202 | fig.add_trace(
203 | plot_point_cloud(
204 | top_obj_contact_kps[i_keypoint, :][None], color='red'
205 | ),
206 | row=1,
207 | col=3,
208 | )
209 |
210 | fig.update_layout(
211 | height=800,
212 | width=1800,
213 | title_text=(
214 | f'Prediction on keypoint {i_keypoint} and one GT grasp for'
215 | ' comparison.'
216 | ),
217 | )
218 | if not save_plot:
219 | fig.show()
220 | else:
221 | fig.write_image(
222 | os.path.join(save_dir, f'prediction+rand_gt_keypoint_{i_keypoint}'),
223 | 'jpg',
224 | scale=2,
225 | )
226 |
227 |
228 | def plot_predicted_keypoints(
229 | obj_name,
230 | rbt_name,
231 | obj_data,
232 | rbt_data,
233 | contact_map_data_gt,
234 | geomatch_model,
235 | save_dir,
236 | save_plot,
237 | data_basedir,
238 | top_k=0,
239 | ):
240 | """Generates plots for a predicted grasp for a given object-gripper pair."""
241 | (
242 | obj_adj,
243 | obj_pc,
244 | robot_adj,
245 | robot_pc,
246 | rest_pose,
247 | keypoints_idx,
248 | obj_cmap,
249 | robot_cmap,
250 | top_obj_contact_kps,
251 | _,
252 | _,
253 | _,
254 | ) = return_random_obj_ee_pair(
255 | obj_name, rbt_name, obj_data, rbt_data, contact_map_data_gt
256 | )
257 |
258 | with torch.no_grad():
259 | obj_embed = geomatch_model.encode_embed(
260 | geomatch_model.obj_encoder, obj_pc[None], obj_adj[None]
261 | )
262 | robot_embed = geomatch_model.encode_embed(
263 | geomatch_model.robot_encoder, robot_pc[None], robot_adj[None]
264 | )
265 |
266 | robot_feat_size = robot_embed.shape[2]
267 | keypoint_feat = torch.gather(
268 | robot_embed,
269 | 1,
270 | keypoints_idx[..., None].long().repeat(1, 1, robot_feat_size),
271 | )
272 | contact_map_pred = torch.matmul(obj_embed, keypoint_feat.transpose(2, 1))[
273 | ..., None
274 | ]
275 | gt_contact_map = (
276 | (obj_cmap * robot_cmap.repeat(1, config.obj_pc_n))
277 | .transpose(1, 0)[..., None]
278 | .contiguous()
279 | )
280 |
281 | top_obj_contact_kps_pred = autoregressive_inference(
282 | contact_map_pred, model, obj_pc, robot_embed, obj_embed, top_k
283 | )
284 | pred_points = top_obj_contact_kps_pred
285 |
286 | hand_model = get_handmodel(rbt_name, 1, 'cpu', 1.0, data_dir=data_basedir)
287 | print('PREDICTION: ', pred_points)
288 |
289 | for i in range(contact_map_pred.shape[2]):
290 | gt_contact_map_i = gt_contact_map[:, i, :]
291 |
292 | obj_kp_cmap = contact_map_pred[:, :, i, :]
293 | obj_kp_cmap_labels = torch.nn.Sigmoid()(obj_kp_cmap)
294 |
295 | selected_kp = robot_pc[keypoints_idx[i].long(), :][None]
296 | vis_data = hand_model.get_plotly_data(q=rest_pose, opacity=0.5)
297 | vis_data += [
298 | plot_point_cloud(robot_pc[keypoints_idx.long(), :].cpu(), color='black')
299 | ]
300 | vis_data += [plot_point_cloud(selected_kp.cpu(), color='red')]
301 |
302 | plot_side_by_side(
303 | obj_pc,
304 | obj_kp_cmap_labels.detach().numpy(),
305 | vis_data,
306 | i,
307 | save_dir,
308 | save_plot,
309 | gt_contact_map_i,
310 | pred_points,
311 | top_obj_contact_kps,
312 | )
313 |
314 |
315 | if __name__ == '__main__':
316 | parser = argparse.ArgumentParser()
317 | parser.add_argument('--object_name', type=str, default='')
318 | parser.add_argument('--robot_name', type=str, default='')
319 | parser.add_argument('--random_example', default=True, action='store_true')
320 | parser.add_argument('--dataset_dir', type=str, default='/data/grasp_gnn')
321 | parser.add_argument('--save_plots', default=False, action='store_true')
322 | parser.add_argument('--top_k_idx', type=int, default=0)
323 | parser.add_argument(
324 | '--saved_model_dir',
325 | type=str,
326 | default='logs_train/exp-pos_weight_500_200_6_kps-1683055568.7644374/',
327 | )
328 | args = parser.parse_args()
329 |
330 | dataset_basedir = args.dataset_dir
331 | saved_model_dir = args.saved_model_dir
332 | top_k_idx = args.top_k_idx
333 |
334 | saved_model_dir = args.saved_model_dir
335 | save_plot_dir = os.path.join(saved_model_dir, 'plots')
336 |
337 | if not os.path.exists(save_plot_dir):
338 | os.mkdir(save_plot_dir)
339 |
340 | device = 'cpu'
341 | object_data = torch.load(
342 | os.path.join(dataset_basedir, 'gnn_obj_adj_point_clouds_new.pt')
343 | )
344 | robot_data = torch.load(
345 | os.path.join(dataset_basedir, 'gnn_robot_adj_point_clouds_new.pt')
346 | )
347 | cmap_data_gt = torch.load(
348 | os.path.join(dataset_basedir, 'gnn_obj_cmap_robot_cmap_adj_new.pt'),
349 | map_location=torch.device('cpu'),
350 | )
351 |
352 | eval_object_list = json.load(
353 | open(
354 | os.path.join(
355 | dataset_basedir,
356 | 'CMapDataset-sqrt_align/split_train_validate_objects.json',
357 | ),
358 | 'rb',
359 | )
360 | )['validate']
361 | robot_name_list = [
362 | 'ezgripper',
363 | 'barrett',
364 | 'robotiq_3finger',
365 | 'allegro',
366 | 'shadowhand',
367 | ]
368 |
369 | obj_robot_pairs = list(itertools.product(eval_object_list, robot_name_list))
370 |
371 | model = GeoMatch(config)
372 | model.load_state_dict(
373 | torch.load(
374 | os.path.join(saved_model_dir, 'weights/grasp_gnn.pth'),
375 | map_location=torch.device('cpu'),
376 | )
377 | )
378 |
379 | model.eval()
380 |
381 | object_name = args.object_name
382 | robot_name = args.robot_name
383 |
384 | plot_predicted_keypoints(
385 | object_name,
386 | robot_name,
387 | object_data,
388 | robot_data,
389 | cmap_data_gt,
390 | model,
391 | save_plot_dir,
392 | args.save_plots,
393 | dataset_basedir,
394 | top_k_idx,
395 | )
396 |
--------------------------------------------------------------------------------