├── .gitignore ├── LICENSE.txt ├── README.md ├── configs ├── dataset │ ├── cmap_dataset.yaml │ └── pretrain_dataset.yaml ├── log.yaml ├── model.yaml ├── pretrain.yaml ├── train.yaml └── validate.yaml ├── data_utils ├── CMapDataset.py ├── PretrainDataset.py ├── filter_dataset.py ├── generate_pc.py └── removed_links.json ├── model ├── encoder.py ├── latent_encoder.py ├── mlp.py ├── module.py ├── network.py └── transformer.py ├── pipeline.jpg ├── pretrain.py ├── requirements.txt ├── scripts ├── download_ckpt.sh ├── download_data.sh ├── example_isaac.py ├── example_pretrain.py └── pretrain_order.py ├── train.py ├── utils ├── controller.py ├── func_utils.py ├── hand_model.py ├── mesh_utils.py ├── multilateration.py ├── optimization.py ├── pretrain_utils.py ├── rotation.py ├── se3_transform.py └── vis_utils.py ├── validate.py ├── validation ├── __init__.py ├── asset_info.py ├── isaac_main.py ├── isaac_validator.py └── validate_utils.py └── visualization ├── vis_controller.py ├── vis_dataset.py ├── vis_hand_joint.py ├── vis_hand_link.py ├── vis_optimization.py ├── vis_pretrain.py └── vis_validation.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | .idea/ 3 | logs/ 4 | lightning_logs/ 5 | 6 | ckpt/ 7 | data/ 8 | output/ 9 | outputs/ 10 | tmp/ 11 | validate_output/ 12 | vis_info/ 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhenyu Wei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # $\mathcal{D(R,O)}$ Grasp 2 | 3 | Official Code Repository for **$\mathcal{D(R,O)}$ Grasp: A Unified Representation of Robot and Object Interaction for Cross-Embodiment Dexterous Grasping**. 4 | 5 | [Zhenyu Wei](https://zhenyuwei2003.github.io/)1,2\*, [Zhixuan Xu](https://ariszxxu.github.io/)1\*, [Jingxiang Guo](https://borisguo6.github.io)1, [Yiwen Hou](https://houyiwen.github.io/)1, [Chongkai Gao](https://chongkaigao.com/)1, Zhehao Cai1, Jiayu Luo1, [Lin Shao](https://linsats.github.io/)1 6 | 7 | 1National University of Singapore, 2Shanghai Jiao Tong University 8 | 9 | * denotes equal contribution 10 | 11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
,, ), ...] 78 | } 79 | """ 80 | 81 | if with_heatmap: 82 | hands = {} 83 | object_pcs = {} 84 | object_normals = {} 85 | 86 | info = {} 87 | metadata = [] 88 | for robot_name in ['allegro', 'barrett', 'ezgripper', 'robotiq_3finger', 'shadowhand']: 89 | num_total = 0 90 | num_upper_object = 0 91 | num_per_object = {} 92 | 93 | metadata_dir = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/{robot_name}') 94 | object_names = os.listdir(metadata_dir) 95 | for file_name in tqdm(sorted(object_names)): 96 | object_name, _, success_num = file_name.rpartition('_') 97 | success_num = int(success_num[:-3]) # remove '.pt' 98 | 99 | num_total += success_num 100 | if success_num > num_upper_object: 101 | num_upper_object = success_num 102 | num_per_object[object_name] = success_num 103 | 104 | q = torch.load(os.path.join(metadata_dir, file_name)) 105 | for q_idx in range(q.shape[0]): 106 | if with_heatmap: # compute heatmap use GenDexGrasp method to keep consistency 107 | if robot_name in hands: 108 | hand = hands[robot_name] 109 | else: 110 | hand = create_hand_model(robot_name) 111 | hands[robot_name] = hand 112 | robot_pc = hand.get_transformed_links_pc(q[q_idx])[:, :3] 113 | 114 | if object_name not in object_pcs: 115 | name = object_name.split('+') 116 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') 117 | mesh = trimesh.load_mesh(object_path) 118 | object_pc, face_indices = mesh.sample(2048, return_index=True) 119 | object_pc = torch.tensor(object_pc, dtype=torch.float32) 120 | object_normal = torch.tensor(mesh.face_normals[face_indices], dtype=torch.float32) 121 | object_pcs[object_name] = object_pc 122 | object_normals[object_name] = object_normal 123 | else: 124 | object_pc = object_pcs[object_name] 125 | object_normal = object_normals[object_name] 126 | 127 | n_robot = robot_pc.shape[0] 128 | n_object = object_pc.shape[0] 129 | 130 | robot_pc = robot_pc.unsqueeze(0).repeat(n_object, 1, 1) 131 | object_pc = object_pc.unsqueeze(0).repeat(n_robot, 1, 1).transpose(0, 1) 132 | object_normal = object_normal.unsqueeze(0).repeat(n_robot, 1, 1).transpose(0, 1) 133 | 134 | object_hand_dist = (robot_pc - object_pc).norm(dim=2) 135 | object_hand_align = ((robot_pc - object_pc) * object_normal).sum(dim=2) 136 | object_hand_align /= (object_hand_dist + 1e-5) 137 | 138 | object_hand_align_dist = object_hand_dist * torch.exp(1 - object_hand_align) 139 | contact_dist = torch.sqrt(object_hand_align_dist.min(dim=1)[0]) 140 | contact_value_current = 1 - 2 * (torch.sigmoid(10 * contact_dist) - 0.5) 141 | heapmap = contact_value_current.unsqueeze(-1) 142 | 143 | metadata.append((heapmap, q[q_idx], object_name, robot_name)) 144 | else: 145 | metadata.append((q[q_idx], object_name, robot_name)) 146 | 147 | info[robot_name] = { 148 | 'robot_name': robot_name, 149 | 'num_total': num_total, 150 | 'num_upper_object': num_upper_object, 151 | 'num_per_object': num_per_object 152 | } 153 | 154 | dataset = { 155 | 'info': info, 156 | 'metadata': metadata 157 | } 158 | if with_heatmap: 159 | torch.save(dataset, os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset_heatmap.pt')) 160 | torch.save(object_pcs, os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/object_point_clouds.pt')) 161 | else: 162 | torch.save(dataset, os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt')) 163 | 164 | print("Post process done!") 165 | 166 | 167 | if __name__ == '__main__': 168 | warnings.simplefilter(action='ignore', category=FutureWarning) 169 | 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--gpu_list', # input format like '--gpu_list 0,1,2,3,4,5,6,7' 172 | default=['0', '1', '2', '3', '4', '5', '6', '7'], 173 | type=lambda string: string.split(',')) 174 | parser.add_argument('--print_info', action='store_true') 175 | parser.add_argument('--post_process', action='store_true') 176 | parser.add_argument('--with_heatmap', action='store_true') 177 | args = parser.parse_args() 178 | 179 | assert not (args.print_info and args.post_process) 180 | if args.print_info: 181 | # dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset/cmap_dataset.pt') 182 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 183 | info = torch.load(dataset_path, map_location=torch.device('cpu'))['info'] 184 | if 'cmap_func' in info: 185 | print(f"cmap_func: {info['cmap_func']}") 186 | del info['cmap_func'] 187 | 188 | for robot_name in info.keys(): 189 | print(f"********************************") 190 | print(f"robot_name: {info[robot_name]['robot_name']}") 191 | print(f"num_total: {info[robot_name]['num_total']}") 192 | print(f"num_upper_object: {info[robot_name]['num_upper_object']}") 193 | print(f"num_per_object: {len(info[robot_name]['num_per_object'])}") 194 | for object_name in sorted(info[robot_name]['num_per_object'].keys()): 195 | print(f" {object_name}: {info[robot_name]['num_per_object'][object_name]}") 196 | print(f"********************************") 197 | elif args.post_process: 198 | post_process(args.with_heatmap) 199 | else: 200 | filter_dataset(args.gpu_list) 201 | -------------------------------------------------------------------------------- /data_utils/generate_pc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import time 5 | import viser 6 | import torch 7 | import trimesh 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.hand_model import create_hand_model 13 | 14 | 15 | def generate_object_pc(args): 16 | """ object/{contactdb, ycb}/ .pt: (num_points, 6), point xyz + normal """ 17 | for dataset_type in ['contactdb', 'ycb']: 18 | input_dir = str(os.path.join(ROOT_DIR, args.object_source_path, dataset_type)) 19 | output_dir = str(os.path.join(ROOT_DIR, args.save_path, 'object', dataset_type)) 20 | os.makedirs(output_dir, exist_ok=True) 21 | 22 | for object_name in os.listdir(input_dir): 23 | if not os.path.isdir(os.path.join(input_dir, object_name)): # skip json file 24 | continue 25 | print(f'Processing {dataset_type}/{object_name}...') 26 | mesh_path = os.path.join(input_dir, object_name, f'{object_name}.stl') 27 | mesh = trimesh.load_mesh(mesh_path) 28 | object_pc, face_indices = mesh.sample(args.num_points, return_index=True) 29 | object_pc = torch.tensor(object_pc, dtype=torch.float32) 30 | normals = torch.tensor(mesh.face_normals[face_indices], dtype=torch.float32) 31 | object_pc_normals = torch.cat([object_pc, normals], dim=-1) 32 | torch.save(object_pc_normals, os.path.join(output_dir, f'{object_name}.pt')) 33 | 34 | print("\nGenerating object point cloud finished.") 35 | 36 | 37 | def generate_robot_pc(args): 38 | output_dir = str(os.path.join(ROOT_DIR, args.save_path, 'robot')) 39 | output_path = str(os.path.join(output_dir, f'{args.robot_name}.pt')) 40 | os.makedirs(output_dir, exist_ok=True) 41 | 42 | hand = create_hand_model(args.robot_name, torch.device('cpu'), args.num_points) 43 | links_pc = hand.vertices 44 | sampled_pc, sampled_pc_index = hand.get_sampled_pc(num_points=args.num_points) 45 | 46 | filtered_links_pc = {} 47 | for link_index, (link_name, points) in enumerate(links_pc.items()): 48 | mask = [i % args.num_points for i in sampled_pc_index 49 | if link_index * args.num_points <= i < (link_index + 1) * args.num_points] 50 | links_pc[link_name] = torch.tensor(points, dtype=torch.float32) 51 | filtered_links_pc[link_name] = torch.tensor(points[mask], dtype=torch.float32) 52 | print(f"[{link_name}] original shape: {links_pc[link_name].shape}, filtered shape: {filtered_links_pc[link_name].shape}") 53 | 54 | data = { 55 | 'original': links_pc, 56 | 'filtered': filtered_links_pc 57 | } 58 | torch.save(data, output_path) 59 | print("\nGenerating robot point cloud finished.") 60 | 61 | server = viser.ViserServer(host='127.0.0.1', port=8080) 62 | server.scene.add_point_cloud( 63 | 'point cloud', 64 | sampled_pc[:, :3].numpy(), 65 | point_size=0.001, 66 | point_shape="circle", 67 | colors=(0, 0, 200) 68 | ) 69 | while True: 70 | time.sleep(1) 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--type', default='robot', type=str) 76 | parser.add_argument('--save_path', default='data/PointCloud/', type=str) 77 | parser.add_argument('--num_points', default=512, type=int) 78 | # for object pc generation 79 | parser.add_argument('--object_source_path', default='data/data_urdf/object', type=str) 80 | # for robot pc generation 81 | parser.add_argument('--robot_name', default='shadowhand', type=str) 82 | args = parser.parse_args() 83 | 84 | if args.type == 'robot': 85 | generate_robot_pc(args) 86 | elif args.type == 'object': 87 | generate_object_pc(args) 88 | else: 89 | raise NotImplementedError 90 | -------------------------------------------------------------------------------- /data_utils/removed_links.json: -------------------------------------------------------------------------------- 1 | { 2 | "allegro": [], 3 | "barrett": [], 4 | "ezgripper": ["base_link"], 5 | "robotiq_3finger": [], 6 | "shadowhand": ["forearm", "wrist", "ffknuckle", "mfknuckle", "rfknuckle", "lfknuckle", "thbase"], 7 | "leaphand": [] 8 | } 9 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def knn(x, k): 7 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) 8 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 9 | pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() 10 | 11 | idx = pairwise_distance.topk(k=k, dim=-1)[1] 12 | return idx 13 | 14 | 15 | def get_graph_feature(x, k=20): 16 | idx = knn(x, k=k) 17 | batch_size, num_points, _ = idx.size() 18 | _, num_dims, _ = x.size() 19 | 20 | idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points 21 | idx = idx + idx_base 22 | idx = idx.view(-1) 23 | 24 | x = x.transpose(2, 1).contiguous() 25 | feature = x.view(batch_size * num_points, -1)[idx, :] 26 | feature = feature.view(batch_size, num_points, k, num_dims) 27 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 28 | 29 | feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2).contiguous() 30 | 31 | return feature 32 | 33 | 34 | class Encoder(nn.Module): 35 | """ 36 | The implementation is based on the DGCNN model 37 | (https://github.com/WangYueFt/dgcnn/blob/f765b469a67730658ba554e97dc11723a7bab628/pytorch/model.py#L88), 38 | and https://github.com/r-pad/taxpose/blob/0c4298fa0486fd09e63bf24d618a579b66ba0f18/third_party/dcp/model.py#L282. 39 | 40 | Further explanation can be found in Appendix F.1 of https://arxiv.org/pdf/2410.01702. 41 | """ 42 | 43 | def __init__(self, emb_dim=512): 44 | super(Encoder, self).__init__() 45 | 46 | self.conv1 = nn.Sequential( 47 | nn.Conv2d(6, 64, kernel_size=1, bias=False), 48 | nn.BatchNorm2d(64), 49 | nn.LeakyReLU(negative_slope=0.2) 50 | ) 51 | self.conv2 = nn.Sequential( 52 | nn.Conv2d(64, 64, kernel_size=1, bias=False), 53 | nn.BatchNorm2d(64), 54 | nn.LeakyReLU(negative_slope=0.2) 55 | ) 56 | self.conv3 = nn.Sequential( 57 | nn.Conv2d(64, 128, kernel_size=1, bias=False), 58 | nn.BatchNorm2d(128), 59 | nn.LeakyReLU(negative_slope=0.2) 60 | ) 61 | self.conv4 = nn.Sequential( 62 | nn.Conv2d(128, 256, kernel_size=1, bias=False), 63 | nn.BatchNorm2d(256), 64 | nn.LeakyReLU(negative_slope=0.2) 65 | ) 66 | self.conv5 = nn.Sequential( 67 | nn.Conv2d(256, 512, kernel_size=1, bias=False), 68 | nn.BatchNorm2d(512), 69 | nn.LeakyReLU(negative_slope=0.2) 70 | ) 71 | self.conv6 = nn.Sequential( 72 | nn.Conv1d(1536, emb_dim, kernel_size=1, bias=False), 73 | nn.BatchNorm1d(emb_dim), 74 | nn.LeakyReLU(negative_slope=0.2) 75 | ) 76 | 77 | def forward(self, x): 78 | x = x.permute(0, 2, 1) # (B, N, 3) -> (B, 3, N) 79 | B, _, N = x.size() 80 | 81 | x = get_graph_feature(x, k=32) # (B, 6, N, K) 82 | 83 | x = self.conv1(x) # (B, 64, N, K) 84 | x1 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N) 85 | 86 | x = self.conv2(x) # (B, 64, N, K) 87 | x2 = x.max(dim=-1, keepdim=False)[0] # (B, 64, N) 88 | 89 | x = self.conv3(x) # (B, 128, N, K) 90 | x3 = x.max(dim=-1, keepdim=False)[0] # (B, 128, N) 91 | 92 | x = self.conv4(x) # (B, 256, N, K) 93 | x4 = x.max(dim=-1, keepdim=False)[0] # (B, 256, N) 94 | 95 | x = self.conv5(x) # (B, 512, N, K) 96 | x5 = x.max(dim=-1, keepdim=False)[0] # (B, 512, N) 97 | 98 | global_feat = x5.mean(dim=-1, keepdim=True).repeat(1, 1, N) # (B, 512, 1) -> (B, 512, N) 99 | 100 | x = torch.cat((x1, x2, x3, x4, x5, global_feat), dim=1) # (B, 1536, N) 101 | x = self.conv6(x).view(B, -1, N) # (B, 512, N) 102 | 103 | return x.permute(0, 2, 1) # (B, D, N) -> (B, N, D) 104 | 105 | 106 | class CvaeEncoder(nn.Module): 107 | """ 108 | The implementation is based on the DGCNN model 109 | (https://github.com/WangYueFt/dgcnn/blob/f765b469a67730658ba554e97dc11723a7bab628/pytorch/model.py#L88). 110 | 111 | The only modification made is to enable the input to include additional features. 112 | """ 113 | 114 | def __init__(self, emb_dims, output_channels, feat_dim=0): 115 | super(CvaeEncoder, self).__init__() 116 | self.feat_dim = feat_dim 117 | 118 | self.bn1 = nn.BatchNorm2d(64) 119 | self.bn2 = nn.BatchNorm2d(64) 120 | self.bn3 = nn.BatchNorm2d(128) 121 | self.bn4 = nn.BatchNorm2d(256) 122 | self.bn5 = nn.BatchNorm1d(emb_dims) 123 | 124 | self.conv1 = nn.Sequential( 125 | nn.Conv2d(6 + feat_dim, 64, kernel_size=1, bias=False), 126 | self.bn1, 127 | nn.LeakyReLU(negative_slope=0.2) 128 | ) 129 | self.conv2 = nn.Sequential( 130 | nn.Conv2d(64, 64, kernel_size=1, bias=False), 131 | self.bn2, 132 | nn.LeakyReLU(negative_slope=0.2) 133 | ) 134 | self.conv3 = nn.Sequential( 135 | nn.Conv2d(64,128, kernel_size=1, bias=False), 136 | self.bn3, 137 | nn.LeakyReLU(negative_slope=0.2) 138 | ) 139 | self.conv4 = nn.Sequential( 140 | nn.Conv2d(128, 256, kernel_size=1, bias=False), 141 | self.bn4, 142 | nn.LeakyReLU(negative_slope=0.2) 143 | ) 144 | self.conv5 = nn.Sequential( 145 | nn.Conv1d(512, emb_dims, kernel_size=1, bias=False), 146 | self.bn5, 147 | nn.LeakyReLU(negative_slope=0.2) 148 | ) 149 | self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False) 150 | self.bn6 = nn.BatchNorm1d(512) 151 | self.dp1 = nn.Dropout(p=0.5) 152 | self.linear2 = nn.Linear(512, 256) 153 | self.bn7 = nn.BatchNorm1d(256) 154 | self.dp2 = nn.Dropout(p=0.5) 155 | self.linear3 = nn.Linear(256, output_channels) 156 | 157 | def forward(self, x): 158 | x = x.permute(0, 2, 1) 159 | B, D, N = x.size() 160 | x_k = get_graph_feature(x[:, :3, :]) # B, 6, N, K 161 | x_feat = x[:, 3:, :].unsqueeze(-1).repeat(1, 1, 1, 20) if self.feat_dim != 0 else None # K = 20 162 | x = torch.cat([x_k, x_feat], dim=1) if self.feat_dim != 0 else x_k # (B, 6 + feat_dim, N, K) 163 | 164 | x = self.conv1(x) 165 | x1 = x.max(dim=-1, keepdim=True)[0] 166 | 167 | x = self.conv2(x) 168 | x2 = x.max(dim=-1, keepdim=True)[0] 169 | 170 | x = self.conv3(x) 171 | x3 = x.max(dim=-1, keepdim=True)[0] 172 | 173 | x = self.conv4(x) 174 | x4 = x.max(dim=-1, keepdim=True)[0] 175 | 176 | x = torch.cat((x1, x2, x3, x4), dim=1)[..., 0] 177 | 178 | x = self.conv5(x) 179 | x1 = F.adaptive_max_pool1d(x, 1).view(B, -1) 180 | x2 = F.adaptive_avg_pool1d(x, 1).view(B, -1) 181 | x = torch.cat((x1, x2), 1) 182 | 183 | x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) 184 | x = self.dp1(x) 185 | x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) 186 | x = self.dp2(x) 187 | x = self.linear3(x) 188 | 189 | return x # (B, output_channels) 190 | -------------------------------------------------------------------------------- /model/latent_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResnetBlockFC(nn.Module): 7 | """ 8 | Fully connected ResNet Block class. 9 | Args: 10 | size_in (int): input dimension 11 | size_out (int): output dimension 12 | size_h (int): hidden dimension 13 | """ 14 | def __init__(self, size_in, size_out=None, size_h=None): 15 | super().__init__() 16 | if size_out is None: 17 | size_out = size_in 18 | 19 | if size_h is None: 20 | size_h = min(size_in, size_out) 21 | 22 | self.size_in = size_in 23 | self.size_h = size_h 24 | self.size_out = size_out 25 | 26 | self.fc_0 = nn.Linear(size_in, size_h) 27 | self.fc_1 = nn.Linear(size_h, size_out) 28 | self.actvn = nn.ReLU() 29 | 30 | if size_in == size_out: 31 | self.shortcut = None 32 | else: 33 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 34 | nn.init.zeros_(self.fc_1.weight) 35 | 36 | def forward(self, x, final_nl=False): 37 | net = self.fc_0(self.actvn(x)) 38 | dx = self.fc_1(self.actvn(net)) 39 | if self.shortcut is not None: 40 | x_s = self.shortcut(x) 41 | else: 42 | x_s = x 43 | x_out = x_s + dx 44 | if final_nl: 45 | return F.leaky_relu(x_out, negative_slope=0.2) 46 | return x_out 47 | 48 | 49 | class LatentEncoder(nn.Module): 50 | def __init__(self, in_dim, dim, out_dim): 51 | super().__init__() 52 | self.block = ResnetBlockFC(size_in=in_dim, size_out=dim, size_h=dim) 53 | self.fc_mu = nn.Linear(dim, out_dim) 54 | self.fc_logvar = nn.Linear(dim, out_dim) 55 | 56 | def forward(self, x): 57 | x = self.block(x, final_nl=True) 58 | return self.fc_mu(x), self.fc_logvar(x) 59 | -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is sourced from https://github.com/r-pad/taxpose. 3 | 4 | The only modification made is adjusting the relative imports to enhance the clarity of the file structure. 5 | """ 6 | 7 | from typing import Callable, List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class MLPKernel(nn.Module): 15 | def __init__(self, feature_dim): 16 | super().__init__() 17 | self.feature_dim = feature_dim 18 | self.mlp = MLP(2 * feature_dim, [300, 100, 1]) 19 | 20 | def forward(self, x1, x2): 21 | v1 = self.mlp(torch.cat([x1, x2], dim=-1)) 22 | v2 = self.mlp(torch.cat([x2, x1], dim=-1)) 23 | return F.softplus((v1 + v2) / 2) 24 | 25 | 26 | class MLP(nn.Sequential): 27 | """ 28 | This block implements the multi-layer perceptron (MLP) module. 29 | 30 | Args: 31 | in_channels (int): Number of channels of the input 32 | hidden_channels (List[int]): List of the hidden channel dimensions 33 | norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` 34 | activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` 35 | inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. 36 | Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. 37 | bias (bool): Whether to use bias in the linear layer. Default ``True`` 38 | dropout (float): The probability for the dropout layer. Default: 0.0 39 | """ 40 | 41 | def __init__( 42 | self, 43 | in_channels: int, 44 | hidden_channels: List[int], 45 | norm_layer: Optional[Callable[..., torch.nn.Module]] = None, 46 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, 47 | inplace: Optional[bool] = None, 48 | bias: bool = True, 49 | dropout: float = 0.0, 50 | ): 51 | # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: 52 | # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py 53 | params = {} if inplace is None else {"inplace": inplace} 54 | 55 | layers: List[nn.Module] = [] 56 | in_dim = in_channels 57 | for hidden_dim in hidden_channels[:-1]: 58 | layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) 59 | if norm_layer is not None: 60 | layers.append(norm_layer(hidden_dim)) 61 | if activation_layer is not None: 62 | layers.append(activation_layer(**params)) 63 | layers.append(torch.nn.Dropout(dropout, **params)) 64 | in_dim = hidden_dim 65 | 66 | layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) 67 | layers.append(torch.nn.Dropout(dropout, **params)) 68 | 69 | super().__init__(*layers) 70 | -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | import pytorch_lightning as pl 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.se3_transform import compute_link_pose 13 | from utils.multilateration import multilateration 14 | from utils.func_utils import calculate_depth 15 | from utils.pretrain_utils import dist2weight, infonce_loss, mean_order 16 | 17 | 18 | class TrainingModule(pl.LightningModule): 19 | def __init__(self, cfg, network, epoch_idx): 20 | super().__init__() 21 | self.cfg = cfg 22 | self.network = network 23 | self.epoch_idx = epoch_idx 24 | 25 | self.lr = cfg.lr 26 | 27 | os.makedirs(self.cfg.save_dir, exist_ok=True) 28 | 29 | def ddp_print(self, *args, **kwargs): 30 | if self.global_rank == 0: 31 | print(*args, **kwargs) 32 | 33 | def training_step(self, batch, batch_idx): 34 | object_name = batch['object_name'] 35 | robot_links_pc = batch['robot_links_pc'] 36 | robot_pc_initial = batch['robot_pc_initial'] 37 | robot_pc_target = batch['robot_pc_target'] 38 | object_pc = batch['object_pc'] 39 | dro_gt = batch['dro_gt'] 40 | 41 | network_output = self.network( 42 | robot_pc_initial, 43 | object_pc, 44 | robot_pc_target 45 | ) 46 | 47 | dro = network_output['dro'] 48 | mu = network_output['mu'] 49 | logvar = network_output['logvar'] 50 | 51 | mlat_pc = multilateration(dro, object_pc) 52 | transforms, transformed_pc = compute_link_pose(robot_links_pc, mlat_pc) 53 | 54 | loss = 0. 55 | 56 | if self.cfg.loss_kl: 57 | loss_kl = - 0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp()) 58 | loss_kl = torch.sqrt(1 + loss_kl ** 2) - 1 59 | loss_kl = loss_kl * self.cfg.loss_kl_weight 60 | self.log('loss_kl', loss_kl, prog_bar=True) 61 | loss += loss_kl 62 | 63 | if self.cfg.loss_r: 64 | loss_r = nn.L1Loss()(dro, dro_gt) 65 | loss_r = loss_r * self.cfg.loss_r_weight 66 | self.log('loss_r', loss_r, prog_bar=True) 67 | loss += loss_r 68 | 69 | if self.cfg.loss_se3: 70 | transforms_gt, transformed_pc_gt = compute_link_pose(robot_links_pc, robot_pc_target) 71 | loss_se3 = 0. 72 | for idx in range(len(transforms)): # iteration over batch 73 | transform = transforms[idx] 74 | transform_gt = transforms_gt[idx] 75 | loss_se3_item = 0. 76 | for link_name in transform: 77 | rel_translation = transform[link_name][:3, 3] - transform_gt[link_name][:3, 3] 78 | rel_rotation = transform[link_name][:3, :3].mT @ transform_gt[link_name][:3, :3] 79 | rel_rotation_trace = torch.clamp(torch.trace(rel_rotation), -1, 3) 80 | rel_angle = torch.acos((rel_rotation_trace - 1) / 2) 81 | loss_se3_item += torch.mean(torch.norm(rel_translation, dim=-1) + rel_angle) 82 | loss_se3 += loss_se3_item / len(transform) 83 | loss_se3 = loss_se3 / len(transforms) * self.cfg.loss_se3_weight 84 | self.log('loss_se3', loss_se3, prog_bar=True) 85 | loss += loss_se3 86 | 87 | if self.cfg.loss_depth: 88 | loss_depth = calculate_depth(transformed_pc, object_name) 89 | loss_depth = loss_depth * self.cfg.loss_depth_weight 90 | self.log('loss_depth', loss_depth, prog_bar=True) 91 | loss += loss_depth 92 | 93 | self.log("loss", loss, prog_bar=True) 94 | return loss 95 | 96 | def on_after_backward(self): 97 | """ 98 | For unknown reasons, there is a small chance that the gradients in CVAE may become NaN during backpropagation. 99 | In such cases, skip the iteration. 100 | """ 101 | for param in self.network.parameters(): 102 | if param.grad is not None and torch.isnan(param.grad).any(): 103 | param.grad = None 104 | 105 | def on_train_epoch_end(self): 106 | self.epoch_idx += 1 107 | self.ddp_print(f"Training epoch: {self.epoch_idx}") 108 | if self.epoch_idx % self.cfg.save_every_n_epoch == 0: 109 | self.ddp_print(f"Saving state_dict at epoch: {self.epoch_idx}") 110 | torch.save(self.network.state_dict(), f'{self.cfg.save_dir}/epoch_{self.epoch_idx}.pth') 111 | 112 | def configure_optimizers(self): 113 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 114 | return optimizer 115 | 116 | 117 | class PretrainingModule(pl.LightningModule): 118 | def __init__(self, cfg, encoder): 119 | super().__init__() 120 | self.cfg = cfg 121 | self.encoder = encoder 122 | 123 | self.lr = cfg.lr 124 | self.temperature = cfg.temperature 125 | 126 | self.epoch_idx = 0 127 | os.makedirs(self.cfg.save_dir, exist_ok=True) 128 | 129 | def ddp_print(self, *args, **kwargs): 130 | if self.global_rank == 0: 131 | print(*args, **kwargs) 132 | 133 | def training_step(self, batch, batch_idx): 134 | robot_pc_1 = batch['robot_pc_1'] 135 | robot_pc_2 = batch['robot_pc_2'] 136 | 137 | robot_pc_1 = robot_pc_1 - robot_pc_1.mean(dim=1, keepdims=True) 138 | robot_pc_2 = robot_pc_2 - robot_pc_2.mean(dim=1, keepdims=True) 139 | 140 | phi_1 = self.encoder(robot_pc_1) # (B, N, 3) -> (B, N, D) 141 | phi_2 = self.encoder(robot_pc_2) # (B, N, 3) -> (B, N, D) 142 | 143 | weights = dist2weight(robot_pc_1, func=lambda x: torch.tanh(10 * x)) 144 | loss, similarity = infonce_loss( 145 | phi_1, phi_2, weights=weights, temperature=self.temperature 146 | ) 147 | mean_order_error = mean_order(similarity) 148 | 149 | self.log("mean_order", mean_order_error) 150 | self.log("loss", loss, prog_bar=True) 151 | 152 | return loss 153 | 154 | def on_train_epoch_end(self): 155 | self.epoch_idx += 1 156 | self.ddp_print(f"Training epoch: {self.epoch_idx}") 157 | if self.epoch_idx % self.cfg.save_every_n_epoch == 0: 158 | self.ddp_print(f"Saving state_dict at epoch: {self.epoch_idx}") 159 | torch.save(self.encoder.state_dict(), f'{self.cfg.save_dir}/epoch_{self.epoch_idx}.pth') 160 | 161 | 162 | def configure_optimizers(self): 163 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 164 | return optimizer 165 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | 6 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | sys.path.append(ROOT_DIR) 8 | 9 | from model.encoder import Encoder, CvaeEncoder 10 | from model.transformer import Transformer 11 | from model.latent_encoder import LatentEncoder 12 | from model.mlp import MLPKernel 13 | 14 | 15 | def create_encoder_network(emb_dim, pretrain=None, device=torch.device('cpu')) -> nn.Module: 16 | encoder = Encoder(emb_dim=emb_dim) 17 | if pretrain is not None: 18 | print(f"******** Load embedding network pretrain from <{pretrain}> ********") 19 | encoder.load_state_dict( 20 | torch.load( 21 | os.path.join(ROOT_DIR, f"ckpt/pretrain/{pretrain}"), 22 | map_location=device 23 | ) 24 | ) 25 | return encoder 26 | 27 | 28 | class Network(nn.Module): 29 | def __init__(self, cfg, mode): 30 | super(Network, self).__init__() 31 | self.cfg = cfg 32 | self.mode = mode 33 | 34 | self.encoder_robot = create_encoder_network(emb_dim=cfg.emb_dim, pretrain=cfg.pretrain) 35 | self.encoder_object = create_encoder_network(emb_dim=cfg.emb_dim) 36 | 37 | self.transformer_robot = Transformer(emb_dim=cfg.emb_dim) 38 | self.transformer_object = Transformer(emb_dim=cfg.emb_dim) 39 | 40 | # CVAE encoder 41 | self.point_encoder = CvaeEncoder(emb_dims=cfg.emb_dim, output_channels=2 * cfg.latent_dim, feat_dim=cfg.emb_dim) 42 | self.latent_encoder = LatentEncoder(in_dim=2*cfg.latent_dim, dim=4*cfg.latent_dim, out_dim=cfg.latent_dim) 43 | 44 | self.kernel = MLPKernel(cfg.emb_dim + cfg.latent_dim) 45 | 46 | def forward(self, robot_pc, object_pc, target_pc=None): 47 | if self.cfg.center_pc: # zero-mean the robot point cloud 48 | robot_pc = robot_pc - robot_pc.mean(dim=1, keepdim=True) 49 | 50 | # point cloud encoder 51 | robot_embedding = self.encoder_robot(robot_pc) 52 | object_embedding = self.encoder_object(object_pc) 53 | 54 | if self.cfg.pretrain is not None: 55 | robot_embedding = robot_embedding.detach() 56 | 57 | # point cloud transformer 58 | transformer_robot_outputs = self.transformer_robot(robot_embedding, object_embedding) 59 | transformer_object_outputs = self.transformer_object(object_embedding, robot_embedding) 60 | robot_embedding_tf = robot_embedding + transformer_robot_outputs["src_embedding"] 61 | object_embedding_tf = object_embedding + transformer_object_outputs["src_embedding"] 62 | 63 | # CVAE encoder 64 | if self.mode == 'train': 65 | grasp_pc = torch.cat([target_pc, object_pc], dim=1) 66 | grasp_emb = torch.cat([robot_embedding_tf, object_embedding_tf], dim=1) 67 | latent = self.point_encoder(torch.cat([grasp_pc, grasp_emb], -1)) 68 | mu, logvar = self.latent_encoder(latent) 69 | z_dist = torch.distributions.normal.Normal(mu, torch.exp(0.5 * logvar)) 70 | z = z_dist.rsample() # (B, latent_dim) 71 | else: 72 | mu, logvar = None, None 73 | z = torch.randn(robot_pc.shape[0], self.cfg.latent_dim).to(robot_pc.device) 74 | z = z.unsqueeze(dim=1).repeat(1, robot_embedding_tf.shape[1], 1) # (B, N, latent_dim) 75 | 76 | Phi_A = torch.cat([robot_embedding_tf, z], dim=-1) # (B, N, emb_dim + latent_dim) 77 | Phi_B = torch.cat([object_embedding_tf, z], dim=-1) # (B, N, emb_dim + latent_dim) 78 | 79 | # Compute D(R,O) matrix 80 | if self.cfg.block_computing: # use matrix block computation to save GPU memory 81 | B, N, D = Phi_A.shape 82 | block_num = 4 # experimental result, reaching a balance between speed and GPU memory 83 | N_block = N // block_num 84 | assert N % N_block == 0, 'Unable to perform block computation.' 85 | 86 | dro = torch.zeros([B, N, N], dtype=torch.float32, device=Phi_A.device) 87 | for A_i in range(block_num): 88 | Phi_A_block = Phi_A[:, A_i * N_block: (A_i + 1) * N_block, :] # (B, N_block, D) 89 | for B_i in range(block_num): 90 | Phi_B_block = Phi_B[:, B_i * N_block: (B_i + 1) * N_block, :] # (B, N_block, D) 91 | 92 | Phi_A_r = Phi_A_block.unsqueeze(2).repeat(1, 1, N_block, 1).reshape(B * N_block * N_block, D) 93 | Phi_B_r = Phi_B_block.unsqueeze(1).repeat(1, N_block, 1, 1).reshape(B * N_block * N_block, D) 94 | 95 | dro[:, A_i * N_block: (A_i + 1) * N_block, B_i * N_block: (B_i + 1) * N_block] \ 96 | = self.kernel(Phi_A_r, Phi_B_r).reshape(B, N_block, N_block) 97 | else: 98 | Phi_A_r = ( 99 | Phi_A.unsqueeze(2) 100 | .repeat(1, 1, Phi_A.shape[1], 1) 101 | .reshape(Phi_A.shape[0] * Phi_A.shape[1] * Phi_A.shape[1], Phi_A.shape[2]) 102 | ) 103 | Phi_B_r = ( 104 | Phi_B.unsqueeze(1) 105 | .repeat(1, Phi_B.shape[1], 1, 1) 106 | .reshape(Phi_B.shape[0] * Phi_B.shape[1] * Phi_B.shape[1], Phi_B.shape[2]) 107 | ) 108 | dro = self.kernel(Phi_A_r, Phi_B_r).reshape(Phi_A.shape[0], Phi_A.shape[1], Phi_B.shape[1]) 109 | 110 | outputs = { 111 | 'dro': dro, 112 | 'mu': mu, 113 | 'logvar': logvar, 114 | } 115 | return outputs 116 | 117 | 118 | def create_network(cfg, mode): 119 | network = Network( 120 | cfg=cfg, 121 | mode=mode 122 | ) 123 | return network 124 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is sourced from https://github.com/r-pad/taxpose, which builds upon 3 | the transformer model from https://github.com/WangYueFt/dcp/blob/master/model.py. 4 | 5 | The only modification made is adjusting the relative imports to enhance the clarity of the file structure. 6 | """ 7 | 8 | import math 9 | import copy 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class Transformer(nn.Module): 16 | def __init__( 17 | self, 18 | emb_dim=512, 19 | n_blocks=1, 20 | dropout=0.0, 21 | ff_dims=1024, 22 | n_heads=4, 23 | bidirectional=False, 24 | ): 25 | super(Transformer, self).__init__() 26 | self.emb_dim = emb_dim 27 | self.N = n_blocks 28 | self.dropout = dropout 29 | self.ff_dims = ff_dims 30 | self.n_heads = n_heads 31 | self.bidirectional = bidirectional 32 | c = copy.deepcopy 33 | attn = MultiHeadedAttention(self.n_heads, self.emb_dim) 34 | ff = PositionwiseFeedForward(self.emb_dim, self.ff_dims, self.dropout) 35 | self.model = EncoderDecoder( 36 | Encoder( 37 | EncoderLayer(self.emb_dim, c(attn), c(ff), self.dropout), 38 | self.N 39 | ), 40 | Decoder( 41 | DecoderLayer(self.emb_dim, c(attn), c(attn), c(ff), self.dropout), 42 | self.N, 43 | ), 44 | nn.Sequential(), 45 | nn.Sequential(), 46 | nn.Sequential(), 47 | ) 48 | 49 | def forward(self, *input): 50 | src = input[0] 51 | tgt = input[1] 52 | src_embedding = self.model(tgt, src, None, None) 53 | src_attn = self.model.decoder.layers[-1].src_attn.attn 54 | 55 | outputs = {"src_embedding": src_embedding, "src_attn": src_attn} 56 | 57 | if self.bidirectional: 58 | tgt_embedding = ( 59 | self.model(src, tgt, None, None) 60 | ) 61 | tgt_attn = self.model.decoder.layers[-1].src_attn.attn 62 | 63 | outputs = { 64 | **outputs, 65 | "tgt_embedding": tgt_embedding, 66 | "tgt_attn": tgt_attn, 67 | } 68 | 69 | return outputs 70 | 71 | def clones(module, N): 72 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 73 | 74 | def attention(query, key, value, mask=None, dropout=None): 75 | d_k = query.size(-1) 76 | scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / math.sqrt(d_k) 77 | if mask is not None: 78 | scores = scores.masked_fill(mask == 0, -1e9) 79 | p_attn = F.softmax(scores, dim=-1) 80 | return torch.matmul(p_attn, value), p_attn 81 | 82 | class LayerNorm(nn.Module): 83 | def __init__(self, features, eps=1e-6): 84 | super(LayerNorm, self).__init__() 85 | self.a_2 = nn.Parameter(torch.ones(features)) 86 | self.b_2 = nn.Parameter(torch.zeros(features)) 87 | self.eps = eps 88 | 89 | def forward(self, x): 90 | mean = x.mean(-1, keepdim=True) 91 | std = x.std(-1, keepdim=True) 92 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 93 | 94 | class EncoderDecoder(nn.Module): 95 | """ 96 | A standard Encoder-Decoder architecture. Base for this and many 97 | other models. 98 | """ 99 | 100 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 101 | super(EncoderDecoder, self).__init__() 102 | self.encoder = encoder 103 | self.decoder = decoder 104 | self.src_embed = src_embed 105 | self.tgt_embed = tgt_embed 106 | self.generator = generator 107 | 108 | def forward(self, src, tgt, src_mask, tgt_mask): 109 | """Take in and process masked src and target sequences.""" 110 | return self.decode(self.encode(src, src_mask), src_mask, 111 | tgt, tgt_mask) 112 | 113 | def encode(self, src, src_mask): 114 | return self.encoder(self.src_embed(src), src_mask) 115 | 116 | def decode(self, memory, src_mask, tgt, tgt_mask): 117 | return self.generator(self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)) 118 | 119 | class Decoder(nn.Module): 120 | """Generic N layer decoder with masking.""" 121 | 122 | def __init__(self, layer, N): 123 | super(Decoder, self).__init__() 124 | self.layers = clones(layer, N) 125 | self.norm = LayerNorm(layer.size) 126 | 127 | def forward(self, x, memory, src_mask, tgt_mask): 128 | for layer in self.layers: 129 | x = layer(x, memory, src_mask, tgt_mask) 130 | return self.norm(x) 131 | 132 | class DecoderLayer(nn.Module): 133 | """Decoder is made of self-attn, src-attn, and feed forward (defined below)""" 134 | 135 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 136 | super(DecoderLayer, self).__init__() 137 | self.size = size 138 | self.self_attn = self_attn 139 | self.src_attn = src_attn 140 | self.feed_forward = feed_forward 141 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 142 | 143 | def forward(self, x, memory, src_mask, tgt_mask): 144 | """Follow Figure 1 (right) for connections.""" 145 | m = memory 146 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 147 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 148 | return self.sublayer[2](x, self.feed_forward) 149 | 150 | class SublayerConnection(nn.Module): 151 | def __init__(self, size, dropout=None): 152 | super(SublayerConnection, self).__init__() 153 | self.norm = LayerNorm(size) 154 | 155 | def forward(self, x, sublayer): 156 | return x + sublayer(self.norm(x)) 157 | 158 | class Encoder(nn.Module): 159 | def __init__(self, layer, N): 160 | super(Encoder, self).__init__() 161 | self.layers = clones(layer, N) 162 | self.norm = LayerNorm(layer.size) 163 | 164 | def forward(self, x, mask): 165 | for layer in self.layers: 166 | x = layer(x, mask) 167 | return self.norm(x) 168 | 169 | class EncoderLayer(nn.Module): 170 | def __init__(self, size, self_attn, feed_forward, dropout): 171 | super(EncoderLayer, self).__init__() 172 | self.self_attn = self_attn 173 | self.feed_forward = feed_forward 174 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 175 | self.size = size 176 | 177 | def forward(self, x, mask): 178 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 179 | return self.sublayer[1](x, self.feed_forward) 180 | 181 | class MultiHeadedAttention(nn.Module): 182 | def __init__(self, h, d_model, dropout=0.1): 183 | """Take in model size and number of heads.""" 184 | super(MultiHeadedAttention, self).__init__() 185 | assert d_model % h == 0 186 | # We assume d_v always equals d_k 187 | self.d_k = d_model // h 188 | self.h = h 189 | self.linears = clones(nn.Linear(d_model, d_model), 4) 190 | self.attn = None 191 | self.dropout = None 192 | 193 | def forward(self, query, key, value, mask=None): 194 | """Implements Figure 2""" 195 | if mask is not None: 196 | # Same mask applied to all h heads. 197 | mask = mask.unsqueeze(1) 198 | nbatches = query.size(0) 199 | 200 | # 1) Do all the linear projections in batch from d_model => h x d_k 201 | query, key, value = \ 202 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() 203 | for l, x in zip(self.linears, (query, key, value))] 204 | 205 | # 2) Apply attention on all the projected vectors in batch. 206 | x, self.attn = attention(query, key, value, mask=mask, 207 | dropout=self.dropout) 208 | 209 | # 3) "Concat" using a view and apply a final linear. 210 | x = x.transpose(1, 2).contiguous() \ 211 | .view(nbatches, -1, self.h * self.d_k) 212 | return self.linears[-1](x) 213 | 214 | 215 | class PositionwiseFeedForward(nn.Module): 216 | """Implements FFN equation.""" 217 | 218 | def __init__(self, d_model, d_ff, dropout=0.1): 219 | super(PositionwiseFeedForward, self).__init__() 220 | self.w_1 = nn.Linear(d_model, d_ff) 221 | self.norm = nn.Sequential() # nn.BatchNorm1d(d_ff) 222 | self.w_2 = nn.Linear(d_ff, d_model) 223 | self.dropout = None 224 | 225 | def forward(self, x): 226 | return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous()) 227 | -------------------------------------------------------------------------------- /pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenyuwei2003/DRO-Grasp/b312055b4a20f73ddfeb3ffc8a1a6c80d48bbe31/pipeline.jpg -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | import hydra 5 | from omegaconf import OmegaConf 6 | import torch 7 | import pytorch_lightning as pl 8 | from pytorch_lightning.loggers import WandbLogger 9 | 10 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(ROOT_DIR) 12 | 13 | from model.module import PretrainingModule 14 | from model.network import create_encoder_network 15 | from data_utils.PretrainDataset import create_dataloader 16 | 17 | 18 | @hydra.main(version_base="1.2", config_path="configs", config_name="pretrain") 19 | def main(cfg): 20 | print("******************************** [Config] ********************************") 21 | print(OmegaConf.to_yaml(cfg)) 22 | print("******************************** [Config] ********************************") 23 | 24 | pl.seed_everything(cfg.seed) 25 | 26 | logger = WandbLogger( 27 | name=cfg.name, 28 | save_dir=cfg.wandb.save_dir, 29 | project=cfg.wandb.project 30 | ) 31 | trainer = pl.Trainer( 32 | logger=logger, 33 | accelerator='gpu', 34 | devices=cfg.gpu, 35 | log_every_n_steps=cfg.log_every_n_steps, 36 | max_epochs=cfg.training.max_epochs 37 | ) 38 | 39 | dataloader = create_dataloader(cfg.dataset) 40 | encoder = create_encoder_network(cfg.model.emb_dim) 41 | model = PretrainingModule( 42 | cfg=cfg.training, 43 | encoder=encoder 44 | ) 45 | model.train() 46 | 47 | trainer.fit(model, dataloader) 48 | 49 | 50 | if __name__ == "__main__": 51 | torch.set_float32_matmul_precision("high") 52 | torch.autograd.set_detect_anomaly(True) 53 | torch.cuda.empty_cache() 54 | torch.multiprocessing.set_sharing_strategy("file_system") 55 | warnings.simplefilter(action='ignore', category=FutureWarning) 56 | main() 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohappyeyeballs==2.4.0 3 | aiohttp==3.10.5 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | arm_pytorch_utilities==0.4.3 7 | async-timeout==4.0.3 8 | attrs==24.2.0 9 | beautifulsoup4==4.12.3 10 | cachetools==5.5.0 11 | certifi==2024.8.30 12 | charset-normalizer==3.3.2 13 | click==8.1.7 14 | contourpy==1.1.1 15 | cvxpy==1.5.2 16 | cvxpylayers==0.1.6 17 | cycler==0.12.1 18 | docker-pycreds==0.4.0 19 | filelock==3.15.4 20 | fonttools==4.53.1 21 | frozenlist==1.4.1 22 | fsspec==2024.6.1 23 | gdown==5.2.0 24 | gitdb==4.0.11 25 | GitPython==3.1.43 26 | h5py==3.11.0 27 | hydra-core==1.3.2 28 | idna==3.8 29 | importlib_resources==6.4.4 30 | Jinja2==3.1.4 31 | kiwisolver==1.4.5 32 | lightning-utilities==0.11.6 33 | lxml==5.3.0 34 | MarkupSafe==2.1.5 35 | matplotlib==3.7.5 36 | mpmath==1.3.0 37 | multidict==6.0.5 38 | networkx==3.1 39 | numpy==1.24.4 40 | nvidia-cublas-cu12==12.1.3.1 41 | nvidia-cuda-cupti-cu12==12.1.105 42 | nvidia-cuda-nvrtc-cu12==12.1.105 43 | nvidia-cuda-runtime-cu12==12.1.105 44 | nvidia-cudnn-cu12==9.1.0.70 45 | nvidia-cufft-cu12==11.0.2.54 46 | nvidia-curand-cu12==10.3.2.106 47 | nvidia-cusolver-cu12==11.4.5.107 48 | nvidia-cusparse-cu12==12.1.0.106 49 | nvidia-ml-py==12.535.161 50 | nvidia-nccl-cu12==2.20.5 51 | nvidia-nvjitlink-cu12==12.6.68 52 | nvidia-nvtx-cu12==12.1.105 53 | nvitop==1.3.2 54 | omegaconf==2.3.0 55 | packaging==24.1 56 | pillow==10.4.0 57 | platformdirs==4.2.2 58 | protobuf==5.28.0 59 | psutil==6.0.0 60 | pyparsing==3.1.4 61 | PySocks==1.7.1 62 | python-dateutil==2.9.0.post0 63 | pytorch-lightning==2.4.0 64 | pytorch-seed==0.2.0 65 | pytorch_kinematics==0.7.4 66 | PyYAML==6.0.2 67 | requests==2.32.3 68 | scipy==1.10.1 69 | sentry-sdk==2.13.0 70 | setproctitle==1.3.3 71 | six==1.16.0 72 | smmap==5.0.1 73 | soupsieve==2.6 74 | sympy==1.13.2 75 | termcolor==2.4.0 76 | torch==2.4.1 77 | torchmetrics==1.4.1 78 | torchvision==0.19.1 79 | tqdm==4.66.5 80 | trimesh==4.4.8 81 | triton==3.0.0 82 | typing_extensions==4.12.2 83 | urllib3==2.2.2 84 | viser==0.2.1 85 | wandb==0.17.8 86 | yarl==1.9.6 87 | zipp==3.20.1 88 | -------------------------------------------------------------------------------- /scripts/download_ckpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ckpt 4 | 5 | cd ckpt 6 | wget https://github.com/zhenyuwei2003/DRO-Grasp/releases/download/v1.0/ckpt.zip 7 | unzip ckpt.zip 8 | rm ckpt.zip 9 | cd .. 10 | 11 | echo "Download checkpoint models finished!" 12 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p data 4 | 5 | cd data 6 | wget https://github.com/zhenyuwei2003/DRO-Grasp/releases/download/v1.0/data.zip 7 | unzip data.zip 8 | rm data.zip 9 | cd .. 10 | 11 | echo "Download data finished!" 12 | -------------------------------------------------------------------------------- /scripts/example_isaac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import warnings 5 | import numpy as np 6 | from tqdm import tqdm 7 | from termcolor import cprint 8 | from types import SimpleNamespace 9 | 10 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | sys.path.append(ROOT_DIR) 12 | 13 | from model.network import create_network 14 | from data_utils.CMapDataset import create_dataloader 15 | from utils.multilateration import multilateration 16 | from utils.se3_transform import compute_link_pose 17 | from utils.optimization import * 18 | from utils.hand_model import create_hand_model 19 | from validation.validate_utils import validate_isaac 20 | 21 | 22 | gpu = 0 23 | device = torch.device(f'cuda:{gpu}') 24 | ckpt_name = 'model_3robots' # 'model_3robots_partial', 'model_allegro', 'model_barrett', 'model_shadowhand' 25 | batch_size = 10 26 | 27 | 28 | def main(): 29 | network = create_network( 30 | SimpleNamespace(**{ 31 | 'emb_dim': 512, 32 | 'latent_dim': 64, 33 | 'pretrain': None, 34 | 'center_pc': True, 35 | 'block_computing': True 36 | }), 37 | mode='validate' 38 | ).to(device) 39 | network.load_state_dict(torch.load(f"ckpt/model/{ckpt_name}.pth", map_location=device)) 40 | network.eval() 41 | dataloader = create_dataloader( 42 | SimpleNamespace(**{ 43 | 'batch_size': batch_size, 44 | 'robot_names': ['barrett', 'allegro', 'shadowhand'], 45 | 'debug_object_names': None, 46 | 'object_pc_type': 'random' if ckpt_name != 'model_3robots_partial' else 'partial', 47 | 'num_workers': 16 48 | }), 49 | is_train=False 50 | ) 51 | 52 | global_robot_name = None 53 | hand = None 54 | all_success_q = [] 55 | time_list = [] 56 | success_num = 0 57 | total_num = 0 58 | for i, data in enumerate(dataloader): 59 | robot_name = data['robot_name'] 60 | object_name = data['object_name'] 61 | 62 | if robot_name != global_robot_name: 63 | if global_robot_name is not None: 64 | all_success_q = torch.cat(all_success_q, dim=0) 65 | diversity_std = torch.std(all_success_q, dim=0).mean() 66 | times = np.array(time_list) 67 | time_mean = np.mean(times) 68 | time_std = np.std(times) 69 | 70 | success_rate = success_num / total_num * 100 71 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 72 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 73 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 74 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 75 | 76 | all_success_q = [] 77 | time_list = [] 78 | success_num = 0 79 | total_num = 0 80 | hand = create_hand_model(robot_name, device) 81 | global_robot_name = robot_name 82 | 83 | predict_q_list = [] 84 | for data_idx in tqdm(range(batch_size)): 85 | initial_q = data['initial_q'][data_idx: data_idx + 1].to(device) 86 | robot_pc = data['robot_pc'][data_idx: data_idx + 1].to(device) 87 | object_pc = data['object_pc'][data_idx: data_idx + 1].to(device) 88 | 89 | with torch.no_grad(): 90 | dro = network(robot_pc, object_pc)['dro'].detach() 91 | 92 | mlat_pc = multilateration(dro, object_pc) 93 | transform, _ = compute_link_pose(hand.links_pc, mlat_pc, is_train=False) 94 | optim_transform = process_transform(hand.pk_chain, transform) 95 | 96 | layer = create_problem(hand.pk_chain, optim_transform.keys()) 97 | start_time = time.time() 98 | predict_q = optimization(hand.pk_chain, layer, initial_q, optim_transform) 99 | end_time = time.time() 100 | # print(f"[{data_count}/{batch_size}] Optimization time: {end_time - start_time:.4f} s") 101 | time_list.append(end_time - start_time) 102 | 103 | predict_q_list.append(predict_q) 104 | 105 | predict_q_batch = torch.cat(predict_q_list, dim=0) 106 | 107 | success, isaac_q = validate_isaac(robot_name, object_name, predict_q_batch, gpu=gpu) 108 | succ_num = success.sum().item() if success is not None else -1 109 | success_q = predict_q_batch[success] 110 | all_success_q.append(success_q) 111 | 112 | cprint(f"[{robot_name}/{object_name}]", 'light_blue', end=' ') 113 | cprint(f"Result: {succ_num}/{batch_size}({succ_num / batch_size * 100:.2f}%)", 'green') 114 | success_num += succ_num 115 | total_num += batch_size 116 | 117 | all_success_q = torch.cat(all_success_q, dim=0) 118 | diversity_std = torch.std(all_success_q, dim=0).mean() 119 | 120 | times = np.array(time_list) 121 | time_mean = np.mean(times) 122 | time_std = np.std(times) 123 | 124 | success_rate = success_num / total_num * 100 125 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 126 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 127 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 128 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 129 | 130 | 131 | if __name__ == "__main__": 132 | warnings.simplefilter(action='ignore', category=FutureWarning) 133 | torch.set_num_threads(8) 134 | main() 135 | -------------------------------------------------------------------------------- /scripts/example_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from types import SimpleNamespace 4 | from tqdm import tqdm 5 | import torch 6 | 7 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(ROOT_DIR) 9 | 10 | from model.network import create_encoder_network 11 | from data_utils.CMapDataset import create_dataloader 12 | from utils.pretrain_utils import dist2weight, infonce_loss 13 | 14 | 15 | pretrain_ckpt = "pretrain_3robots" # name of pretrain model 16 | robot_names = ['barrett', 'allegro', 'shadowhand'] 17 | verbose = False 18 | data_num = 200 19 | 20 | 21 | def main(): 22 | encoder = create_encoder_network(emb_dim=512) 23 | 24 | encoder.load_state_dict( 25 | torch.load( 26 | os.path.join(ROOT_DIR, f'ckpt/pretrain/{pretrain_ckpt}.pth'), 27 | map_location=torch.device('cpu') 28 | ) 29 | ) 30 | 31 | for robot_name in robot_names: 32 | print(f"Robot: {robot_name}") 33 | dataloader = create_dataloader( 34 | SimpleNamespace(**{ 35 | 'batch_size': 1, 36 | 'robot_names': [robot_name], 37 | 'debug_object_names': None, 38 | 'object_pc_type': 'random', 39 | 'num_workers': 4 40 | }), 41 | is_train=True 42 | ) 43 | 44 | orders = [] 45 | for data_idx, data in enumerate(tqdm(dataloader, total=data_num)): 46 | if data_idx == data_num: 47 | break 48 | 49 | pc_1 = data['robot_pc_initial'] 50 | pc_2 = data['robot_pc_target'] 51 | 52 | pc_1 = pc_1 - pc_1.mean(dim=1, keepdims=True) 53 | pc_2 = pc_2 - pc_2.mean(dim=1, keepdims=True) 54 | 55 | emb_1 = encoder(pc_1).detach() 56 | emb_2 = encoder(pc_2).detach() 57 | 58 | weight = dist2weight(pc_1, func=lambda x: torch.tanh(10 * x)) 59 | loss, similarity = infonce_loss( 60 | emb_1, emb_2, weights=weight, temperature=0.1 61 | ) 62 | 63 | order = (similarity > similarity.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)).sum(-1).float().mean() 64 | orders.append(order) 65 | if verbose: 66 | print("\torder:", order) 67 | 68 | print(f"Robot: {robot_name}, Mean Order: {sum(orders) / len(orders)}\n") 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /scripts/pretrain_order.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import argparse 6 | import warnings 7 | from types import SimpleNamespace 8 | from tqdm import tqdm 9 | import torch 10 | 11 | from model.network import create_encoder_network 12 | from data_utils.CMapDataset import create_dataloader 13 | from utils.pretrain_utils import dist2weight, infonce_loss 14 | 15 | 16 | def main(args): 17 | encoder = create_encoder_network(emb_dim=512) 18 | 19 | for epoch in args.epoch_list: 20 | print("****************************************************************") 21 | print(f"[Epoch {epoch}]") 22 | encoder.load_state_dict( 23 | torch.load( 24 | os.path.join(ROOT_DIR, f'output/{args.pretrain_ckpt}/state_dict/epoch_{epoch}.pth'), 25 | map_location=torch.device('cpu') 26 | ) 27 | ) 28 | 29 | for robot_name in args.robot_names: 30 | print(f"Robot: {robot_name}") 31 | dataloader = create_dataloader( 32 | SimpleNamespace(**{ 33 | 'batch_size': 1, 34 | 'robot_names': [robot_name], 35 | 'debug_object_names': None, 36 | 'object_pc_type': 'random', 37 | 'num_workers': 4 38 | }), 39 | is_train=True 40 | ) 41 | # print(len(dataloader)) 42 | 43 | orders = [] 44 | for data_idx, data in enumerate(tqdm(dataloader, total=args.data_num)): 45 | if data_idx == args.data_num: 46 | break 47 | 48 | pc_1 = data['robot_pc_initial'] 49 | pc_2 = data['robot_pc_target'] 50 | 51 | pc_1 = pc_1 - pc_1.mean(dim=1, keepdims=True) 52 | pc_2 = pc_2 - pc_2.mean(dim=1, keepdims=True) 53 | 54 | emb_1 = encoder(pc_1).detach() 55 | emb_2 = encoder(pc_2).detach() 56 | 57 | weight = dist2weight(pc_1, func=lambda x: torch.tanh(10 * x)) 58 | loss, similarity = infonce_loss( 59 | emb_1, emb_2, weights=weight, temperature=0.1 60 | ) 61 | 62 | order = (similarity > similarity.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)).sum(-1).float().mean() 63 | orders.append(order) 64 | if args.verbose: 65 | print("\torder:", order) 66 | 67 | print(f"Epoch: {epoch}, Robot: {robot_name}, Mean Order: {sum(orders) / len(orders)}\n") 68 | 69 | 70 | if __name__ == '__main__': 71 | warnings.simplefilter(action='ignore', category=FutureWarning) 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--pretrain_ckpt', type=str, default='pretrain_3robots') 74 | parser.add_argument('--data_num', type=int, default=200) 75 | parser.add_argument('--epoch_list', type=lambda string: string.split(','), 76 | default=['10', '20', '30', '40', '50', '60', '70', '80', '90', '100']) 77 | parser.add_argument('--robot_names', type=lambda string: string.split(','), 78 | default=['barrett', 'allegro', 'shadowhand']) 79 | parser.add_argument('--verbose', action='store_true') 80 | args = parser.parse_args() 81 | main(args) 82 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import hydra 4 | import warnings 5 | import torch 6 | import pytorch_lightning as pl 7 | from omegaconf import OmegaConf 8 | from pytorch_lightning.loggers import WandbLogger 9 | 10 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(ROOT_DIR) 12 | 13 | from data_utils.CMapDataset import create_dataloader 14 | from model.network import create_network 15 | from model.module import TrainingModule 16 | 17 | 18 | @hydra.main(version_base="1.2", config_path="configs", config_name="train") 19 | def main(cfg): 20 | print("******************************** [Config] ********************************") 21 | print(OmegaConf.to_yaml(cfg)) 22 | print("******************************** [Config] ********************************") 23 | 24 | pl.seed_everything(cfg.seed) 25 | 26 | last_run_id = None 27 | last_epoch = 0 28 | last_ckpt_file = None 29 | if cfg.load_from_checkpoint: 30 | wandb_dir = f'output/{cfg.name}/log/{cfg.wandb.project}' 31 | last_run_id = os.listdir(wandb_dir)[0] 32 | ckpt_dir = f'{wandb_dir}/{last_run_id}/checkpoints' 33 | ckpt_files = os.listdir(ckpt_dir) 34 | for ckpt_file in ckpt_files: 35 | epoch = int(ckpt_file.split('-')[0].split('=')[1]) 36 | if epoch > last_epoch: 37 | last_epoch = epoch 38 | last_ckpt_file = os.path.join(ckpt_dir, ckpt_file) 39 | print("***************************************************") 40 | print(f"Loading checkpoint from run_id({last_run_id}): epoch {last_epoch}") 41 | print("***************************************************") 42 | 43 | logger = WandbLogger( 44 | name=cfg.name, 45 | save_dir=cfg.wandb.save_dir, 46 | id=last_run_id, 47 | project=cfg.wandb.project 48 | ) 49 | trainer = pl.Trainer( 50 | logger=logger, 51 | accelerator='gpu', 52 | strategy='ddp_find_unused_parameters_true' if (cfg.model.pretrain is not None) else 'auto', 53 | devices=cfg.gpu, 54 | log_every_n_steps=cfg.log_every_n_steps, 55 | max_epochs=cfg.training.max_epochs, 56 | gradient_clip_val=0.1 57 | ) 58 | 59 | dataloader = create_dataloader(cfg.dataset, is_train=True) 60 | 61 | network = create_network(cfg.model, mode='train') 62 | model = TrainingModule( 63 | cfg=cfg.training, 64 | network=network, 65 | epoch_idx=last_epoch 66 | ) 67 | model.train() 68 | 69 | trainer.fit(model, dataloader, ckpt_path=last_ckpt_file) 70 | torch.save(model.network.state_dict(), f'{cfg.training.save_dir}/epoch_{cfg.training.max_epochs}.pth') 71 | 72 | 73 | if __name__ == "__main__": 74 | torch.set_float32_matmul_precision("high") 75 | torch.autograd.set_detect_anomaly(True) 76 | torch.cuda.empty_cache() 77 | torch.multiprocessing.set_sharing_strategy("file_system") 78 | warnings.simplefilter(action='ignore', category=FutureWarning) 79 | main() 80 | -------------------------------------------------------------------------------- /utils/controller.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import json 5 | import trimesh 6 | import torch 7 | import viser 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.hand_model import create_hand_model 13 | from utils.rotation import q_rot6d_to_q_euler 14 | 15 | 16 | def get_link_dir(robot_name, joint_name): 17 | if joint_name.startswith('virtual'): 18 | return None 19 | 20 | if robot_name == 'allegro': 21 | if joint_name in ['joint_0.0', 'joint_4.0', 'joint_8.0', 'joint_13.0']: 22 | return None 23 | link_dir = torch.tensor([0, 0, 1], dtype=torch.float32) 24 | elif robot_name == 'barrett': 25 | if joint_name in ['bh_j11_joint', 'bh_j21_joint']: 26 | return None 27 | link_dir = torch.tensor([-1, 0, 0], dtype=torch.float32) 28 | elif robot_name == 'ezgripper': 29 | link_dir = torch.tensor([1, 0, 0], dtype=torch.float32) 30 | elif robot_name == 'robotiq_3finger': 31 | if joint_name in ['gripper_fingerB_knuckle', 'gripper_fingerC_knuckle']: 32 | return None 33 | link_dir = torch.tensor([0, 0, -1], dtype=torch.float32) 34 | elif robot_name == 'shadowhand': 35 | if joint_name in ['WRJ2', 'WRJ1']: 36 | return None 37 | if joint_name != 'THJ5': 38 | link_dir = torch.tensor([0, 0, 1], dtype=torch.float32) 39 | else: 40 | link_dir = torch.tensor([1, 0, 0], dtype=torch.float32) 41 | elif robot_name == 'leaphand': 42 | if joint_name in ['13']: 43 | return None 44 | if joint_name in ['0', '4', '8']: 45 | link_dir = torch.tensor([1, 0, 0], dtype=torch.float32) 46 | elif joint_name in ['1', '5', '9', '12', '14']: 47 | link_dir = torch.tensor([0, 1, 0], dtype=torch.float32) 48 | else: 49 | link_dir = torch.tensor([0, -1, 0], dtype=torch.float32) 50 | else: 51 | raise NotImplementedError(f"Unknown robot name: {robot_name}!") 52 | 53 | return link_dir 54 | 55 | 56 | def controller(robot_name, q_para): 57 | q_batch = torch.atleast_2d(q_para) 58 | 59 | hand = create_hand_model(robot_name, device=q_batch.device) 60 | joint_orders = hand.get_joint_orders() 61 | pk_chain = hand.pk_chain 62 | if q_batch.shape[-1] != len(pk_chain.get_joint_parameter_names()): 63 | q_batch = q_rot6d_to_q_euler(q_batch) 64 | status = pk_chain.forward_kinematics(q_batch) 65 | 66 | outer_q_batch = [] 67 | inner_q_batch = [] 68 | for batch_idx in range(q_batch.shape[0]): 69 | joint_dots = {} 70 | for frame_name in pk_chain.get_frame_names(): 71 | frame = pk_chain.find_frame(frame_name) 72 | joint = frame.joint 73 | link_dir = get_link_dir(robot_name, joint.name) 74 | if link_dir is None: 75 | continue 76 | 77 | frame_transform = status[frame_name].get_matrix()[batch_idx] 78 | axis_dir = frame_transform[:3, :3] @ joint.axis 79 | link_dir = frame_transform[:3, :3] @ link_dir 80 | normal_dir = torch.cross(axis_dir, link_dir, dim=0) 81 | axis_origin = frame_transform[:3, 3] 82 | origin_dir = -axis_origin / torch.norm(axis_origin) 83 | joint_dots[joint.name] = torch.dot(normal_dir, origin_dir) 84 | 85 | q = q_batch[batch_idx] 86 | lower_q, upper_q = hand.pk_chain.get_joint_limits() 87 | outer_q, inner_q = q.clone(), q.clone() 88 | for joint_name, dot in joint_dots.items(): 89 | idx = joint_orders.index(joint_name) 90 | if robot_name == 'robotiq_3finger': # open -> upper, close -> lower 91 | outer_q[idx] += 0.25 * ((outer_q[idx] - lower_q[idx]) if dot <= 0 else (outer_q[idx] - upper_q[idx])) 92 | inner_q[idx] += 0.15 * ((inner_q[idx] - upper_q[idx]) if dot <= 0 else (inner_q[idx] - lower_q[idx])) 93 | else: # open -> lower, close -> upper 94 | outer_q[idx] += 0.25 * ((lower_q[idx] - outer_q[idx]) if dot >= 0 else (upper_q[idx] - outer_q[idx])) 95 | inner_q[idx] += 0.15 * ((upper_q[idx] - inner_q[idx]) if dot >= 0 else (lower_q[idx] - inner_q[idx])) 96 | outer_q_batch.append(outer_q) 97 | inner_q_batch.append(inner_q) 98 | 99 | outer_q_batch = torch.stack(outer_q_batch, dim=0) 100 | inner_q_batch = torch.stack(inner_q_batch, dim=0) 101 | 102 | if q_para.ndim == 2: # batch 103 | return outer_q_batch.to(q_para.device), inner_q_batch.to(q_para.device) 104 | else: 105 | return outer_q_batch[0].to(q_para.device), inner_q_batch[0].to(q_para.device) 106 | -------------------------------------------------------------------------------- /utils/func_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(ROOT_DIR) 7 | 8 | 9 | def calculate_depth(robot_pc, object_names): 10 | """ 11 | Calculate the average penetration depth of predicted pc into the object. 12 | 13 | :param robot_pc: (B, N, 3) 14 | :param object_name: list , len = B 15 | :return: calculated depth, (B,) 16 | """ 17 | object_pc_list = [] 18 | normals_list = [] 19 | for object_name in object_names: 20 | name = object_name.split('+') 21 | object_path = os.path.join(ROOT_DIR, f'data/PointCloud/object/{name[0]}/{name[1]}.pt') 22 | object_pc_normals = torch.load(object_path).to(robot_pc.device) 23 | object_pc_list.append(object_pc_normals[:, :3]) 24 | normals_list.append(object_pc_normals[:, 3:]) 25 | object_pc = torch.stack(object_pc_list, dim=0) 26 | normals = torch.stack(normals_list, dim=0) 27 | 28 | distance = torch.cdist(robot_pc, object_pc) 29 | distance, index = torch.min(distance, dim=-1) 30 | index = index.unsqueeze(-1).repeat(1, 1, 3) 31 | object_pc_indexed = torch.gather(object_pc, dim=1, index=index) 32 | normals_indexed = torch.gather(normals, dim=1, index=index) 33 | get_sign = torch.vmap(torch.vmap(lambda x, y: torch.where(torch.dot(x, y) >= 0, 1, -1))) 34 | signed_distance = distance * get_sign(robot_pc - object_pc_indexed, normals_indexed) 35 | signed_distance[signed_distance > 0] = 0 36 | return -torch.mean(signed_distance) 37 | 38 | 39 | def farthest_point_sampling(point_cloud, num_points=1024): 40 | """ 41 | :param point_cloud: (N, 3) or (N, 4), point cloud (with link index) 42 | :param num_points: int, number of sampled points 43 | :return: ((N, 3) or (N, 4), list), sampled point cloud (numpy) & index 44 | """ 45 | point_cloud_origin = point_cloud 46 | if point_cloud.shape[1] == 4: 47 | point_cloud = point_cloud[:, :3] 48 | 49 | selected_indices = [0] 50 | distances = torch.norm(point_cloud - point_cloud[selected_indices[-1]], dim=1) 51 | for _ in range(num_points - 1): 52 | farthest_point_idx = torch.argmax(distances) 53 | selected_indices.append(farthest_point_idx) 54 | new_distances = torch.norm(point_cloud - point_cloud[farthest_point_idx], dim=1) 55 | distances = torch.min(distances, new_distances) 56 | sampled_point_cloud = point_cloud_origin[selected_indices] 57 | 58 | return sampled_point_cloud, selected_indices 59 | -------------------------------------------------------------------------------- /utils/hand_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import math 5 | import random 6 | import numpy as np 7 | import torch 8 | import trimesh 9 | import pytorch_kinematics as pk 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(ROOT_DIR) 13 | 14 | from utils.func_utils import farthest_point_sampling 15 | from utils.mesh_utils import load_link_geometries 16 | from utils.rotation import * 17 | 18 | 19 | class HandModel: 20 | def __init__( 21 | self, 22 | robot_name, 23 | urdf_path, 24 | meshes_path, 25 | links_pc_path, 26 | device, 27 | link_num_points=512 28 | ): 29 | self.robot_name = robot_name 30 | self.urdf_path = urdf_path 31 | self.meshes_path = meshes_path 32 | self.device = device 33 | 34 | self.pk_chain = pk.build_chain_from_urdf(open(urdf_path).read()).to(dtype=torch.float32, device=device) 35 | self.dof = len(self.pk_chain.get_joint_parameter_names()) 36 | if os.path.exists(links_pc_path): # In case of generating robot links pc, the file doesn't exist. 37 | links_pc_data = torch.load(links_pc_path, map_location=device) 38 | self.links_pc = links_pc_data['filtered'] 39 | self.links_pc_original = links_pc_data['original'] 40 | else: 41 | self.links_pc = None 42 | self.links_pc_original = None 43 | 44 | self.meshes = load_link_geometries(robot_name, self.urdf_path, self.pk_chain.get_link_names()) 45 | 46 | self.vertices = {} 47 | removed_links = json.load(open(os.path.join(ROOT_DIR, 'data_utils/removed_links.json')))[robot_name] 48 | for link_name, link_mesh in self.meshes.items(): 49 | if link_name in removed_links: # remove links unrelated to contact 50 | continue 51 | v = link_mesh.sample(link_num_points) 52 | self.vertices[link_name] = v 53 | 54 | self.frame_status = None 55 | 56 | def get_joint_orders(self): 57 | return [joint.name for joint in self.pk_chain.get_joints()] 58 | 59 | def update_status(self, q): 60 | if q.shape[-1] != self.dof: 61 | q = q_rot6d_to_q_euler(q) 62 | self.frame_status = self.pk_chain.forward_kinematics(q.to(self.device)) 63 | 64 | def get_transformed_links_pc(self, q=None, links_pc=None): 65 | """ 66 | Use robot link pc & q value to get point cloud. 67 | 68 | :param q: (6 + DOF,), joint values (euler representation) 69 | :param links_pc: {link_name: (N_link, 3)}, robot links pc dict, not None only for get_sampled_pc() 70 | :return: point cloud: (N, 4), with link index 71 | """ 72 | if q is None: 73 | q = torch.zeros(self.dof, dtype=torch.float32, device=self.device) 74 | self.update_status(q) 75 | if links_pc is None: 76 | links_pc = self.links_pc 77 | 78 | all_pc_se3 = [] 79 | for link_index, (link_name, link_pc) in enumerate(links_pc.items()): 80 | if not torch.is_tensor(link_pc): 81 | link_pc = torch.tensor(link_pc, dtype=torch.float32, device=q.device) 82 | n_link = link_pc.shape[0] 83 | se3 = self.frame_status[link_name].get_matrix()[0].to(q.device) 84 | homogeneous_tensor = torch.ones(n_link, 1, device=q.device) 85 | link_pc_homogeneous = torch.cat([link_pc.to(q.device), homogeneous_tensor], dim=1) 86 | link_pc_se3 = (link_pc_homogeneous @ se3.T)[:, :3] 87 | index_tensor = torch.full([n_link, 1], float(link_index), device=q.device) 88 | link_pc_se3_index = torch.cat([link_pc_se3, index_tensor], dim=1) 89 | all_pc_se3.append(link_pc_se3_index) 90 | all_pc_se3 = torch.cat(all_pc_se3, dim=0) 91 | 92 | return all_pc_se3 93 | 94 | def get_sampled_pc(self, q=None, num_points=512): 95 | """ 96 | :param q: (9 + DOF,), joint values (rot6d representation) 97 | :param num_points: int, number of sampled points 98 | :return: ((N, 3), list), sampled point cloud (numpy) & index 99 | """ 100 | if q is None: 101 | q = self.get_canonical_q() 102 | 103 | sampled_pc = self.get_transformed_links_pc(q, self.vertices) 104 | return farthest_point_sampling(sampled_pc, num_points) 105 | 106 | def get_canonical_q(self): 107 | """ For visualization purposes only. """ 108 | lower, upper = self.pk_chain.get_joint_limits() 109 | canonical_q = torch.tensor(lower) * 0.75 + torch.tensor(upper) * 0.25 110 | canonical_q[:6] = 0 111 | return canonical_q 112 | 113 | def get_initial_q(self, q=None, max_angle: float = math.pi / 6): 114 | """ 115 | Compute the robot initial joint value q based on the target grasp. 116 | Root translation is not considered since the point cloud will be normalized to zero-mean. 117 | 118 | :param q: (6 + DOF,) or (9 + DOF,), joint values (euler/rot6d representation) 119 | :param max_angle: float, maximum angle of the random rotation 120 | :return: initial q: (6 + DOF,), euler representation 121 | """ 122 | if q is None: # random sample root rotation and joint values 123 | q_initial = torch.zeros(self.dof, dtype=torch.float32, device=self.device) 124 | 125 | q_initial[3:6] = (torch.rand(3) * 2 - 1) * torch.pi 126 | q_initial[5] /= 2 127 | 128 | lower_joint_limits, upper_joint_limits = self.pk_chain.get_joint_limits() 129 | lower_joint_limits = torch.tensor(lower_joint_limits[6:], dtype=torch.float32) 130 | upper_joint_limits = torch.tensor(upper_joint_limits[6:], dtype=torch.float32) 131 | portion = random.uniform(0.65, 0.85) 132 | q_initial[6:] = lower_joint_limits * portion + upper_joint_limits * (1 - portion) 133 | else: 134 | if len(q) == self.dof: 135 | q = q_euler_to_q_rot6d(q) 136 | q_initial = q.clone() 137 | 138 | # compute random initial rotation 139 | direction = - q_initial[:3] / torch.norm(q_initial[:3]) 140 | angle = torch.tensor(random.uniform(0, max_angle), device=q.device) # sample rotation angle 141 | axis = torch.randn(3).to(q.device) # sample rotation axis 142 | axis -= torch.dot(axis, direction) * direction # ensure orthogonality 143 | axis = axis / torch.norm(axis) 144 | random_rotation = axisangle_to_matrix(axis, angle).to(q.device) 145 | rotation_matrix = random_rotation @ rot6d_to_matrix(q_initial[3:9]) 146 | q_initial[3:9] = matrix_to_rot6d(rotation_matrix) 147 | 148 | # compute random initial joint values 149 | lower_joint_limits, upper_joint_limits = self.pk_chain.get_joint_limits() 150 | lower_joint_limits = torch.tensor(lower_joint_limits[6:], dtype=torch.float32) 151 | upper_joint_limits = torch.tensor(upper_joint_limits[6:], dtype=torch.float32) 152 | portion = random.uniform(0.65, 0.85) 153 | q_initial[9:] = lower_joint_limits * portion + upper_joint_limits * (1 - portion) 154 | # q_initial[9:] = torch.zeros_like(q_initial[9:], dtype=q.dtype, device=q.device) 155 | 156 | q_initial = q_rot6d_to_q_euler(q_initial) 157 | 158 | return q_initial 159 | 160 | def get_trimesh_q(self, q): 161 | """ Return the hand trimesh object corresponding to the input joint value q. """ 162 | self.update_status(q) 163 | 164 | scene = trimesh.Scene() 165 | for link_name in self.vertices: 166 | mesh_transform_matrix = self.frame_status[link_name].get_matrix()[0].cpu().numpy() 167 | scene.add_geometry(self.meshes[link_name].copy().apply_transform(mesh_transform_matrix)) 168 | 169 | vertices = [] 170 | faces = [] 171 | vertex_offset = 0 172 | for geom in scene.geometry.values(): 173 | if isinstance(geom, trimesh.Trimesh): 174 | vertices.append(geom.vertices) 175 | faces.append(geom.faces + vertex_offset) 176 | vertex_offset += len(geom.vertices) 177 | all_vertices = np.vstack(vertices) 178 | all_faces = np.vstack(faces) 179 | 180 | parts = {} 181 | for link_name in self.meshes: 182 | mesh_transform_matrix = self.frame_status[link_name].get_matrix()[0].cpu().numpy() 183 | part_mesh = self.meshes[link_name].copy().apply_transform(mesh_transform_matrix) 184 | parts[link_name] = part_mesh 185 | 186 | return_dict = { 187 | 'visual': trimesh.Trimesh(vertices=all_vertices, faces=all_faces), 188 | 'parts': parts 189 | } 190 | return return_dict 191 | 192 | def get_trimesh_se3(self, transform, index): 193 | """ Return the hand trimesh object corresponding to the input transform. """ 194 | scene = trimesh.Scene() 195 | for link_name in transform: 196 | mesh_transform_matrix = transform[link_name][index].cpu().numpy() 197 | scene.add_geometry(self.meshes[link_name].copy().apply_transform(mesh_transform_matrix)) 198 | 199 | vertices = [] 200 | faces = [] 201 | vertex_offset = 0 202 | for geom in scene.geometry.values(): 203 | if isinstance(geom, trimesh.Trimesh): 204 | vertices.append(geom.vertices) 205 | faces.append(geom.faces + vertex_offset) 206 | vertex_offset += len(geom.vertices) 207 | all_vertices = np.vstack(vertices) 208 | all_faces = np.vstack(faces) 209 | 210 | return trimesh.Trimesh(vertices=all_vertices, faces=all_faces) 211 | 212 | 213 | def create_hand_model( 214 | robot_name, 215 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 216 | num_points=512 217 | ): 218 | json_path = os.path.join(ROOT_DIR, 'data/data_urdf/robot/urdf_assets_meta.json') 219 | urdf_assets_meta = json.load(open(json_path)) 220 | urdf_path = os.path.join(ROOT_DIR, urdf_assets_meta['urdf_path'][robot_name]) 221 | meshes_path = os.path.join(ROOT_DIR, urdf_assets_meta['meshes_path'][robot_name]) 222 | links_pc_path = os.path.join(ROOT_DIR, f'data/PointCloud/robot/{robot_name}.pt') 223 | hand_model = HandModel(robot_name, urdf_path, meshes_path, links_pc_path, device, num_points) 224 | return hand_model 225 | -------------------------------------------------------------------------------- /utils/mesh_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | !!! This code file is not organized, there may be relatively chaotic writing and inconsistent comment formats. !!! 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import trimesh 8 | import xml.etree.ElementTree as ET 9 | from scipy.spatial.transform import Rotation as R 10 | 11 | 12 | def as_mesh(scene_or_mesh): 13 | """ 14 | Convert a possible scene to a mesh. 15 | 16 | If conversion occurs, the returned mesh has only vertex and face data. 17 | """ 18 | if isinstance(scene_or_mesh, trimesh.Scene): 19 | if len(scene_or_mesh.geometry) == 0: 20 | mesh = None # empty scene 21 | else: 22 | # we lose texture information here 23 | mesh = trimesh.util.concatenate( 24 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)for g in scene_or_mesh.geometry.values()) 25 | ) 26 | else: 27 | assert isinstance(scene_or_mesh, trimesh.Trimesh) 28 | mesh = scene_or_mesh 29 | return mesh 30 | 31 | 32 | def extract_colors_from_urdf(urdf_path): 33 | tree = ET.parse(urdf_path) 34 | root = tree.getroot() 35 | 36 | global_materials = {} 37 | 38 | for material in root.findall("material"): 39 | name = material.attrib["name"] 40 | color_elem = material.find("color") 41 | if color_elem is not None and "rgba" in color_elem.attrib: 42 | rgba = [float(c) for c in color_elem.attrib["rgba"].split()] 43 | global_materials[name] = rgba 44 | 45 | link_colors = {} 46 | 47 | for link in root.iter("link"): 48 | link_name = link.attrib["name"] 49 | visual = link.find("./visual") 50 | if visual is not None: 51 | material = visual.find("./material") 52 | if material is not None: 53 | color = material.find("color") 54 | if color is not None and "rgba" in color.attrib: 55 | rgba = [float(c) for c in color.attrib["rgba"].split()] 56 | link_colors[link_name] = rgba 57 | elif "name" in material.attrib: 58 | material_name = material.attrib["name"] 59 | if material_name in global_materials: 60 | link_colors[link_name] = global_materials[material_name] 61 | 62 | return link_colors 63 | 64 | 65 | def parse_origin(element): 66 | """Parse the origin element for translation and rotation.""" 67 | origin = element.find("origin") 68 | xyz = np.zeros(3) 69 | rotation = np.eye(3) 70 | if origin is not None: 71 | xyz = np.fromstring(origin.attrib.get("xyz", "0 0 0"), sep=" ") 72 | rpy = np.fromstring(origin.attrib.get("rpy", "0 0 0"), sep=" ") 73 | rotation = R.from_euler("xyz", rpy).as_matrix() 74 | return xyz, rotation 75 | 76 | 77 | def apply_transform(mesh, translation, rotation): 78 | """Apply translation and rotation to a mesh.""" 79 | # mesh.apply_translation(-mesh.centroid) 80 | transform = np.eye(4) 81 | transform[:3, :3] = rotation 82 | transform[:3, 3] = translation 83 | mesh.apply_transform(transform) 84 | return mesh 85 | 86 | 87 | def create_primitive_mesh(geometry, translation, rotation): 88 | """Create a trimesh object from primitive geometry definitions with transformations.""" 89 | if geometry.tag.endswith("box"): 90 | size = np.fromstring(geometry.attrib["size"], sep=" ") 91 | mesh = trimesh.creation.box(extents=size) 92 | elif geometry.tag.endswith("sphere"): 93 | radius = float(geometry.attrib["radius"]) 94 | mesh = trimesh.creation.icosphere(radius=radius) 95 | elif geometry.tag.endswith("cylinder"): 96 | radius = float(geometry.attrib["radius"]) 97 | length = float(geometry.attrib["length"]) 98 | mesh = trimesh.creation.cylinder(radius=radius, height=length) 99 | else: 100 | raise ValueError(f"Unsupported geometry type: {geometry.tag}") 101 | return apply_transform(mesh, translation, rotation) 102 | 103 | 104 | def load_link_geometries(robot_name, urdf_path, link_names, collision=False): 105 | """Load geometries (trimesh objects) for specified links from a URDF file, considering origins.""" 106 | urdf_dir = os.path.dirname(urdf_path) 107 | tree = ET.parse(urdf_path) 108 | root = tree.getroot() 109 | 110 | link_geometries = {} 111 | link_colors_from_urdf = extract_colors_from_urdf(urdf_path) 112 | 113 | for link in root.findall("link"): 114 | link_name = link.attrib["name"] 115 | link_color = link_colors_from_urdf.get(link_name, None) 116 | if link_name in link_names: 117 | geom_index = "collision" if collision else "visual" 118 | link_mesh = [] 119 | for visual in link.findall(".//" + geom_index): 120 | geometry = visual.find("geometry") 121 | xyz, rotation = parse_origin(visual) 122 | try: 123 | if geometry[0].tag.endswith("mesh"): 124 | mesh_filename = geometry[0].attrib["filename"] 125 | full_mesh_path = os.path.join(urdf_dir, mesh_filename) 126 | mesh = as_mesh(trimesh.load(full_mesh_path)) 127 | scale = np.fromstring(geometry[0].attrib.get("scale", "1 1 1"), sep=" ") 128 | mesh.apply_scale(scale) 129 | mesh = apply_transform(mesh, xyz, rotation) 130 | link_mesh.append(mesh) 131 | else: # Handle primitive shapes 132 | mesh = create_primitive_mesh(geometry[0], xyz, rotation) 133 | scale = np.fromstring(geometry[0].attrib.get("scale", "1 1 1"), sep=" ") 134 | mesh.apply_scale(scale) 135 | link_mesh.append(mesh) 136 | except Exception as e: 137 | print(f"Failed to load geometry for {link_name}: {e}") 138 | if len(link_mesh) == 0: 139 | continue 140 | elif len(link_mesh) > 1: 141 | link_trimesh = as_mesh(trimesh.Scene(link_mesh)) 142 | elif len(link_mesh) == 1: 143 | link_trimesh = link_mesh[0] 144 | 145 | if link_color is not None: 146 | link_trimesh.visual.face_colors = np.array(link_color) 147 | link_geometries[link_name] = link_trimesh 148 | 149 | return link_geometries 150 | -------------------------------------------------------------------------------- /utils/multilateration.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | 7 | @typing.no_type_check 8 | def estimate_p( 9 | P: torch.FloatTensor, R: torch.FloatTensor, W: Optional[torch.FloatTensor] = None 10 | ) -> torch.FloatStorage: 11 | assert P.ndim == 3 # N x D x 1 12 | assert R.ndim == 1 # N 13 | assert P.shape[0] == R.shape[0] 14 | assert P.shape[1] in {2, 3} 15 | 16 | N, D, _ = P.shape 17 | 18 | if W is None: 19 | W = torch.ones(N, device=P.device) 20 | assert W.ndim == 1 # N 21 | W = W[:, None, None] 22 | 23 | # Shared stuff. 24 | Pt = P.permute(0, 2, 1) 25 | PPt = P @ Pt 26 | PtP = (Pt @ P).squeeze() 27 | I = torch.eye(D, device=P.device) 28 | NI = I[None].repeat(N, 1, 1) 29 | PtP_minus_r2 = (PtP - R**2)[:, None, None] 30 | 31 | # These are ripped straight from the paper, with weighting passed through. 32 | a = (W * (PtP_minus_r2 * P)).mean(dim=0) 33 | B = (W * (-2 * PPt - PtP_minus_r2 * NI)).mean(dim=0) 34 | c = (W * P).mean(dim=0) 35 | f = a + B @ c + 2 * c @ c.T @ c 36 | H = -2 * PPt.mean(dim=0) + 2 * c @ c.T 37 | q = -torch.linalg.inv(H) @ f 38 | p = q + c 39 | 40 | return p 41 | 42 | 43 | def multilateration(dro, fixed_pc): 44 | """ 45 | Compute the target point cloud described by D(R,O) matrix & fixed_pc 46 | 47 | :param dro: (B, N, N), point-wise relative distance matrix between target point cloud & fixed point cloud 48 | :param fixed_pc: (B, N, 3), point cloud as a reference for relative distance 49 | :return: (B, N, 3), the target point cloud 50 | """ 51 | assert dro.ndim == 3 and fixed_pc.ndim == 3, "multilateration() requires batch data." 52 | v_est_p = torch.vmap(torch.vmap(estimate_p, in_dims=(None, 0))) 53 | target_pc = v_est_p(fixed_pc.unsqueeze(-1), dro)[..., 0] 54 | return target_pc -------------------------------------------------------------------------------- /utils/optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cvxpy as cp 4 | from cvxpylayers.torch import CvxpyLayer 5 | 6 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | sys.path.append(ROOT_DIR) 8 | 9 | from utils.rotation import * 10 | 11 | 12 | def process_transform(pk_chain, transform, device=None): 13 | """ Compute extra link transform, and convert SE3 transform to only translation. """ 14 | new_transform = transform.copy() 15 | for name in pk_chain.get_frame_names(exclude_fixed=False): 16 | if name.startswith('extra'): 17 | frame = pk_chain.find_frame(name) 18 | parent_name = pk_chain.idx_to_frame[pk_chain.parents_indices[pk_chain.frame_to_idx[name]][-2].item()] 19 | new_transform[name] = new_transform[parent_name] @ frame.joint.offset.get_matrix()[0] 20 | for name, se3 in new_transform.items(): 21 | new_transform[name] = se3[:, :3, 3] 22 | if device is not None: 23 | new_transform[name] = new_transform[name].to(device) 24 | 25 | return new_transform 26 | 27 | 28 | def jacobian(pk_chain, q, frame_X_dict, frame_names): 29 | """ 30 | Calculate Jacobian (dX/dq) of all frames 31 | 32 | Notation: (similar as https://manipulation.csail.mit.edu/pick.html#monogram) 33 | J: jacobian, X: transform, R: rotation, p: position, v: velocity, w: angular velocity 34 | <>_BA_C: of frame A measured from frame B expressed in frame C 35 | W: world frame, J: joint frame, F: link frame 36 | 37 | :param pk_chain: get from pk.build_chain_from_urdf() 38 | :param q: (6 + DOF,) or (B, 6 + DOF), joint values (euler representation) 39 | :return: Jacobian: {frame_name: (B, 6, num_joints)} 40 | """ 41 | jacobian_dict = {} 42 | 43 | q = torch.atleast_2d(q) 44 | batch_size = q.shape[0] 45 | joint_names = pk_chain.get_joint_parameter_names() 46 | num_joints = len(joint_names) 47 | joint_name2idx = {name: idx for idx, name in enumerate(joint_names)} 48 | 49 | frames = [pk_chain.find_frame(name) for name in pk_chain.get_joint_parent_frame_names()] 50 | idx = lambda frame: joint_name2idx[frame.joint.name] 51 | 52 | transfer_X = {} 53 | for frame in frames: 54 | q_frame = q[:, idx(frame)] 55 | if frame.joint.joint_type == 'prismatic': 56 | q_frame = q_frame.unsqueeze(-1) 57 | transfer_X[idx(frame)] = frame.get_transform(q_frame).get_matrix() 58 | 59 | frame_X_dict = {f: frame_X_dict[f] for f in frame_X_dict if f in frame_names} 60 | 61 | for frame_name, frame_X in frame_X_dict.items(): 62 | jacobian = torch.zeros((batch_size, 6, num_joints), dtype=pk_chain.dtype, device=pk_chain.device) 63 | 64 | R_WF = frame_X.get_matrix()[:, :3, :3] 65 | X_JF = torch.eye(4, dtype=pk_chain.dtype, device=pk_chain.device).repeat(batch_size, 1, 1) 66 | for frame_idx in reversed(pk_chain.parents_indices[pk_chain.frame_to_idx[frame_name]].tolist()): 67 | frame = pk_chain.find_frame(pk_chain.idx_to_frame[frame_idx]) 68 | joint = frame.joint 69 | if joint.joint_type == 'fixed': 70 | if joint.offset is not None: 71 | X_JF = joint.offset.get_matrix() @ X_JF 72 | continue 73 | 74 | R_FJ = X_JF[:, :3, :3].mT 75 | R_WJ = R_WF @ R_FJ 76 | p_JF_J = X_JF[:, :3, 3][:, :, None] 77 | w_WJ_J = joint.axis[None, :, None].repeat(batch_size, 1, 1) 78 | if joint.joint_type == 'revolute': 79 | jacobian_v = R_WJ @ torch.cross(w_WJ_J, p_JF_J, dim=1) 80 | jacobian_w = R_WJ @ w_WJ_J 81 | elif joint.joint_type == 'prismatic': 82 | jacobian_v = R_WJ @ w_WJ_J 83 | jacobian_w = torch.zeros([batch_size, 3, 1], dtype=jacobian_v.dtype, device=jacobian_v.device) 84 | else: 85 | raise NotImplementedError(f"Unknown joint_type: {joint.joint_type}") 86 | 87 | joint_idx = joint_name2idx[joint.name] 88 | X_JF = transfer_X[joint_idx] @ X_JF 89 | jacobian[:, :, joint_idx] = torch.cat([jacobian_v[..., 0], jacobian_w[..., 0]], dim=1) 90 | 91 | jacobian_dict[frame_name] = jacobian 92 | return jacobian_dict 93 | 94 | 95 | def create_problem(pk_chain, frame_names): 96 | """ 97 | Only use all frame positions (ignore rotation) to optimize joint values. 98 | 99 | :param pk_chain: get from pk.build_chain_from_urdf() 100 | :param frame_names: list of frame names to optimize 101 | :return: CvxpyLayer() 102 | """ 103 | n_joint = len(pk_chain.get_joint_parameter_names()) 104 | 105 | delta_q = cp.Variable(n_joint) 106 | 107 | q = cp.Parameter(n_joint) 108 | jacobian = {} 109 | frame_xyz = {} 110 | target_frame_xyz = {} 111 | 112 | objective_expr = 0 113 | for link_name in frame_names: 114 | frame_xyz[link_name] = cp.Parameter(3) 115 | target_frame_xyz[link_name] = cp.Parameter(3) 116 | 117 | jacobian[link_name] = cp.Parameter((3, n_joint)) 118 | delta_frame_xyz = jacobian[link_name] @ delta_q 119 | 120 | predict_frame_xyz = frame_xyz[link_name] + delta_frame_xyz 121 | objective_expr += cp.norm2(predict_frame_xyz - target_frame_xyz[link_name]) 122 | objective = cp.Minimize(objective_expr) 123 | 124 | lower_joint_limits, upper_joint_limits = pk_chain.get_joint_limits() 125 | upper_limit = cp.minimum(0.5, upper_joint_limits - q) 126 | lower_limit = cp.maximum(-0.5, lower_joint_limits - q) 127 | constraints = [delta_q <= upper_limit, delta_q >= lower_limit] 128 | problem = cp.Problem(objective, constraints) 129 | 130 | layer = CvxpyLayer( 131 | problem, 132 | parameters=[ 133 | q, 134 | *frame_xyz.values(), 135 | *target_frame_xyz.values(), 136 | *jacobian.values() 137 | ], 138 | variables=[delta_q] 139 | ) 140 | return layer 141 | 142 | 143 | def optimization(pk_chain, layer, initial_q, transform, n_iter=64, verbose=False): 144 | if initial_q.shape[-1] != len(pk_chain.get_joint_parameter_names()): 145 | initial_q = q_rot6d_to_q_euler(initial_q) 146 | q = initial_q.clone() 147 | 148 | for i in range(n_iter): 149 | status = pk_chain.forward_kinematics(q) 150 | jacobians = jacobian(pk_chain, q, status, transform.keys()) 151 | 152 | frame_xyz = {} 153 | target_frame_xyz = {} 154 | jacobians_xyz = {} 155 | for link_name, link_jacobian in jacobians.items(): 156 | frame_xyz[link_name] = status[link_name].get_matrix()[:, :3, 3] 157 | target_frame_xyz[link_name] = transform[link_name] 158 | jacobians_xyz[link_name] = link_jacobian[:, :3, :] 159 | 160 | delta_q = layer( 161 | q, 162 | *list(frame_xyz.values()), 163 | *list(target_frame_xyz.values()), 164 | *list(jacobians_xyz.values()), 165 | ) 166 | q += delta_q[0] 167 | if verbose: 168 | print(f'[Step {i}], delta_q norm: {delta_q[0].norm()}') 169 | if delta_q[0].norm() < 0.3: 170 | if verbose: 171 | print("Converged at iteration:", i) 172 | break 173 | return q 174 | -------------------------------------------------------------------------------- /utils/pretrain_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is sourced with some modifications made, from 3 | https://github.com/r-pad/taxpose/blob/0c4298fa0486fd09e63bf24d618a579b66ba0f18/taxpose/utils/emb_losses.py. 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def dist2weight(xyz, func=None): 11 | d = (xyz.unsqueeze(1) - xyz.unsqueeze(2)).norm(dim=-1) 12 | if func is not None: 13 | d = func(d) 14 | w = d / d.max(dim=-1, keepdims=True)[0] 15 | w = w + torch.eye(d.shape[-1], device=d.device).unsqueeze(0).tile( 16 | [d.shape[0], 1, 1] 17 | ) 18 | return w 19 | 20 | 21 | def infonce_loss(phi_1, phi_2, weights=None, temperature=0.1): 22 | B, N, D = phi_1.shape 23 | 24 | # cosine similarity 25 | phi_1 = F.normalize(phi_1, dim=2) 26 | phi_2 = F.normalize(phi_2, dim=2) 27 | similarity = phi_1 @ phi_2.mT 28 | 29 | target = torch.arange(N, device=similarity.device).tile([B, 1]) 30 | if weights is None: 31 | weights = 1.0 32 | loss = F.cross_entropy(torch.log(weights) + (similarity / temperature), target) 33 | 34 | return loss, similarity 35 | 36 | 37 | def mean_order(similarity): 38 | order = (similarity > similarity.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)).sum(-1) 39 | return order.float().mean() / similarity.shape[-1] 40 | -------------------------------------------------------------------------------- /utils/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.spatial.transform import Rotation 3 | 4 | def matrix_to_euler(matrix): 5 | device = matrix.device 6 | # forward_kinematics() requires intrinsic euler ('XYZ') 7 | euler = Rotation.from_matrix(matrix.cpu().numpy()).as_euler('XYZ') 8 | return torch.tensor(euler, dtype=torch.float32, device=device) 9 | 10 | def euler_to_matrix(euler): 11 | device = euler.device 12 | matrix = Rotation.from_euler('XYZ', euler.cpu().numpy()).as_matrix() 13 | return torch.tensor(matrix, dtype=torch.float32, device=device) 14 | 15 | def matrix_to_rot6d(matrix): 16 | return matrix.T.reshape(9)[:6] 17 | 18 | def rot6d_to_matrix(rot6d): 19 | x = normalize(rot6d[..., 0:3]) 20 | y = normalize(rot6d[..., 3:6]) 21 | a = normalize(x + y) 22 | b = normalize(x - y) 23 | x = normalize(a + b) 24 | y = normalize(a - b) 25 | z = normalize(torch.cross(x, y, dim=-1)) 26 | matrix = torch.stack([x, y, z], dim=-2).mT 27 | return matrix 28 | 29 | def euler_to_rot6d(euler): 30 | matrix = euler_to_matrix(euler) 31 | return matrix_to_rot6d(matrix) 32 | 33 | def rot6d_to_euler(rot6d): 34 | matrix = rot6d_to_matrix(rot6d) 35 | return matrix_to_euler(matrix) 36 | 37 | def axisangle_to_matrix(axis, angle): 38 | (x, y, z), c, s = axis, torch.cos(angle), torch.sin(angle) 39 | return torch.tensor([ 40 | [(1 - c) * x * x + c, (1 - c) * x * y - s * z, (1 - c) * x * z + s * y], 41 | [(1 - c) * x * y + s * z, (1 - c) * y * y + c, (1 - c) * y * z - s * x], 42 | [(1 - c) * x * z - s * y, (1 - c) * y * z + s * x, (1 - c) * z * z + c] 43 | ]) 44 | 45 | def euler_to_quaternion(euler): 46 | device = euler.device 47 | quaternion = Rotation.from_euler('XYZ', euler.cpu().numpy()).as_quat() 48 | return torch.tensor(quaternion, dtype=torch.float32, device=device) 49 | 50 | def normalize(v): 51 | return v / torch.norm(v, dim=-1, keepdim=True) 52 | 53 | def q_euler_to_q_rot6d(q_euler): 54 | return torch.cat([q_euler[..., :3], euler_to_rot6d(q_euler[..., 3:6]), q_euler[..., 6:]], dim=-1) 55 | 56 | def q_rot6d_to_q_euler(q_rot6d): 57 | return torch.cat([q_rot6d[..., :3], rot6d_to_euler(q_rot6d[..., 3:9]), q_rot6d[..., 9:]], dim=-1) 58 | 59 | 60 | if __name__ == '__main__': 61 | """ Test correctness of above functions, no need to compare euler angle due to singularity. """ 62 | test_euler = torch.rand(3) * 2 * torch.pi 63 | 64 | test_matrix = euler_to_matrix(test_euler) 65 | test_euler_prime = matrix_to_euler(test_matrix) 66 | test_matrix_prime = euler_to_matrix(test_euler_prime) 67 | assert torch.allclose(test_matrix, test_matrix_prime), \ 68 | f"Original Matrix: {test_matrix}, Converted Matrix: {test_matrix_prime}" 69 | 70 | test_rot6d = matrix_to_rot6d(test_matrix) 71 | test_matrix_prime = rot6d_to_matrix(test_rot6d) 72 | assert torch.allclose(test_matrix, test_matrix_prime),\ 73 | f"Original Matrix: {test_matrix}, Converted Matrix: {test_matrix_prime}" 74 | 75 | test_rot6d_prime = matrix_to_rot6d(test_matrix_prime) 76 | assert torch.allclose(test_rot6d, test_rot6d_prime), \ 77 | f"Original Rot6D: {test_rot6d}, Converted Rot6D: {test_rot6d_prime}" 78 | 79 | test_euler_prime = rot6d_to_euler(test_rot6d) 80 | test_rot6d_prime = euler_to_rot6d(test_euler_prime) 81 | assert torch.allclose(test_rot6d, test_rot6d_prime), \ 82 | f"Original Rot6D: {test_rot6d}, Converted Rot6D: {test_rot6d_prime}" 83 | 84 | print("All Tests Passed!") 85 | -------------------------------------------------------------------------------- /utils/se3_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_se3_transform(P, Q): 5 | """ 6 | Compute SE3 transform between two point clouds. 7 | 8 | :param P: (N, 3) or (B, N, 3), point cloud (w/ or w/o batch) 9 | :param Q: same as P 10 | :return: SE3 transform between P and Q, (4, 4) or (B, 4, 4) 11 | """ 12 | assert P.shape == Q.shape 13 | 14 | if P.ndim == 2: # (N, 3) 15 | P_mean = torch.mean(P, dim=0) 16 | Q_mean = torch.mean(Q, dim=0) 17 | P_prime = P - P_mean 18 | Q_prime = Q - Q_mean 19 | H = P_prime.T @ Q_prime 20 | U, _, Vt = torch.linalg.svd(H) 21 | V = Vt.T 22 | R = V @ U.T 23 | if torch.linalg.det(R) < 0: 24 | V[:, -1] *= -1 25 | R = V @ U.T 26 | t = Q_mean - R @ P_mean 27 | 28 | T = torch.eye(4).to(P.device) 29 | T[:3, :3] = R 30 | T[:3, 3] = t 31 | elif P.ndim == 3: # (B, N, 3) 32 | P_mean = torch.mean(P, dim=1, keepdim=True) 33 | Q_mean = torch.mean(Q, dim=1, keepdim=True) 34 | P_prime = P - P_mean 35 | Q_prime = Q - Q_mean 36 | H = P_prime.transpose(1, 2) @ Q_prime 37 | U, _, Vt = torch.linalg.svd(H) 38 | V = Vt.transpose(1, 2) 39 | R = V @ U.transpose(1, 2) 40 | det_R = torch.linalg.det(R) 41 | VV = V.clone() 42 | VV[:, :, -1] *= torch.where(det_R < 0, -1.0, 1.0).unsqueeze(-1) 43 | RR = VV @ U.transpose(1, 2) 44 | t = Q_mean.squeeze(1) - (RR @ P_mean.transpose(1, 2)).squeeze(-1) 45 | 46 | T = torch.eye(4).repeat(P.shape[0], 1, 1).to(P.device) 47 | T[:, :3, :3] = RR 48 | T[:, :3, 3] = t 49 | else: 50 | raise RuntimeError('Unexpected point cloud shape!') 51 | 52 | return T 53 | 54 | 55 | def se3_transform_point_cloud(P, Transform): 56 | """ 57 | Apply SE3 transform on point cloud. 58 | 59 | :param P: (N, 3) or (B, N, 3), point cloud (w/ or w/o batch) 60 | :param Transform: SE3 transform (w/ or w/o batch) 61 | :return: Point Cloud after SE3 transform, (N, 3) or (B, N, 3) 62 | """ 63 | P_prime = torch.cat((P, torch.ones([*P.shape[:-1], 1], dtype=torch.float32, device=P.device)), dim=-1) 64 | P_transformed = P_prime @ Transform.mT 65 | return P_transformed[..., :3] 66 | 67 | 68 | def compute_link_pose(robot_links_pc, predict_pcs, is_train=True): 69 | """ 70 | Calculate link poses of the predicted pc. 71 | 72 | :param robot_links_pc: (train) [{link_name: (N_i, 3), ...}, ...], per link sampled points of batch robots 73 | (validate) {link_name: (N_i, 3), ...}, per link sampled points of the same robot 74 | :param predict_pcs: (B, N, 3), point cloud to calculate SE3 75 | :return: link transforms, [{link_name: (4, 4)}, ...]; 76 | transformed_pc, (B, N, 3) 77 | """ 78 | if is_train: 79 | assert predict_pcs.ndim == 3, "compute_link_pose() requires batch data during training." 80 | batch_transform = [] 81 | batch_transformed_pc = [] 82 | for idx in range(len(robot_links_pc)): 83 | links_pc = robot_links_pc[idx] 84 | predict_pc = predict_pcs[idx] 85 | 86 | global_index = 0 87 | transform = {} 88 | transformed_pc = [] 89 | for link_index, (link_name, link_pc) in enumerate(links_pc.items()): 90 | predict_pc_link = predict_pc[global_index: global_index + link_pc.shape[-2], :3] 91 | global_index += link_pc.shape[0] 92 | 93 | link_se3 = compute_se3_transform(link_pc.unsqueeze(0), predict_pc_link.unsqueeze(0))[0] # (4, 4) 94 | link_transformed_pc = se3_transform_point_cloud(link_pc, link_se3) # (N_link, 3) 95 | transform[link_name] = link_se3 96 | transformed_pc.append(link_transformed_pc) 97 | 98 | batch_transform.append(transform) 99 | batch_transformed_pc.append(torch.cat(transformed_pc, dim=0)) 100 | 101 | return batch_transform, torch.stack(batch_transformed_pc, dim=0) 102 | else: 103 | batch_transform = {} 104 | batch_transformed_pc = [] 105 | global_index = 0 106 | for link_index, (link_name, link_pc) in enumerate(robot_links_pc.items()): 107 | if predict_pcs.ndim == 3 and link_pc.ndim != 3: 108 | link_pc = link_pc.unsqueeze(0).repeat(predict_pcs.shape[0], 1, 1) 109 | predict_pc_link = predict_pcs[..., global_index: global_index + link_pc.shape[-2], :3] 110 | global_index += link_pc.shape[-2] 111 | batch_transform[link_name] = compute_se3_transform(link_pc, predict_pc_link) 112 | batch_transformed_pc.append(se3_transform_point_cloud(link_pc, batch_transform[link_name])) 113 | batch_transformed_pc = torch.cat(batch_transformed_pc, dim=-2) 114 | 115 | return batch_transform, batch_transformed_pc 116 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most of the visualization code has not been encapsulated into functions; 3 | only the part for visualizing vectors is kept in this file, and the comment format is not consistent. 4 | """ 5 | 6 | import trimesh 7 | import numpy as np 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | 11 | def normalize(x): 12 | """ 13 | Normalize the input vector. If the magnitude of the vector is zero, a small value is added to prevent division by zero. 14 | 15 | Parameters: 16 | - x (np.ndarray): Input vector to be normalized. 17 | 18 | Returns: 19 | - np.ndarray: Normalized vector. 20 | """ 21 | if len(x.shape) == 1: 22 | mag = np.linalg.norm(x) 23 | if mag == 0: 24 | mag = mag + 1e-10 25 | return x / mag 26 | else: 27 | norms = np.linalg.norm(x, axis=1, keepdims=True) 28 | norms = np.where(norms == 0, 1e-10, norms) 29 | return x / norms 30 | 31 | 32 | def sample_transform_w_normals( 33 | new_palm_center, 34 | new_face_vector, 35 | sample_roll, 36 | ori_face_vector=np.array([1.0, 0.0, 0.0]), 37 | ): 38 | """ 39 | Compute the transformation matrix from the original palm pose to a new palm pose. 40 | 41 | Parameters: 42 | - new_palm_center (np.ndarray): The point of the palm center [x, y, z]. 43 | - new_face_vector (np.ndarray): The direction vector representing the new palm facing direction. 44 | - sample_roll (float): The roll angle in range [0, 2*pi). 45 | - ori_face_vector (np.ndarray): The original direction vector representing the palm facing direction. Default is [1.0, 0.0, 0.0]. 46 | 47 | Returns: 48 | - rst_transform (np.ndarray): A 4x4 transformation matrix. 49 | """ 50 | 51 | rot_axis = np.cross(ori_face_vector, normalize(new_face_vector)) 52 | rot_axis = rot_axis / (np.linalg.norm(rot_axis) + 1e-16) 53 | rot_ang = np.arccos(np.clip(np.dot(ori_face_vector, new_face_vector), -1.0, 1.0)) 54 | 55 | if rot_ang > 3.1415 or rot_ang < -3.1415: 56 | rot_axis = ( 57 | np.array([1.0, 0.0, 0.0]) 58 | if not np.isclose(ori_face_vector, np.array([1.0, 0.0, 0.0])).all() 59 | else np.array([0.0, 1.0, 0.0]) 60 | ) 61 | 62 | rot = R.from_rotvec(rot_ang * rot_axis).as_matrix() 63 | roll_rot = R.from_rotvec(sample_roll * new_face_vector).as_matrix() 64 | 65 | final_rot = roll_rot @ rot 66 | rst_transform = np.eye(4) 67 | rst_transform[:3, :3] = final_rot 68 | rst_transform[:3, 3] = new_palm_center 69 | return rst_transform 70 | 71 | 72 | def vis_vector( 73 | start_point, 74 | vector, 75 | length=0.1, 76 | cyliner_r=0.003, 77 | color=[255, 255, 100, 245], 78 | no_arrow=False, 79 | ): 80 | """ 81 | start_points: np.ndarray, shape=(3,) 82 | vectors: np.ndarray, shape=(3,) 83 | length: cylinder length 84 | """ 85 | normalized_vector = normalize(vector) 86 | end_point = start_point + length * normalized_vector 87 | 88 | # create a mesh for the force 89 | force_cylinder = trimesh.creation.cylinder( 90 | radius=cyliner_r, segment=np.array([start_point, end_point]) 91 | ) 92 | 93 | # create a mesh for the arrowhead 94 | cone_transform = sample_transform_w_normals( 95 | end_point, normalized_vector, 0, ori_face_vector=np.array([0.0, 0.0, 1.0]) 96 | ) 97 | arrowhead_cone = trimesh.creation.cone( 98 | radius=2 * cyliner_r, height=4 * cyliner_r, transform=cone_transform 99 | ) 100 | # combine the two meshes into one 101 | if not no_arrow: 102 | force_mesh = force_cylinder + arrowhead_cone 103 | else: 104 | force_mesh = force_cylinder 105 | force_mesh.visual.face_colors = color 106 | 107 | return force_mesh 108 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import warnings 5 | from termcolor import cprint 6 | import hydra 7 | import numpy as np 8 | 9 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from model.network import create_network 13 | from data_utils.CMapDataset import create_dataloader 14 | from utils.multilateration import multilateration 15 | from utils.se3_transform import compute_link_pose 16 | from utils.optimization import * 17 | from utils.hand_model import create_hand_model 18 | from validation.validate_utils import validate_isaac 19 | 20 | 21 | @hydra.main(version_base="1.2", config_path="configs", config_name="validate") 22 | def main(cfg): 23 | device = torch.device(f'cuda:{cfg.gpu}') 24 | batch_size = cfg.dataset.batch_size 25 | print(f"Device: {device}") 26 | print('Name:', cfg.name) 27 | 28 | os.makedirs(os.path.join(ROOT_DIR, 'validate_output'), exist_ok=True) 29 | log_file_name = os.path.join(ROOT_DIR, f'validate_output/{cfg.name}.log') 30 | print('Log file:', log_file_name) 31 | for validate_epoch in cfg.validate_epochs: 32 | print(f"************************ Validating epoch {validate_epoch} ************************") 33 | with open(log_file_name, 'a') as f: 34 | print(f"************************ Validating epoch {validate_epoch} ************************", file=f) 35 | 36 | network = create_network(cfg.model, mode='validate').to(device) 37 | network.load_state_dict(torch.load(f"output/{cfg.name}/state_dict/epoch_{validate_epoch}.pth", map_location=device)) 38 | network.eval() 39 | 40 | dataloader = create_dataloader(cfg.dataset, is_train=False) 41 | 42 | global_robot_name = None 43 | hand = None 44 | all_success_q = [] 45 | time_list = [] 46 | success_num = 0 47 | total_num = 0 48 | vis_info = [] 49 | for i, data in enumerate(dataloader): 50 | robot_name = data['robot_name'] 51 | object_name = data['object_name'] 52 | 53 | if robot_name != global_robot_name: 54 | if global_robot_name is not None: 55 | all_success_q = torch.cat(all_success_q, dim=0) 56 | diversity_std = torch.std(all_success_q, dim=0).mean() 57 | times = np.array(time_list) 58 | time_mean = np.mean(times) 59 | time_std = np.std(times) 60 | 61 | success_rate = success_num / total_num * 100 62 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 63 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 64 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 65 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 66 | with open(log_file_name, 'a') as f: 67 | cprint(f"[{global_robot_name}]", 'magenta', end=' ', file=f) 68 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ', file=f) 69 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ', file=f) 70 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue', file=f) 71 | 72 | all_success_q = [] 73 | time_list = [] 74 | success_num = 0 75 | total_num = 0 76 | hand = create_hand_model(robot_name, device) 77 | global_robot_name = robot_name 78 | 79 | initial_q_list = [] 80 | predict_q_list = [] 81 | object_pc_list = [] 82 | mlat_pc_list = [] 83 | transform_list = [] 84 | data_count = 0 85 | while data_count != batch_size: 86 | split_num = min(batch_size - data_count, cfg.split_batch_size) 87 | 88 | initial_q = data['initial_q'][data_count : data_count + split_num].to(device) 89 | robot_pc = data['robot_pc'][data_count : data_count + split_num].to(device) 90 | object_pc = data['object_pc'][data_count : data_count + split_num].to(device) 91 | 92 | data_count += split_num 93 | 94 | with torch.no_grad(): 95 | dro = network( 96 | robot_pc, 97 | object_pc 98 | )['dro'].detach() 99 | 100 | mlat_pc = multilateration(dro, object_pc) 101 | transform, _ = compute_link_pose(hand.links_pc, mlat_pc, is_train=False) 102 | optim_transform = process_transform(hand.pk_chain, transform) 103 | 104 | layer = create_problem(hand.pk_chain, optim_transform.keys()) 105 | start_time = time.time() 106 | predict_q = optimization(hand.pk_chain, layer, initial_q, optim_transform) 107 | end_time = time.time() 108 | print(f"[{data_count}/{batch_size}] Optimization time: {end_time - start_time:.4f} s") 109 | time_list.append(end_time - start_time) 110 | 111 | initial_q_list.append(initial_q) 112 | predict_q_list.append(predict_q) 113 | object_pc_list.append(object_pc) 114 | mlat_pc_list.append(mlat_pc) 115 | transform_list.append(transform) 116 | 117 | initial_q_batch = torch.cat(initial_q_list, dim=0) 118 | predict_q_batch = torch.cat(predict_q_list, dim=0) 119 | object_pc_batch = torch.cat(object_pc_list, dim=0) 120 | mlat_pc_batch = torch.cat(mlat_pc_list, dim=0) 121 | transform_batch = {} 122 | for transform in transform_list: 123 | for k, v in transform.items(): 124 | transform_batch[k] = v if k not in transform_batch else torch.cat((transform_batch[k], v), dim=0) 125 | 126 | success, isaac_q = validate_isaac(robot_name, object_name, predict_q_batch, gpu=cfg.gpu) 127 | succ_num = success.sum().item() if success is not None else -1 128 | success_q = predict_q_batch[success] 129 | all_success_q.append(success_q) 130 | 131 | vis_info.append({ 132 | 'robot_name': robot_name, 133 | 'object_name': object_name, 134 | 'initial_q': initial_q_batch, 135 | 'predict_q': predict_q_batch, 136 | 'object_pc': object_pc_batch, 137 | 'mlat_pc': mlat_pc_batch, 138 | 'predict_transform': transform_batch, 139 | 'success': success, 140 | 'isaac_q': isaac_q 141 | }) 142 | 143 | cprint(f"[{robot_name}/{object_name}]", 'light_blue', end=' ') 144 | cprint(f"Result: {succ_num}/{batch_size}({succ_num / batch_size * 100:.2f}%)", 'green') 145 | with open(log_file_name, 'a') as f: 146 | cprint(f"[{robot_name}/{object_name}]", 'light_blue', end=' ', file=f) 147 | cprint(f"Result: {succ_num}/{batch_size}({succ_num / batch_size * 100:.2f}%)", 'green', file=f) 148 | success_num += succ_num 149 | total_num += batch_size 150 | 151 | all_success_q = torch.cat(all_success_q, dim=0) 152 | diversity_std = torch.std(all_success_q, dim=0).mean() 153 | 154 | times = np.array(time_list) 155 | time_mean = np.mean(times) 156 | time_std = np.std(times) 157 | 158 | success_rate = success_num / total_num * 100 159 | cprint(f"[{global_robot_name}]", 'magenta', end=' ') 160 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ') 161 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ') 162 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue') 163 | with open(log_file_name, 'a') as f: 164 | cprint(f"[{global_robot_name}]", 'magenta', end=' ', file=f) 165 | cprint(f"Result: {success_num}/{total_num}({success_rate:.2f}%)", 'yellow', end=' ', file=f) 166 | cprint(f"Std: {diversity_std:.3f}", 'cyan', end=' ', file=f) 167 | cprint(f"Time: (mean) {time_mean:.2f} s, (std) {time_std:.2f} s", 'blue', file=f) 168 | 169 | vis_info_file = f'{cfg.name}_epoch{validate_epoch}' 170 | os.makedirs(os.path.join(ROOT_DIR, 'vis_info'), exist_ok=True) 171 | torch.save(vis_info, os.path.join(ROOT_DIR, f'vis_info/{vis_info_file}.pt')) 172 | 173 | 174 | if __name__ == "__main__": 175 | warnings.simplefilter(action='ignore', category=FutureWarning) 176 | torch.set_num_threads(8) 177 | main() 178 | -------------------------------------------------------------------------------- /validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenyuwei2003/DRO-Grasp/b312055b4a20f73ddfeb3ffc8a1a6c80d48bbe31/validation/__init__.py -------------------------------------------------------------------------------- /validation/asset_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | 4 | NVIDIA CORPORATION and its licensors retain all intellectual property 5 | and proprietary rights in and to this software, related documentation 6 | and any modifications thereto. Any use, reproduction, disclosure or 7 | distribution of this software and related documentation without an express 8 | license agreement from NVIDIA CORPORATION is strictly prohibited. 9 | 10 | 11 | Asset and Environment Information 12 | --------------------------------- 13 | Demonstrates introspection capabilities of the gym api at the asset and environment levels 14 | - Once an asset is loaded its properties can be queried 15 | - Assets in environments can be queried and their current states be retrieved 16 | """ 17 | 18 | import os 19 | from isaacgym import gymapi 20 | from isaacgym import gymutil 21 | 22 | 23 | def print_asset_info(gym, asset, name): 24 | print("======== Asset info %s: ========" % (name)) 25 | num_bodies = gym.get_asset_rigid_body_count(asset) 26 | num_joints = gym.get_asset_joint_count(asset) 27 | num_dofs = gym.get_asset_dof_count(asset) 28 | print("Got %d bodies, %d joints, and %d DOFs" % 29 | (num_bodies, num_joints, num_dofs)) 30 | 31 | # Iterate through bodies 32 | print("Bodies:") 33 | for i in range(num_bodies): 34 | name = gym.get_asset_rigid_body_name(asset, i) 35 | print(" %2d: '%s'" % (i, name)) 36 | 37 | # Iterate through joints 38 | print("Joints:") 39 | for i in range(num_joints): 40 | name = gym.get_asset_joint_name(asset, i) 41 | type = gym.get_asset_joint_type(asset, i) 42 | type_name = gym.get_joint_type_string(type) 43 | print(" %2d: '%s' (%s)" % (i, name, type_name)) 44 | 45 | # iterate through degrees of freedom (DOFs) 46 | print("DOFs:") 47 | for i in range(num_dofs): 48 | name = gym.get_asset_dof_name(asset, i) 49 | type = gym.get_asset_dof_type(asset, i) 50 | type_name = gym.get_dof_type_string(type) 51 | print(" %2d: '%s' (%s)" % (i, name, type_name)) 52 | 53 | 54 | def print_actor_info(gym, env, actor_handle): 55 | 56 | name = gym.get_actor_name(env, actor_handle) 57 | 58 | body_names = gym.get_actor_rigid_body_names(env, actor_handle) 59 | body_dict = gym.get_actor_rigid_body_dict(env, actor_handle) 60 | 61 | joint_names = gym.get_actor_joint_names(env, actor_handle) 62 | joint_dict = gym.get_actor_joint_dict(env, actor_handle) 63 | 64 | dof_names = gym.get_actor_dof_names(env, actor_handle) 65 | dof_dict = gym.get_actor_dof_dict(env, actor_handle) 66 | 67 | print() 68 | print("===== Actor: %s =======================================" % name) 69 | 70 | print("\nBodies") 71 | print(body_names) 72 | print(body_dict) 73 | 74 | print("\nJoints") 75 | print(joint_names) 76 | print(joint_dict) 77 | 78 | print("\n Degrees Of Freedom (DOFs)") 79 | print(dof_names) 80 | print(dof_dict) 81 | print() 82 | 83 | # Get body state information 84 | body_states = gym.get_actor_rigid_body_states( 85 | env, actor_handle, gymapi.STATE_ALL) 86 | 87 | # Print some state slices 88 | print("Poses from Body State:") 89 | print(body_states['pose']) # print just the poses 90 | 91 | print("\nVelocities from Body State:") 92 | print(body_states['vel']) # print just the velocities 93 | print() 94 | 95 | # iterate through bodies and print name and position 96 | body_positions = body_states['pose']['p'] 97 | for i in range(len(body_names)): 98 | print("Body '%s' has position" % body_names[i], body_positions[i]) 99 | 100 | print("\nDOF states:") 101 | 102 | # get DOF states 103 | dof_states = gym.get_actor_dof_states(env, actor_handle, gymapi.STATE_ALL) 104 | 105 | # print some state slices 106 | # Print all states for each degree of freedom 107 | print(dof_states) 108 | print() 109 | 110 | # iterate through DOFs and print name and position 111 | dof_positions = dof_states['pos'] 112 | for i in range(len(dof_names)): 113 | print("DOF '%s' has position" % dof_names[i], dof_positions[i]) 114 | -------------------------------------------------------------------------------- /validation/isaac_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import warnings 6 | from termcolor import cprint 7 | 8 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(ROOT_DIR) 10 | 11 | from validation.isaac_validator import IsaacValidator # IsaacGym must be imported before PyTorch 12 | from utils.hand_model import create_hand_model 13 | from utils.rotation import q_rot6d_to_q_euler 14 | 15 | import torch 16 | 17 | 18 | def isaac_main( 19 | mode: str, 20 | robot_name: str, 21 | object_name: str, 22 | batch_size: int, 23 | q_batch: torch.Tensor = None, 24 | gpu: int = 0, 25 | use_gui: bool = False 26 | ): 27 | """ 28 | For filtering dataset and validating grasps. 29 | 30 | :param mode: str, 'filter' or 'validation' 31 | :param robot_name: str 32 | :param object_name: str 33 | :param batch_size: int, number of grasps in Isaac Gym simultaneously 34 | :param q_batch: torch.Tensor (validation only) 35 | :param gpu: int, specify the GPU device used by Isaac Gym 36 | :param use_gui: bool, whether to visualize Isaac Gym simulation process 37 | :return: success: (batch_size,), bool, whether each grasp is successful in Isaac Gym; 38 | q_isaac: (success_num, DOF), torch.float32, successful joint values after the grasp phase 39 | """ 40 | if mode == 'filter' and batch_size == 0: # special judge for num_per_object = 0 in dataset 41 | return 0, None 42 | if use_gui: # for unknown reason otherwise will segmentation fault :( 43 | gpu = 0 44 | 45 | data_urdf_path = os.path.join(ROOT_DIR, 'data/data_urdf') 46 | urdf_assets_meta = json.load(open(os.path.join(data_urdf_path, 'robot/urdf_assets_meta.json'))) 47 | robot_urdf_path = urdf_assets_meta['urdf_path'][robot_name] 48 | object_name_split = object_name.split('+') if object_name is not None else None 49 | # object_urdf_path = f'{object_name_split[0]}/{object_name_split[1]}/{object_name_split[1]}.urdf' 50 | object_urdf_path = f'{object_name_split[0]}/{object_name_split[1]}/coacd_decomposed_object_one_link.urdf' 51 | 52 | hand = create_hand_model(robot_name) 53 | joint_orders = hand.get_joint_orders() 54 | 55 | simulator = IsaacValidator( 56 | robot_name=robot_name, 57 | joint_orders=joint_orders, 58 | batch_size=batch_size, 59 | gpu=gpu, 60 | is_filter=(mode == 'filter'), 61 | use_gui=use_gui 62 | ) 63 | print("[Isaac] IsaacValidator is created.") 64 | 65 | simulator.set_asset( 66 | robot_path=os.path.join(data_urdf_path, 'robot'), 67 | robot_file=robot_urdf_path[21:], # ignore 'data/data_urdf/robot/' 68 | object_path=os.path.join(data_urdf_path, 'object'), 69 | object_file=object_urdf_path 70 | ) 71 | simulator.create_envs() 72 | print("[Isaac] IsaacValidator preparation is done.") 73 | 74 | if mode == 'filter': 75 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset/cmap_dataset.pt') 76 | metadata = torch.load(dataset_path)['metadata'] 77 | q_batch = [m[1] for m in metadata if m[2] == object_name and m[3] == robot_name] 78 | q_batch = torch.stack(q_batch, dim=0).to(torch.device('cpu')) 79 | if q_batch.shape[-1] != len(joint_orders): 80 | q_batch = q_rot6d_to_q_euler(q_batch) 81 | 82 | simulator.set_actor_pose_dof(q_batch.to(torch.device('cpu'))) 83 | success, q_isaac = simulator.run_sim() 84 | simulator.destroy() 85 | 86 | return success, q_isaac 87 | 88 | 89 | # for Python scripts subprocess call to avoid Isaac Gym GPU memory leak problem 90 | if __name__ == '__main__': 91 | warnings.simplefilter(action='ignore', category=FutureWarning) 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--mode', type=str, required=True) 94 | parser.add_argument('--robot_name', type=str, required=True) 95 | parser.add_argument('--object_name', type=str, required=True) 96 | parser.add_argument('--batch_size', type=int, required=True) 97 | parser.add_argument('--q_file', type=str) 98 | parser.add_argument('--gpu', default=0, type=int) 99 | parser.add_argument('--use_gui', action='store_true') 100 | args = parser.parse_args() 101 | 102 | print(f'GPU: {args.gpu}') 103 | assert args.mode in ['filter', 'validation'], f"Unknown mode: {args.mode}!" 104 | q_batch = torch.load(args.q_file, map_location=f'cpu') if args.q_file is not None else None 105 | success, q_isaac = isaac_main( 106 | mode=args.mode, 107 | robot_name=args.robot_name, 108 | object_name=args.object_name, 109 | batch_size=args.batch_size, 110 | q_batch=q_batch, 111 | gpu=args.gpu, 112 | use_gui=args.use_gui 113 | ) 114 | 115 | success_num = success.sum().item() 116 | if args.mode == 'filter': 117 | print(f"<{args.robot_name}/{args.object_name}> before: {args.batch_size}, after: {success_num}") 118 | if success_num > 0: 119 | q_filtered = q_isaac[success] 120 | save_dir = str(os.path.join(ROOT_DIR, 'data/CMapDataset_filtered', args.robot_name)) 121 | os.makedirs(save_dir, exist_ok=True) 122 | torch.save(q_filtered, os.path.join(save_dir, f'{args.object_name}_{success_num}.pt')) 123 | elif args.mode == 'validation': 124 | cprint(f"[{args.robot_name}/{args.object_name}] Result: {success_num}/{args.batch_size}", 'green') 125 | save_data = { 126 | 'success': success, 127 | 'q_isaac': q_isaac 128 | } 129 | os.makedirs(os.path.join(ROOT_DIR, 'tmp'), exist_ok=True) 130 | torch.save(save_data, os.path.join(ROOT_DIR, f'tmp/isaac_main_ret_{args.gpu}.pt')) 131 | -------------------------------------------------------------------------------- /validation/isaac_validator.py: -------------------------------------------------------------------------------- 1 | from isaacgym import gymapi 2 | from isaacgym import gymtorch 3 | 4 | import os 5 | import sys 6 | import time 7 | import numpy as np 8 | import torch 9 | from scipy.spatial.transform import Rotation 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(ROOT_DIR) 13 | 14 | from utils.controller import controller 15 | 16 | 17 | class IsaacValidator: 18 | def __init__( 19 | self, 20 | robot_name, 21 | joint_orders, 22 | batch_size, 23 | gpu=0, 24 | is_filter=False, 25 | use_gui=False, 26 | robot_friction=3., 27 | object_friction=3., 28 | steps_per_sec=100, 29 | grasp_step=100, 30 | debug_interval=0.01 31 | ): 32 | self.gym = gymapi.acquire_gym() 33 | 34 | self.robot_name = robot_name 35 | self.joint_orders = joint_orders 36 | self.batch_size = batch_size 37 | self.gpu = gpu 38 | self.is_filter = is_filter 39 | self.robot_friction = robot_friction 40 | self.object_friction = object_friction 41 | self.steps_per_sec = steps_per_sec 42 | self.grasp_step = grasp_step 43 | self.debug_interval = debug_interval 44 | 45 | self.envs = [] 46 | self.robot_handles = [] 47 | self.object_handles = [] 48 | self.robot_asset = None 49 | self.object_asset = None 50 | self.rigid_body_num = None 51 | self.object_force = None 52 | self.urdf2isaac_order = None 53 | self.isaac2urdf_order = None 54 | 55 | self.sim_params = gymapi.SimParams() 56 | # set common parameters 57 | self.sim_params.dt = 1 / steps_per_sec 58 | self.sim_params.substeps = 2 59 | self.sim_params.gravity = gymapi.Vec3(0.0, 0.0, 0.0) 60 | #self.sim_params.use_gpu_pipeline = True 61 | # set PhysX-specific parameters 62 | self.sim_params.physx.use_gpu = True 63 | self.sim_params.physx.solver_type = 1 64 | self.sim_params.physx.num_position_iterations = 8 65 | self.sim_params.physx.num_velocity_iterations = 0 66 | self.sim_params.physx.contact_offset = 0.01 67 | self.sim_params.physx.rest_offset = 0.0 68 | 69 | self.sim = self.gym.create_sim(self.gpu, self.gpu, gymapi.SIM_PHYSX, self.sim_params) 70 | self._rigid_body_states = self.gym.acquire_rigid_body_state_tensor(self.sim) 71 | self._dof_states = self.gym.acquire_dof_state_tensor(self.sim) 72 | 73 | self.viewer = None 74 | if use_gui: 75 | self.has_viewer = True 76 | self.camera_props = gymapi.CameraProperties() 77 | self.camera_props.width = 1920 78 | self.camera_props.height = 1080 79 | self.camera_props.use_collision_geometry = True 80 | self.viewer = self.gym.create_viewer(self.sim, self.camera_props) 81 | self.gym.viewer_camera_look_at(self.viewer, None, gymapi.Vec3(1, 0, 0), gymapi.Vec3(0, 0, 0)) 82 | else: 83 | self.has_viewer = False 84 | 85 | self.robot_asset_options = gymapi.AssetOptions() 86 | self.robot_asset_options.disable_gravity = True 87 | self.robot_asset_options.fix_base_link = True 88 | self.robot_asset_options.collapse_fixed_joints = True 89 | 90 | self.object_asset_options = gymapi.AssetOptions() 91 | self.object_asset_options.override_com = True 92 | self.object_asset_options.override_inertia = True 93 | self.object_asset_options.density = 500 94 | 95 | def set_asset(self, robot_path, robot_file, object_path, object_file): 96 | self.robot_asset = self.gym.load_asset(self.sim, robot_path, robot_file, self.robot_asset_options) 97 | self.object_asset = self.gym.load_asset(self.sim, object_path, object_file, self.object_asset_options) 98 | self.rigid_body_num = (self.gym.get_asset_rigid_body_count(self.robot_asset) 99 | + self.gym.get_asset_rigid_body_count(self.object_asset)) 100 | # print_asset_info(gym, self.robot_asset, 'robot') 101 | # print_asset_info(gym, self.object_asset, 'object') 102 | 103 | def create_envs(self): 104 | for env_idx in range(self.batch_size): 105 | env = self.gym.create_env( 106 | self.sim, 107 | gymapi.Vec3(-1, -1, -1), 108 | gymapi.Vec3(1, 1, 1), 109 | int(self.batch_size ** 0.5) 110 | ) 111 | self.envs.append(env) 112 | 113 | # draw world frame 114 | if self.has_viewer: 115 | x_axis_dir = np.array([0, 0, 0, 1, 0, 0], dtype=np.float32) 116 | x_axis_color = np.array([1, 0, 0], dtype=np.float32) 117 | self.gym.add_lines(self.viewer, env, 1, x_axis_dir, x_axis_color) 118 | y_axis_dir = np.array([0, 0, 0, 0, 1, 0], dtype=np.float32) 119 | y_axis_color = np.array([0, 1, 0], dtype=np.float32) 120 | self.gym.add_lines(self.viewer, env, 1, y_axis_dir, y_axis_color) 121 | z_axis_dir = np.array([0, 0, 0, 0, 0, 1], dtype=np.float32) 122 | z_axis_color = np.array([0, 0, 1], dtype=np.float32) 123 | self.gym.add_lines(self.viewer, env, 1, z_axis_dir, z_axis_color) 124 | 125 | # object actor setting 126 | object_handle = self.gym.create_actor( 127 | env, 128 | self.object_asset, 129 | gymapi.Transform(), 130 | f'object_{env_idx}', 131 | env_idx 132 | ) 133 | self.object_handles.append(object_handle) 134 | 135 | object_shape_properties = self.gym.get_actor_rigid_shape_properties(env, object_handle) 136 | for i in range(len(object_shape_properties)): 137 | object_shape_properties[i].friction = self.object_friction 138 | self.gym.set_actor_rigid_shape_properties(env, object_handle, object_shape_properties) 139 | 140 | # robot actor setting 141 | robot_handle = self.gym.create_actor( 142 | env, 143 | self.robot_asset, 144 | gymapi.Transform(), 145 | f'robot_{env_idx}', 146 | env_idx 147 | ) 148 | self.robot_handles.append(robot_handle) 149 | 150 | robot_properties = self.gym.get_actor_dof_properties(env, robot_handle) 151 | robot_properties["driveMode"].fill(gymapi.DOF_MODE_POS) 152 | robot_properties["stiffness"].fill(1000) 153 | robot_properties["damping"].fill(200) 154 | self.gym.set_actor_dof_properties(env, robot_handle, robot_properties) 155 | 156 | object_shape_properties = self.gym.get_actor_rigid_shape_properties(env, robot_handle) 157 | for i in range(len(object_shape_properties)): 158 | object_shape_properties[i].friction = self.robot_friction 159 | self.gym.set_actor_rigid_shape_properties(env, robot_handle, object_shape_properties) 160 | 161 | # print_actor_info(self.gym, env, robot_handle) 162 | # print_actor_info(self.gym, env, object_handle) 163 | 164 | # assume robots & objects in the same batch are the same 165 | obj_property = self.gym.get_actor_rigid_body_properties(self.envs[0], self.object_handles[0]) 166 | object_mass = [obj_property[i].mass for i in range(len(obj_property))] 167 | object_mass = torch.tensor(object_mass) 168 | self.object_force = 0.5 * object_mass 169 | 170 | self.urdf2isaac_order = np.zeros(len(self.joint_orders), dtype=np.int32) 171 | self.isaac2urdf_order = np.zeros(len(self.joint_orders), dtype=np.int32) 172 | for urdf_idx, joint_name in enumerate(self.joint_orders): 173 | isaac_idx = self.gym.find_actor_dof_index(self.envs[0], self.robot_handles[0], joint_name, gymapi.DOMAIN_ACTOR) 174 | self.urdf2isaac_order[isaac_idx] = urdf_idx 175 | self.isaac2urdf_order[urdf_idx] = isaac_idx 176 | 177 | def set_actor_pose_dof(self, q): 178 | self.gym.prepare_sim(self.sim) 179 | 180 | # set all actors to origin 181 | _root_state = self.gym.acquire_actor_root_state_tensor(self.sim) 182 | root_state = gymtorch.wrap_tensor(_root_state) 183 | root_state[:] = torch.tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=torch.float32) 184 | self.gym.set_actor_root_state_tensor(self.sim, _root_state) 185 | 186 | outer_q, inner_q = controller(self.robot_name, q) 187 | 188 | for env_idx in range(len(self.envs)): 189 | env = self.envs[env_idx] 190 | robot_handle = self.robot_handles[env_idx] 191 | 192 | dof_states_initial = self.gym.get_actor_dof_states(env, robot_handle, gymapi.STATE_ALL).copy() 193 | dof_states_initial['pos'] = outer_q[env_idx, self.urdf2isaac_order] 194 | self.gym.set_actor_dof_states(env, robot_handle, dof_states_initial, gymapi.STATE_ALL) 195 | 196 | dof_states_target = self.gym.get_actor_dof_states(env, robot_handle, gymapi.STATE_ALL).copy() 197 | dof_states_target['pos'] = inner_q[env_idx, self.urdf2isaac_order] 198 | self.gym.set_actor_dof_position_targets(env, robot_handle, dof_states_target["pos"]) 199 | 200 | def run_sim(self): 201 | # controller phase 202 | for step in range(self.grasp_step): 203 | self.gym.simulate(self.sim) 204 | 205 | if self.has_viewer: 206 | if self.gym.query_viewer_has_closed(self.viewer): 207 | break 208 | t = time.time() 209 | while time.time() - t < self.debug_interval: 210 | self.gym.step_graphics(self.sim) 211 | self.gym.draw_viewer(self.viewer, self.sim, render_collision=True) 212 | 213 | self.gym.refresh_rigid_body_state_tensor(self.sim) 214 | start_pos = gymtorch.wrap_tensor(self._rigid_body_states)[::self.rigid_body_num, :3].clone() 215 | 216 | force_tensor = torch.zeros([len(self.envs), self.rigid_body_num, 3]) # env, rigid_body, xyz 217 | x_pos_force = force_tensor.clone() 218 | x_pos_force[:, 0, 0] = self.object_force 219 | x_neg_force = force_tensor.clone() 220 | x_neg_force[:, 0, 0] = -self.object_force 221 | y_pos_force = force_tensor.clone() 222 | y_pos_force[:, 0, 1] = self.object_force 223 | y_neg_force = force_tensor.clone() 224 | y_neg_force[:, 0, 1] = -self.object_force 225 | z_pos_force = force_tensor.clone() 226 | z_pos_force[:, 0, 2] = self.object_force 227 | z_neg_force = force_tensor.clone() 228 | z_neg_force[:, 0, 2] = -self.object_force 229 | force_list = [x_pos_force, y_pos_force, z_pos_force, x_neg_force, y_neg_force, z_neg_force] 230 | 231 | # force phase 232 | for step in range(self.steps_per_sec * 6): 233 | self.gym.apply_rigid_body_force_tensors(self.sim, 234 | gymtorch.unwrap_tensor(force_list[step // self.steps_per_sec]), 235 | None, 236 | gymapi.ENV_SPACE) 237 | self.gym.simulate(self.sim) 238 | self.gym.fetch_results(self.sim, True) 239 | 240 | if self.has_viewer: 241 | if self.gym.query_viewer_has_closed(self.viewer): 242 | break 243 | t = time.time() 244 | while time.time() - t < self.debug_interval: 245 | self.gym.step_graphics(self.sim) 246 | self.gym.draw_viewer(self.viewer, self.sim, render_collision=True) 247 | 248 | self.gym.refresh_rigid_body_state_tensor(self.sim) 249 | end_pos = gymtorch.wrap_tensor(self._rigid_body_states)[::self.rigid_body_num, :3].clone() 250 | 251 | distance = (end_pos - start_pos).norm(dim=-1) 252 | 253 | if self.is_filter: 254 | success = (distance <= 0.02) & (end_pos.norm(dim=-1) <= 0.05) 255 | else: 256 | success = (distance <= 0.02) 257 | 258 | # apply inverse object transform to robot to get new joint value 259 | self.gym.refresh_rigid_body_state_tensor(self.sim) 260 | object_pose = gymtorch.wrap_tensor(self._rigid_body_states).clone()[::self.rigid_body_num, :7] # batch_size, 7 (xyz + quat) 261 | object_transform = np.eye(4)[np.newaxis].repeat(self.batch_size, axis=0) 262 | object_transform[:, :3, 3] = object_pose[:, :3] 263 | object_transform[:, :3, :3] = Rotation.from_quat(object_pose[:, 3:7]).as_matrix() 264 | 265 | self.gym.refresh_dof_state_tensor(self.sim) 266 | dof_states = gymtorch.wrap_tensor(self._dof_states).clone().reshape(len(self.envs), -1, 2)[:, :, 0] # batch_size, DOF (xyz + euler + joint) 267 | robot_transform = np.eye(4)[np.newaxis].repeat(self.batch_size, axis=0) 268 | robot_transform[:, :3, 3] = dof_states[:, :3] 269 | robot_transform[:, :3, :3] = Rotation.from_euler('XYZ', dof_states[:, 3:6]).as_matrix() 270 | 271 | robot_transform = np.linalg.inv(object_transform) @ robot_transform 272 | dof_states[:, :3] = torch.tensor(robot_transform[:, :3, 3]) 273 | dof_states[:, 3:6] = torch.tensor(Rotation.from_matrix(robot_transform[:, :3, :3]).as_euler('XYZ')) 274 | q_isaac = dof_states[:, self.isaac2urdf_order].to(torch.device('cpu')) 275 | 276 | return success, q_isaac 277 | 278 | def reset_simulator(self): 279 | self.gym.destroy_sim(self.sim) 280 | if self.has_viewer: 281 | self.gym.destroy_viewer(self.viewer) 282 | self.viewer = self.gym.create_viewer(self.sim, self.camera_props) 283 | self.sim = self.gym.create_sim(self.gpu, self.gpu, gymapi.SIM_PHYSX, self.sim_params) 284 | for env in self.envs: 285 | self.gym.destroy_env(env) 286 | self.envs = [] 287 | self.robot_handles = [] 288 | self.object_handles = [] 289 | self.robot_asset = None 290 | self.object_asset = None 291 | 292 | def destroy(self): 293 | for env in self.envs: 294 | self.gym.destroy_env(env) 295 | self.gym.destroy_sim(self.sim) 296 | if self.has_viewer: 297 | self.gym.destroy_viewer(self.viewer) 298 | del self.gym 299 | -------------------------------------------------------------------------------- /validation/validate_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | from tqdm import tqdm 5 | from termcolor import cprint 6 | import torch 7 | import trimesh 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from utils.controller import controller 13 | 14 | 15 | def validate_depth(hand, object_name, q_list_validate, threshold=0.005, exact=True): 16 | """ 17 | Calculate the penetration depth of predicted grasps into the object. 18 | 19 | :param hand: HandModel() 20 | :param object_name: str 21 | :param q_list_validate: list, joint values to validate 22 | :param threshold: float, criteria for determining success in depth 23 | :param exact: bool, if false, use point cloud instead of mesh to compute (much faster with minor error) 24 | :param print_info: bool, whether to print progress information 25 | :return: (list , list ), success list & depth list 26 | """ 27 | name = object_name.split('+') 28 | if exact: 29 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') 30 | object_mesh = trimesh.load_mesh(object_path) 31 | else: 32 | object_path = os.path.join(ROOT_DIR, f'data/PointCloud/object/{name[0]}/{name[1]}.pt') 33 | object_pc_normals = torch.load(object_path).to(hand.device) 34 | object_pc = object_pc_normals[:, :3] 35 | normals = object_pc_normals[:, 3:] 36 | 37 | result_list = [] 38 | depth_list = [] 39 | q_list_initial = [] 40 | for q in q_list_validate: 41 | initial_q, _ = controller(hand.robot_name, q) 42 | q_list_initial.append(initial_q) 43 | for q in tqdm(q_list_initial): 44 | robot_pc = hand.get_transformed_links_pc(q)[:, :3] 45 | if exact: 46 | robot_pc = robot_pc.cpu() 47 | _, distance, _ = trimesh.proximity.ProximityQuery(object_mesh).on_surface(robot_pc) 48 | distance = distance[object_mesh.contains(robot_pc)] 49 | depth = distance.max() if distance.size else 0. 50 | else: 51 | distance = torch.cdist(robot_pc, object_pc) 52 | distance, index = torch.min(distance, dim=-1) 53 | object_pc_indexed, normals_indexed = object_pc[index], normals[index] 54 | get_sign = torch.vmap(lambda x, y: torch.where(torch.dot(x, y) >= 0, 1, -1)) 55 | signed_distance = distance * get_sign(robot_pc - object_pc_indexed, normals_indexed) 56 | depth, _ = torch.min(signed_distance, dim=-1) 57 | depth = -depth.item() if depth.item() < 0 else 0. 58 | 59 | result_list.append(depth <= threshold) 60 | depth_list.append(round(depth * 1000, 2)) 61 | 62 | return result_list, depth_list 63 | 64 | 65 | def validate_isaac(robot_name, object_name, q_batch, gpu: int = 0): 66 | """ 67 | Wrap function for subprocess call (isaac_main.py) to avoid Isaac Gym GPU memory leak problem. 68 | 69 | :param robot_name: str 70 | :param object_name: str 71 | :param q_batch: torch.Tensor, joint values to validate 72 | :param gpu: int 73 | :return: (list , list ), success list & info list 74 | """ 75 | os.makedirs(os.path.join(ROOT_DIR, 'tmp'), exist_ok=True) 76 | q_file_path = str(os.path.join(ROOT_DIR, f'tmp/q_list_validate_{gpu}.pt')) 77 | torch.save(q_batch, q_file_path) 78 | batch_size = q_batch.shape[0] 79 | args = [ 80 | 'python', 81 | os.path.join(ROOT_DIR, 'validation/isaac_main.py'), 82 | '--mode', 'validation', 83 | '--robot_name', robot_name, 84 | '--object_name', object_name, 85 | '--batch_size', str(batch_size), 86 | '--q_file', q_file_path, 87 | '--gpu', str(gpu), 88 | # '--use_gui' 89 | ] 90 | ret = subprocess.run(args, capture_output=True, text=True) 91 | try: 92 | ret_file_path = os.path.join(ROOT_DIR, f'tmp/isaac_main_ret_{gpu}.pt') 93 | save_data = torch.load(ret_file_path) 94 | success = save_data['success'] 95 | q_isaac = save_data['q_isaac'] 96 | os.remove(q_file_path) 97 | os.remove(ret_file_path) 98 | except FileNotFoundError as e: 99 | cprint(f"Caught a ValueError: {e}", 'yellow') 100 | cprint(ret.stdout.strip(), 'blue') 101 | cprint(ret.stderr.strip(), 'red') 102 | exit() 103 | return success, q_isaac 104 | -------------------------------------------------------------------------------- /visualization/vis_controller.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import time 6 | import argparse 7 | import trimesh 8 | import torch 9 | import viser 10 | 11 | from utils.hand_model import create_hand_model 12 | from utils.controller import controller, get_link_dir 13 | from utils.vis_utils import vis_vector 14 | 15 | 16 | def vis_controller_result(robot_name='shadowhand', object_name=None): 17 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 18 | metadata = torch.load(dataset_path)['metadata'] 19 | metadata = [m for m in metadata if (object_name is None or m[1] == object_name) and m[2] == robot_name] 20 | 21 | server = viser.ViserServer(host='127.0.0.1', port=8080) 22 | 23 | slider = server.gui.add_slider( 24 | label='robot', 25 | min=0, 26 | max=len(metadata) - 1, 27 | step=1, 28 | initial_value=0 29 | ) 30 | slider.on_update(lambda gui: on_update(gui.target.value)) 31 | 32 | hand = create_hand_model(robot_name) 33 | 34 | def on_update(idx): 35 | q, object_name, _ = metadata[idx] 36 | outer_q, inner_q = controller(robot_name, q) 37 | 38 | name = object_name.split('+') 39 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') 40 | object_trimesh = trimesh.load_mesh(object_path) 41 | server.scene.add_mesh_simple( 42 | 'object', 43 | object_trimesh.vertices, 44 | object_trimesh.faces, 45 | color=(239, 132, 167), 46 | opacity=0.75 47 | ) 48 | 49 | robot_trimesh = hand.get_trimesh_q(q)["visual"] 50 | server.scene.add_mesh_simple( 51 | 'origin', 52 | robot_trimesh.vertices, 53 | robot_trimesh.faces, 54 | color=(102, 192, 255), 55 | opacity=0.75 56 | ) 57 | robot_trimesh = hand.get_trimesh_q(outer_q)["visual"] 58 | server.scene.add_mesh_simple( 59 | 'outer', 60 | robot_trimesh.vertices, 61 | robot_trimesh.faces, 62 | color=(255, 149, 71), 63 | opacity=0.75 64 | ) 65 | robot_trimesh = hand.get_trimesh_q(inner_q)["visual"] 66 | server.scene.add_mesh_simple( 67 | 'inner', 68 | robot_trimesh.vertices, 69 | robot_trimesh.faces, 70 | color=(255, 111, 190), 71 | opacity=0.75 72 | ) 73 | 74 | while True: 75 | time.sleep(1) 76 | 77 | 78 | def vis_hand_direction(robot_name='shadowhand'): 79 | server = viser.ViserServer(host='127.0.0.1', port=8080) 80 | 81 | hand = create_hand_model(robot_name, device='cpu') 82 | q = hand.get_canonical_q() 83 | joint_orders = hand.get_joint_orders() 84 | lower, upper = hand.pk_chain.get_joint_limits() 85 | 86 | canonical_trimesh = hand.get_trimesh_q(q)["visual"] 87 | server.scene.add_mesh_simple( 88 | robot_name, 89 | canonical_trimesh.vertices, 90 | canonical_trimesh.faces, 91 | color=(102, 192, 255), 92 | opacity=0.8 93 | ) 94 | 95 | pk_chain = hand.pk_chain 96 | status = pk_chain.forward_kinematics(q) 97 | joint_dots = {} 98 | for frame_name in pk_chain.get_frame_names(): 99 | frame = pk_chain.find_frame(frame_name) 100 | joint = frame.joint 101 | link_dir = get_link_dir(robot_name, joint.name) 102 | if link_dir is None: 103 | continue 104 | 105 | frame_transform = status[frame_name].get_matrix()[0] 106 | axis_dir = frame_transform[:3, :3] @ joint.axis 107 | link_dir = frame_transform[:3, :3] @ link_dir 108 | normal_dir = torch.cross(axis_dir, link_dir, dim=0) 109 | axis_origin = frame_transform[:3, 3] 110 | origin_dir = -axis_origin / torch.norm(axis_origin) 111 | joint_dots[joint.name] = float(torch.dot(normal_dir, origin_dir)) 112 | 113 | print(joint.name, joint_orders.index(joint.name), joint_dots[joint.name]) 114 | vec_mesh = vis_vector( 115 | axis_origin.numpy(), 116 | vector=normal_dir.numpy(), 117 | length=0.03, 118 | cyliner_r=0.001, 119 | color=(0, 255, 0) 120 | ) 121 | server.scene.add_mesh_trimesh(joint.name, vec_mesh, visible=True) 122 | 123 | current_q = [0 if i < 6 else lower[i] * 0.75 + upper[i] * 0.25 for i in range(hand.dof)] 124 | 125 | def update(joint_idx, joint_q): 126 | current_q[joint_idx] = joint_q 127 | trimesh = hand.get_trimesh_q(torch.tensor(current_q))["visual"] 128 | server.scene.add_mesh_simple( 129 | robot_name, 130 | trimesh.vertices, 131 | trimesh.faces, 132 | color=(102, 192, 255), 133 | opacity=0.8 134 | ) 135 | 136 | for i, joint_name in enumerate(joint_orders): 137 | if joint_name in joint_dots.keys(): 138 | slider = server.gui.add_slider( 139 | label=joint_name, 140 | min=round(lower[i], 2), 141 | max=round(upper[i], 2), 142 | step=(upper[i] - lower[i]) / 100, 143 | initial_value=current_q[i], 144 | ) 145 | slider.on_update(lambda gui: update(gui.target.order - 1, gui.target.value)) 146 | else: 147 | slider = server.gui.add_slider(label=' ', min=0, max=1, step=1, initial_value=0) 148 | 149 | while True: 150 | time.sleep(1) 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--robot_name', default='shadowhand', type=str) 156 | parser.add_argument('--controller', action='store_true') 157 | args = parser.parse_args() 158 | 159 | if args.controller: 160 | vis_controller_result(args.robot_name) 161 | else: 162 | vis_hand_direction(args.robot_name) 163 | -------------------------------------------------------------------------------- /visualization/vis_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import time 6 | import trimesh 7 | import torch 8 | import viser 9 | from utils.hand_model import create_hand_model 10 | 11 | filtered = True 12 | 13 | robot_names = ['allegro', 'barrett', 'ezgripper', 'robotiq_3finger', 'shadowhand'] 14 | object_names = [ 15 | 'contactdb+alarm_clock', 'contactdb+apple', 'contactdb+banana', 'contactdb+binoculars', 'contactdb+camera', 16 | 'contactdb+cell_phone', 'contactdb+cube_large', 'contactdb+cube_medium', 'contactdb+cube_small', 17 | 'contactdb+cylinder_large', 'contactdb+cylinder_medium', 'contactdb+cylinder_small', 'contactdb+door_knob', 18 | 'contactdb+elephant', 'contactdb+flashlight', 'contactdb+hammer', 'contactdb+light_bulb', 'contactdb+mouse', 19 | 'contactdb+piggy_bank', 'contactdb+ps_controller', 'contactdb+pyramid_large', 'contactdb+pyramid_medium', 20 | 'contactdb+pyramid_small', 'contactdb+rubber_duck', 'contactdb+stanford_bunny', 'contactdb+stapler', 21 | 'contactdb+toothpaste', 'contactdb+torus_large', 'contactdb+torus_medium', 'contactdb+torus_small', 22 | 'contactdb+train', 'contactdb+water_bottle', 'ycb+baseball', 'ycb+bleach_cleanser', 'ycb+cracker_box', 23 | 'ycb+foam_brick', 'ycb+gelatin_box', 'ycb+hammer', 'ycb+lemon', 'ycb+master_chef_can', 'ycb+mini_soccer_ball', 24 | 'ycb+mustard_bottle', 'ycb+orange', 'ycb+peach', 'ycb+pear', 'ycb+pitcher_base', 'ycb+plum', 'ycb+potted_meat_can', 25 | 'ycb+power_drill', 'ycb+pudding_box', 'ycb+rubiks_cube', 'ycb+sponge', 'ycb+strawberry', 'ycb+sugar_box', 26 | 'ycb+tomato_soup_can', 'ycb+toy_airplane', 'ycb+tuna_fish_can', 'ycb+wood_block' 27 | ] 28 | 29 | if filtered: 30 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 31 | else: 32 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset/cmap_dataset.pt') 33 | metadata = torch.load(dataset_path, map_location=torch.device('cpu'))['metadata'] 34 | 35 | def on_update(robot_idx, object_idx, grasp_idx): 36 | robot_name = robot_names[robot_idx] 37 | object_name = object_names[object_idx] 38 | if filtered: 39 | metadata_curr = [m[0] for m in metadata if m[1] == object_name and m[2] == robot_name] 40 | else: 41 | metadata_curr = [m[1] for m in metadata if m[2] == object_name and m[3] == robot_name] 42 | if len(metadata_curr) == 0: 43 | print('No metadata found!') 44 | return 45 | q = metadata_curr[grasp_idx % len(metadata_curr)] 46 | print(f"joint values: {q}") 47 | 48 | name = object_name.split('+') 49 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/{name[1]}.stl') # visual mesh 50 | # object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{name[0]}/{name[1]}/coacd_allinone.obj') # collision mesh 51 | object_trimesh = trimesh.load_mesh(object_path) 52 | server.scene.add_mesh_simple( 53 | 'object', 54 | object_trimesh.vertices, 55 | object_trimesh.faces, 56 | color=(239, 132, 167), 57 | opacity=1 58 | ) 59 | 60 | hand = create_hand_model(robot_name) 61 | robot_trimesh = hand.get_trimesh_q(q)["visual"] 62 | server.scene.add_mesh_simple( 63 | 'robot', 64 | robot_trimesh.vertices, 65 | robot_trimesh.faces, 66 | color=(102, 192, 255), 67 | opacity=0.8 68 | ) 69 | 70 | server = viser.ViserServer(host='127.0.0.1', port=8080) 71 | 72 | robot_slider = server.gui.add_slider( 73 | label='robot', 74 | min=0, 75 | max=len(robot_names) - 1, 76 | step=1, 77 | initial_value=0 78 | ) 79 | object_slider = server.gui.add_slider( 80 | label='object', 81 | min=0, 82 | max=len(object_names) - 1, 83 | step=1, 84 | initial_value=0 85 | ) 86 | grasp_slider = server.gui.add_slider( 87 | label='grasp', 88 | min=0, 89 | max=199, 90 | step=1, 91 | initial_value=0 92 | ) 93 | robot_slider.on_update(lambda _: on_update(robot_slider.value, object_slider.value, grasp_slider.value)) 94 | object_slider.on_update(lambda _: on_update(robot_slider.value, object_slider.value, grasp_slider.value)) 95 | grasp_slider.on_update(lambda _: on_update(robot_slider.value, object_slider.value, grasp_slider.value)) 96 | 97 | while True: 98 | time.sleep(1) 99 | -------------------------------------------------------------------------------- /visualization/vis_hand_joint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualizes hand joint motion within joint range (upper & lower limits). 3 | """ 4 | 5 | import os 6 | import sys 7 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(ROOT_DIR) 9 | import time 10 | import argparse 11 | import torch 12 | import viser 13 | from utils.hand_model import create_hand_model 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--robot_name', type=str, default='shadowhand') 17 | args = parser.parse_args() 18 | robot_name = args.robot_name 19 | 20 | hand = create_hand_model(robot_name) 21 | pk_chain = hand.pk_chain 22 | lower, upper = pk_chain.get_joint_limits() 23 | 24 | server = viser.ViserServer(host='127.0.0.1', port=8080) 25 | 26 | canonical_trimesh = hand.get_trimesh_q(hand.get_canonical_q())["visual"] 27 | server.scene.add_mesh_simple( 28 | robot_name, 29 | canonical_trimesh.vertices, 30 | canonical_trimesh.faces, 31 | color=(102, 192, 255), 32 | opacity=0.8 33 | ) 34 | 35 | def update(q): 36 | trimesh = hand.get_trimesh_q(q)["visual"] 37 | server.scene.add_mesh_simple( 38 | robot_name, 39 | trimesh.vertices, 40 | trimesh.faces, 41 | color=(102, 192, 255), 42 | opacity=0.8 43 | ) 44 | 45 | gui_joints = [] 46 | for i, joint_name in enumerate(hand.get_joint_orders()): 47 | slider = server.gui.add_slider( 48 | label=joint_name, 49 | min=round(lower[i], 2), 50 | max=round(upper[i], 2), 51 | step=(upper[i] - lower[i]) / 100, 52 | initial_value=0 if i < 6 else lower[i] * 0.75 + upper[i] * 0.25, 53 | ) 54 | slider.on_update(lambda _: update(torch.tensor([gui.value for gui in gui_joints]))) 55 | gui_joints.append(slider) 56 | 57 | while True: 58 | time.sleep(1) 59 | -------------------------------------------------------------------------------- /visualization/vis_hand_link.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize hand links to remove abundant links in removed_links.json. 3 | """ 4 | 5 | import os 6 | import sys 7 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(ROOT_DIR) 9 | import time 10 | import argparse 11 | import viser 12 | from utils.hand_model import create_hand_model 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--robot_name', type=str, default='shadowhand') 16 | args = parser.parse_args() 17 | robot_name = args.robot_name 18 | 19 | hand = create_hand_model(robot_name) 20 | meshes = hand.get_trimesh_q(hand.get_canonical_q())['parts'] 21 | 22 | server = viser.ViserServer(host='127.0.0.1', port=8080) 23 | 24 | for name, mesh in meshes.items(): 25 | server.scene.add_mesh_simple( 26 | name, 27 | mesh.vertices, 28 | mesh.faces, 29 | color=(102, 192, 255), 30 | opacity=0.8 31 | ) 32 | 33 | while True: 34 | time.sleep(1) 35 | -------------------------------------------------------------------------------- /visualization/vis_optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(ROOT_DIR) 5 | import warnings 6 | import time 7 | import random 8 | import argparse 9 | import viser 10 | import torch 11 | 12 | from utils.hand_model import create_hand_model 13 | from utils.optimization import * 14 | from utils.se3_transform import compute_link_pose 15 | 16 | 17 | def main(robot_name): 18 | dataset_path = os.path.join(ROOT_DIR, f'data/CMapDataset_filtered/cmap_dataset.pt') 19 | metadata = torch.load(dataset_path, map_location=torch.device('cpu'))['metadata'] 20 | metadata = [m for m in metadata if m[2] == robot_name] 21 | q = random.choice(metadata)[0] 22 | 23 | hand = create_hand_model(robot_name, device='cpu') 24 | initial_q = hand.get_initial_q(q) 25 | pc_initial = hand.get_transformed_links_pc(initial_q)[:, :3] 26 | pc_target = hand.get_transformed_links_pc(q)[:, :3] 27 | 28 | transform, _ = compute_link_pose(hand.links_pc, pc_target.unsqueeze(0), is_train=False) 29 | optim_transform = process_transform(hand.pk_chain, transform) 30 | layer = create_problem(hand.pk_chain, optim_transform.keys()) 31 | predict_q = optimization(hand.pk_chain, layer, initial_q.unsqueeze(0), optim_transform)[0] 32 | pc_optimize = hand.get_transformed_links_pc(predict_q)[:, :3] 33 | 34 | server = viser.ViserServer(host='127.0.0.1', port=8080) 35 | 36 | server.scene.add_point_cloud( 37 | 'pc_initial', 38 | pc_initial.numpy(), 39 | point_size=0.001, 40 | point_shape="circle", 41 | colors=(102, 192, 255), 42 | visible=False 43 | ) 44 | 45 | server.scene.add_point_cloud( 46 | 'pc_optimize', 47 | pc_optimize.numpy(), 48 | point_size=0.001, 49 | point_shape="circle", 50 | colors=(0, 0, 200) 51 | ) 52 | 53 | server.scene.add_point_cloud( 54 | 'pc_target', 55 | pc_target.numpy(), 56 | point_size=0.001, 57 | point_shape="circle", 58 | colors=(200, 0, 0) 59 | ) 60 | 61 | while True: 62 | time.sleep(1) 63 | 64 | 65 | if __name__ == '__main__': 66 | warnings.simplefilter(action='ignore', category=FutureWarning) 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--robot_name', default='shadowhand', type=str) 69 | args = parser.parse_args() 70 | 71 | main(args.robot_name) 72 | -------------------------------------------------------------------------------- /visualization/vis_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import time 5 | import viser 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(ROOT_DIR) 11 | 12 | from model.network import create_encoder_network 13 | from data_utils.CMapDataset import CMapDataset 14 | from utils.pretrain_utils import dist2weight, infonce_loss 15 | from utils.hand_model import create_hand_model 16 | 17 | 18 | def main(robot_name): 19 | encoder = create_encoder_network(emb_dim=512, pretrain='pretrain_3robots.pth') 20 | 21 | dataset = CMapDataset( 22 | batch_size=1, 23 | robot_names=[robot_name], 24 | is_train=True, 25 | debug_object_names=None 26 | ) 27 | data = dataset[0] 28 | q_1 = data['initial_q'][0] 29 | q_2 = data['target_q'][0] 30 | pc_1 = data['robot_pc_initial'] 31 | pc_2 = data['robot_pc_target'] 32 | 33 | pc_1 = pc_1 - pc_1.mean(dim=1, keepdims=True) 34 | pc_2 = pc_2 - pc_2.mean(dim=1, keepdims=True) 35 | 36 | emb_1 = encoder(pc_1).detach() 37 | emb_2 = encoder(pc_2).detach() 38 | 39 | weight = dist2weight(pc_1, func=lambda x: torch.tanh(10 * x)) 40 | loss, similarity = infonce_loss( 41 | emb_1, emb_2, weights=weight, temperature=0.1 42 | ) 43 | 44 | match_idx = torch.argmax(similarity[0], dim=0) 45 | 46 | # offset for clearer visualization result 47 | offset = torch.tensor([0, 0.3, 0]) 48 | vis_pc_1 = data['robot_pc_initial'][0] 49 | vis_pc_2 = data['robot_pc_target'][0] + offset 50 | q_2[:3] += offset 51 | 52 | # match_tgt = vis_pc_2[match_idx] 53 | # match_vec = match_tgt - vis_pc_1 54 | 55 | server = viser.ViserServer(host='127.0.0.1', port=8080) 56 | 57 | z_values = vis_pc_1[:, 1] 58 | z_normalized = (z_values - z_values.min()) / (z_values.max() - z_values.min()) 59 | cmap = plt.get_cmap('rainbow') 60 | initial_colors = cmap(z_normalized)[:, :3] 61 | target_colors = initial_colors[match_idx] 62 | 63 | server.scene.add_point_cloud( 64 | 'initial pc', 65 | vis_pc_1[:, :3].numpy(), 66 | point_size=0.002, 67 | point_shape="circle", 68 | colors=initial_colors 69 | ) 70 | 71 | server.scene.add_point_cloud( 72 | 'target pc', 73 | vis_pc_2[:, :3].numpy(), 74 | point_size=0.002, 75 | point_shape="circle", 76 | colors=target_colors 77 | ) 78 | 79 | hand = create_hand_model(robot_name) 80 | 81 | robot_trimesh = hand.get_trimesh_q(q_1)["visual"] 82 | server.scene.add_mesh_simple( 83 | 'robot_initial', 84 | robot_trimesh.vertices, 85 | robot_trimesh.faces, 86 | color=(102, 192, 255), 87 | opacity=0.2 88 | ) 89 | 90 | robot_trimesh = hand.get_trimesh_q(q_2)["visual"] 91 | server.scene.add_mesh_simple( 92 | 'robot_target', 93 | robot_trimesh.vertices, 94 | robot_trimesh.faces, 95 | color=(102, 192, 255), 96 | opacity=0.2 97 | ) 98 | 99 | while True: 100 | time.sleep(1) 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--robot_name', type=str, default='shadowhand') 106 | args = parser.parse_args() 107 | 108 | main(args.robot_name) 109 | -------------------------------------------------------------------------------- /visualization/vis_validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation visualization results will be saved in the 'vis_info/' folder. 3 | This code is used to visualize the saved information. 4 | """ 5 | 6 | import os 7 | import sys 8 | import time 9 | import viser 10 | import trimesh 11 | import torch 12 | 13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | sys.path.append(ROOT_DIR) 15 | 16 | from utils.hand_model import create_hand_model 17 | 18 | 19 | def main(): 20 | # substitute your filename here, which should be automatically saved in vis_info/ by validation.py 21 | file_name = 'vis_info/3robots_epoch10.pt' 22 | vis_info = torch.load(os.path.join(ROOT_DIR, file_name), map_location='cpu') 23 | 24 | def on_update(idx): 25 | invalid = True 26 | for info in vis_info: 27 | if idx >= info['predict_q'].shape[0]: 28 | idx -= info['predict_q'].shape[0] 29 | else: 30 | invalid = False 31 | break 32 | if invalid: 33 | print('Invalid index!') 34 | return 35 | 36 | print(info['robot_name'], info['object_name'], idx) 37 | print('result:', info['success'][idx]) 38 | 39 | object_name = info['object_name'].split('+') 40 | object_path = os.path.join(ROOT_DIR, f'data/data_urdf/object/{object_name[0]}/{object_name[1]}/{object_name[1]}.stl') 41 | object_trimesh = trimesh.load_mesh(object_path) 42 | server.scene.add_mesh_simple( 43 | 'object', 44 | object_trimesh.vertices, 45 | object_trimesh.faces, 46 | color=(239, 132, 167), 47 | opacity=0.8 48 | ) 49 | 50 | server.scene.add_point_cloud( 51 | 'object_pc', 52 | info['object_pc'][idx].numpy(), 53 | point_size=0.0008, 54 | point_shape="circle", 55 | colors=(255, 0, 0), 56 | visible=False 57 | ) 58 | 59 | server.scene.add_point_cloud( 60 | 'mlat_pc', 61 | info['mlat_pc'][idx].numpy(), 62 | point_size=0.001, 63 | point_shape="circle", 64 | colors=(0, 0, 200), 65 | visible=False 66 | ) 67 | 68 | hand = create_hand_model(info['robot_name']) 69 | 70 | robot_transform_trimesh = hand.get_trimesh_se3(info['predict_transform'], idx) 71 | server.scene.add_mesh_trimesh('transform', robot_transform_trimesh, visible=False) 72 | 73 | robot_trimesh = hand.get_trimesh_q(info['predict_q'][idx])['visual'] 74 | server.scene.add_mesh_simple( 75 | 'robot_predict', 76 | robot_trimesh.vertices, 77 | robot_trimesh.faces, 78 | color=(102, 192, 255), 79 | opacity=0.8, 80 | visible=False 81 | ) 82 | 83 | robot_trimesh = hand.get_trimesh_q(info['isaac_q'][idx])['visual'] 84 | server.scene.add_mesh_simple( 85 | 'robot_isaac', 86 | robot_trimesh.vertices, 87 | robot_trimesh.faces, 88 | color=(102, 192, 255), 89 | opacity=0.8 90 | ) 91 | 92 | server = viser.ViserServer(host='127.0.0.1', port=8080) 93 | 94 | grasp_num = 0 95 | for info in vis_info: 96 | grasp_num += info['predict_q'].shape[0] 97 | 98 | slider = server.gui.add_slider( 99 | label='grasp_idx', 100 | min=0, 101 | max=grasp_num, 102 | step=1, 103 | initial_value=0 104 | ) 105 | slider.on_update(lambda _: on_update(slider.value)) 106 | 107 | while True: 108 | time.sleep(1) 109 | 110 | if __name__ == '__main__': 111 | main() 112 | --------------------------------------------------------------------------------