├── AutomaticWeightedLoss ├── AutomaticWeightedLoss.py ├── LICENSE ├── README.md └── __init__.py ├── README.md ├── checkpoints ├── data ├── dataset ├── Dataset.py ├── GPTsummary.py ├── LVDataset.py ├── __init__.py ├── category.py ├── config_to_object_id.py ├── generate_data.py ├── hardset.py ├── helpers.py ├── label_grasp.py ├── material_count.pkl ├── materials.json ├── object_count.pkl ├── prompt1.py ├── prompt2.py ├── stat.py └── utils.py ├── demo ├── __init__.py ├── demo_extract_language_feature.py ├── demo_objects ├── main.py └── view_graspnet_grasps.py ├── encoder ├── LLaMA27BEncoder.py ├── __init__.py └── language_encode.py ├── evaluation ├── __init__.py ├── affordance.py ├── demo_graspnet.py ├── generate_eval_table.py ├── main.py ├── main_multi.py ├── results ├── run.py ├── test_graspnet.py ├── test_hardset.py ├── test_to_train.py └── test_vgn.py ├── example.py ├── grasp ├── example.py ├── force_optimization.py ├── generate_grasp.py └── transform.py ├── model ├── __init__.py ├── data_utils.py ├── eval_ab_l.py ├── eval_ab_vg.py ├── eval_ab_vl.py ├── grasp_utils.py ├── model.py ├── model_ab_l.py ├── model_ab_vg.py ├── model_ab_vl.py ├── run.py ├── trainer.py ├── trainer_ab_l.py ├── trainer_ab_vg.py ├── trainer_ab_vl.py ├── utils.py └── visualize.py ├── scripts ├── data_exam.py ├── heatmap.py └── plot_affordance_map.py └── vision_encoder ├── __init__.py ├── demo_extract_vision_feature.py ├── extract_vision_feature.py ├── modelnet40_pointnext-s.yaml ├── modelnet40ply2048-train-pointnext-s-ngpus1-seed6848-model.encoder_args.width=64-20220525-145053-7tGhBV9xR9yQEBtN4GPcSc_ckpt_best.pth ├── shapenetpart-train-pointnext-s_c64-ngpus4-seed7798-20220822-024210-ZcJ8JwCgc7yysEBWzkyAaE_ckpt_best.pth ├── shapenetpart_pointnext-s.yaml ├── shapenetpart_pointnext-s_c160.yaml └── shapenetpart_pointnext-s_c64.yaml /AutomaticWeightedLoss/AutomaticWeightedLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class AutomaticWeightedLoss(nn.Module): 7 | """automatically weighted multi-task loss 8 | 9 | Params: 10 | num: int,the number of loss 11 | x: multi-task loss 12 | Examples: 13 | loss1=1 14 | loss2=2 15 | awl = AutomaticWeightedLoss(2) 16 | loss_sum = awl(loss1, loss2) 17 | """ 18 | def __init__(self, num=2): 19 | super(AutomaticWeightedLoss, self).__init__() 20 | params = torch.ones(num, requires_grad=True) 21 | self.params = torch.nn.Parameter(params) 22 | 23 | def forward(self, *x): 24 | loss_sum = 0 25 | for i, loss in enumerate(x): 26 | loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2) 27 | return loss_sum 28 | 29 | if __name__ == '__main__': 30 | awl = AutomaticWeightedLoss(2) 31 | print(awl.parameters()) -------------------------------------------------------------------------------- /AutomaticWeightedLoss/README.md: -------------------------------------------------------------------------------- 1 | # AutomaticWeightedLoss 2 | 3 | A PyTorch implementation of Liebel L, Körner M. [Auxiliary tasks in multi-task learning](https://arxiv.org/pdf/1805.06334)[J]. arXiv preprint arXiv:1805.06334, 2018. 4 | 5 | The above paper improves the paper "[Multi-task learning using uncertainty to weigh losses for scene geometry and semantics](http://openaccess.thecvf.com/content_cvpr_2018/html/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.html)" to avoid the loss of becoming negative during training. 6 | 7 | ## Requirements 8 | 9 | * Python 10 | * PyTorch 11 | 12 | ## How to Train with Your Model 13 | 14 | * Clone the repository 15 | 16 | ``` bash 17 | git clone git@github.com:Mikoto10032/AutomaticWeightedLoss.git 18 | ``` 19 | 20 | * Create an AutomaticWeightedLoss module 21 | 22 | ```python 23 | from AutomaticWeightedLoss import AutomaticWeightedLoss 24 | 25 | awl = AutomaticWeightedLoss(2) # we have 2 losses 26 | loss1 = 1 27 | loss2 = 2 28 | loss_sum = awl(loss1, loss2) 29 | ``` 30 | 31 | * Create an optimizer to learn weight coefficients 32 | 33 | ```python 34 | from torch import optim 35 | 36 | model = Model() 37 | optimizer = optim.Adam([ 38 | {'params': model.parameters()}, 39 | {'params': awl.parameters(), 'weight_decay': 0} 40 | ]) 41 | ``` 42 | 43 | * A complete example 44 | 45 | ```python 46 | from torch import optim 47 | from AutomaticWeightedLoss import AutomaticWeightedLoss 48 | 49 | model = Model() 50 | 51 | awl = AutomaticWeightedLoss(2) # we have 2 losses 52 | loss_1 = ... 53 | loss_2 = ... 54 | 55 | # learnable parameters 56 | optimizer = optim.Adam([ 57 | {'params': model.parameters()}, 58 | {'params': awl.parameters(), 'weight_decay': 0} 59 | ]) 60 | 61 | for i in range(epoch): 62 | for data, label1, label2 in data_loader: 63 | # forward 64 | pred1, pred2 = Model(data) 65 | # calculate losses 66 | loss1 = loss_1(pred1, label1) 67 | loss2 = loss_2(pred2, label2) 68 | # weigh losses 69 | loss_sum = awl(loss1, loss2) 70 | # backward 71 | optimizer.zero_grad() 72 | loss_sum.backward() 73 | optimizer.step() 74 | ``` 75 | 76 | ## Something to Say 77 | 78 | Actually, it is not always effective, but I hope it can help you. -------------------------------------------------------------------------------- /AutomaticWeightedLoss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/AutomaticWeightedLoss/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | PhyGrasp: Generalizing Robotic Grasping with Physics-informed Large Multimodal Models 3 |

4 | 5 |
6 |

7 | Dingkun Guo*, 8 | Yuqi Xiang*, 9 | Shuqi Zhao, Xinghao Zhu, Masayoshi Tomizuka, 10 | Mingyu Ding†, Wei Zhan 11 |
12 |
13 | *These authors contribute equally to this work. †Corresponding author and project lead. 14 |
15 |
16 | [Website] 17 | [Paper] 18 | [Presentation] 19 |
20 |

21 | 22 | # Citation 23 | Please cite this work if it helps your research: 24 | 25 | ``` 26 | @ARTICLE{Guo2024PhyGrasp, 27 | title={PhyGrasp: Generalizing Robotic Grasping with Physics-informed Large Multimodal Models}, 28 | author={Dingkun Guo, Yuqi Xiang, Shuqi Zhao, Xinghao Zhu, Masayoshi Tomizuka, Mingyu Ding, Wei Zhan}, 29 | year={2024}, 30 | eprint={2402.16836}, 31 | archivePrefix={arXiv}, 32 | primaryClass={cs.RO} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /checkpoints: -------------------------------------------------------------------------------- 1 | /home/gdk/Data/bimanual/checkpoints -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | /home/gdk/Data/bimanual -------------------------------------------------------------------------------- /dataset/Dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import trimesh 7 | import pickle 8 | import dataset.helpers as helpers 9 | 10 | class Dataset: 11 | def __init__(self, dataset_dir): 12 | self.dataset_dir = dataset_dir 13 | self.data_entries = {} 14 | 15 | def __getitem__(self, object_id): 16 | if object_id not in self.data_entries: 17 | self.data_entries[object_id] = ObjectEntry(object_id, self.dataset_dir) 18 | return self.data_entries[object_id] 19 | 20 | def get_object_ids(self): 21 | return os.listdir(self.dataset_dir) 22 | 23 | def load(self): 24 | for object_id in self.get_object_ids(): 25 | self[object_id].load() 26 | # for entry_name in os.listdir(os.path.join(self.dataset_dir, object_id)): 27 | # if entry_name.endswith('.pkl'): 28 | # config_id = entry_name.split('_')[1][0:-4] 29 | # self[object_id].load(config_id) 30 | 31 | def count_category(self): 32 | return helpers.count_category(self) 33 | 34 | def get_languages(self): 35 | return helpers.get_languages(self) 36 | 37 | class ObjectEntry: 38 | def __init__(self, object_id, dataset_dir): 39 | self.object_id = object_id 40 | self.dataset_dir = dataset_dir 41 | self.object_dir = os.path.join(self.dataset_dir, self.object_id) 42 | self.data = {} 43 | self.load_metadata() 44 | 45 | def __getitem__(self, config_id): 46 | if config_id not in self.data: 47 | self.data[config_id] = self.load(config_id) 48 | return self.data[config_id] 49 | 50 | def __setitem__(self, config_id, data_entry): 51 | self.data[config_id] = data_entry 52 | 53 | def load_metadata(self): 54 | parts_json = json.load(open(os.path.join(self.object_dir, 'parts.json'), 'r')) 55 | self.name = parts_json['name'] 56 | 57 | def load_meshes(self): 58 | parts_json = json.load(open(os.path.join(self.object_dir, 'parts.json'), 'r')) 59 | if 'name' not in parts_json: 60 | raise Exception('parts.json does not have name field') 61 | meshes_dir = os.path.join(self.object_dir, 'meshes') 62 | meshes = [] 63 | names = [] 64 | for part in parts_json.values(): 65 | name = part['name'] 66 | mesh = trimesh.load(os.path.join(meshes_dir, part['mesh_name'])) 67 | mesh.units = 'm' 68 | meshes.append(mesh) 69 | names.append(name) 70 | # print(names) 71 | return meshes 72 | else: 73 | self.name = parts_json['name'] 74 | meshes_dir = os.path.join(self.object_dir, 'objs') 75 | meshes = [] 76 | names = [] 77 | for part in parts_json['parts'].values(): 78 | names.append(part['name']) 79 | obj_names = part['objs'] 80 | obj_meshes = [] 81 | for obj_name in obj_names: 82 | mesh = trimesh.load(f'{meshes_dir}/{obj_name}.obj') 83 | mesh.units = 'm' 84 | obj_meshes.append(mesh) 85 | mesh = trimesh.util.concatenate(obj_meshes) 86 | meshes.append(mesh) 87 | # print(names) 88 | return meshes 89 | 90 | def save(self, version=""): 91 | for config_id, data_entry in self.data.items(): 92 | assert hasattr(data_entry, 'language_feature'), 'data_entry does not have language_feature' 93 | entry_name = f'{self.object_id}_{config_id}{version}.pkl' 94 | entry_path = os.path.join(self.object_dir, entry_name) 95 | with open(entry_path, 'wb') as f: 96 | pickle.dump(data_entry, f) 97 | 98 | def load_config(self, config_id, version=""): 99 | entry_name = f'{self.object_id}_{config_id}{version}.pkl' 100 | entry_path = os.path.join(self.object_dir, entry_name) 101 | if os.path.exists(entry_path): 102 | with open(entry_path, 'rb') as f: 103 | data_entry = pickle.load(f) 104 | self.__setitem__(config_id, data_entry) 105 | 106 | def load(self, version=""): 107 | for entry_name in os.listdir(self.object_dir): 108 | suffix = f'{version}.pkl' 109 | if entry_name.endswith(suffix): 110 | config_id = entry_name[:-len(suffix)].split('_')[1] 111 | self.load_config(config_id, version=version) 112 | 113 | def unload(self): 114 | for config_id in list(self.data.keys()): 115 | del self.data[config_id] 116 | 117 | 118 | class DataEntry: 119 | def __init__(self, config, pos_grasps, neg_grasps, grasp_map, language): 120 | self.config = config 121 | self.pos_grasps = pos_grasps 122 | self.neg_grasps = neg_grasps 123 | self.grasp_map = grasp_map 124 | self.language = language 125 | 126 | 127 | class Config: 128 | def __init__(self, config_id, materials, frictions, densities, grasp_likelihoods, fragilities, 129 | sample_probs, max_normal_forces, masses): 130 | self.id = config_id 131 | # self.meshes = meshes 132 | assert len(materials) == len(frictions) == len(densities) == len(fragilities) 133 | self.num_parts = len(materials) 134 | 135 | # defined config 136 | self.materials = materials 137 | self.frictions = frictions 138 | self.densities = densities 139 | self.grasp_likelihoods = grasp_likelihoods 140 | self.fragilities = fragilities 141 | 142 | # calculated config 143 | # self.sample_probs = self.grasp_likelihoods / sum(self.grasp_likelihoods) 144 | # self.max_normal_forces = np.power(10, self.fragilities) 145 | self.sample_probs = sample_probs 146 | self.max_normal_forces = max_normal_forces 147 | self.masses = masses 148 | 149 | 150 | if __name__ == '__main__': 151 | dataset = Dataset('./data/objects/') 152 | meshes = dataset['10314'].load_meshes() 153 | print(len(meshes)) 154 | -------------------------------------------------------------------------------- /dataset/GPTsummary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import argparse 4 | import json 5 | import numpy as np 6 | import random 7 | import dataset.prompt2 as PROMPT 8 | from dataset.Dataset import Dataset, Config, DataEntry, ObjectEntry 9 | import dataset.generate_data as generate_data 10 | from dataset.category import OBJECTS 11 | import time 12 | 13 | openai.api_type = "azure" 14 | openai.api_version = "2023-05-15" 15 | openai.api_key = os.getenv("AZURE_OPENAI_KEY") 16 | openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") 17 | deployment_name='gpt35' #This will correspond to the custom name you chose for your deployment when you deployed a model. 18 | 19 | PARTS_FILE = 'parts.json' 20 | MATERIALS_FILE = './dataset/materials.json' 21 | FRAGILIGY_LEVELS = {1: 'very fragile', 2: 'fragile', 3: 'normal', 4: 'tough', 5: 'very tough'} 22 | 23 | 24 | def get_parts_names(object: ObjectEntry): 25 | parts_path = os.path.join(object.object_dir, PARTS_FILE) 26 | with open(parts_path) as f: 27 | parts = json.load(f) 28 | object_name = parts['name'] 29 | parts_info = parts['parts'] 30 | parts_name = [part['name'] for part in parts_info.values()] 31 | return object_name, parts_name 32 | 33 | def get_materials(): 34 | materials_path = MATERIALS_FILE 35 | with open(materials_path) as f: 36 | materials = json.load(f) 37 | return materials 38 | 39 | def get_random_material(materials, n=1): 40 | raise NotImplementedError 41 | material = [] 42 | for i in range(n): 43 | material.append(random.choice(list(materials))) 44 | material_names = ', '.join([m['Material'] for m in material]) 45 | material_frictions = ', '.join([str(m['Friction']) for m in material]) 46 | material_density = ', '.join([str(m['Density']) for m in material]) 47 | material_fragility = ', '.join([str(m['Fragility']) for m in material]) 48 | material_grasp_prob = ', '.join([random.choice(['0.1', '0.5', '0.9']) for m in material]) 49 | return material_names, material_frictions, material_density, material_fragility 50 | 51 | def get_config_materials(config: Config, available_materials): 52 | material_names = ', '.join(m for m in config.materials) 53 | material_frictions = ', '.join([str(m) for m in config.frictions]) 54 | material_density = ', '.join([str(m) for m in config.densities]) 55 | material_fragility = ', '.join([FRAGILIGY_LEVELS[m] for m in config.fragilities]) 56 | sample_probs = " The grasping probabilities of each part are " + ', '.join([str(round(m, 2)) for m in config.sample_probs]) + '.' if config.grasp_likelihoods is not None else "" 57 | 58 | return material_names, material_frictions, material_density, material_fragility, sample_probs 59 | 60 | def summary(config: Config, object: ObjectEntry, available_materials, params=None): 61 | object_name, parts_name = get_parts_names(object) 62 | all_parts_name = ', '.join(parts_name) 63 | if available_materials is None: 64 | available_materials = get_materials(params) 65 | material_names, material_frictions, material_density, material_fragility, sample_probs = get_config_materials(config, available_materials) 66 | description = "There is an %s, it has several parts including %s. The materials of each part are %s, with friction: %s, density: %s, fragility: %s.%s " % (object_name, all_parts_name, material_names, material_frictions, material_density, material_fragility, sample_probs) 67 | 68 | completion = openai.ChatCompletion.create( 69 | engine=deployment_name, 70 | model="gpt-3.5-turbo", 71 | messages=[ 72 | {"role": "system", "content": PROMPT.ROLE}, 73 | 74 | {"role": "user", "content": PROMPT.EXAMPLES}, 75 | 76 | {"role": "user", "content": PROMPT.INSTRUCTION + description}, 77 | ], 78 | ) 79 | response = completion.choices[0].message['content'] 80 | 81 | if params is not None and params['debug']: 82 | # print(object_name, all_parts_name) 83 | # print(material_names, material_frictions, material_density, material_fragility) 84 | print("Before summary:\n", description, '\n') 85 | print("Attempt summary:\n", response, '\n') 86 | print('-' * 80 + '\n') 87 | 88 | return response 89 | 90 | if __name__ == '__main__': 91 | # usage: under folder DualArmManipulation, run python dataset/GPTsummary_demo.py 92 | args = argparse.ArgumentParser() 93 | args.add_argument('--id', type=str, default='38957') 94 | args.add_argument('--debug', default=False, action='store_true') 95 | args.add_argument('--data_folder', type=str, default='./data/objects') 96 | args.add_argument('--parts_file', type=str, default='parts.json') 97 | args.add_argument('--material_file', type=str, default='./dataset/materials.json') 98 | params = args.parse_args() 99 | params = vars(params) 100 | 101 | dataset = Dataset('./data/objects') 102 | available_materials = json.load(open('./dataset/materials.json', 'r')) 103 | 104 | subdirs = [x for x in os.listdir(params['data_folder']) if os.path.isdir(os.path.join(params['data_folder'], x))] 105 | object_hit = {} 106 | to_hit = len(OBJECTS) 107 | for obj in OBJECTS: 108 | object_hit[obj] = 0 109 | 110 | while to_hit > 0: 111 | subdir = random.choice(subdirs) 112 | params['id'] = subdir 113 | object_name = subdir 114 | with open(os.path.join(params['data_folder'], subdir, params['parts_file']), 'r') as f: 115 | parts = json.load(f) 116 | if object_hit[parts['name']] > 0: 117 | continue 118 | else: 119 | object_hit[parts['name']] += 1 120 | to_hit -= 1 121 | 122 | meshes = dataset[object_name].load_meshes() 123 | config = generate_data.generate_random_config(meshes, available_materials) 124 | start_time = time.time() 125 | summary(config, dataset[object_name], available_materials, params) 126 | print("Time used: ", time.time() - start_time) 127 | 128 | # object_name = '1536' 129 | # meshes = dataset[object_name].load_meshes() 130 | # config = generate_random_config(meshes, available_materials) 131 | # summary(config, dataset[object_name], available_materials, params) 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /dataset/LVDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pickle 4 | import random 5 | import numpy as np 6 | 7 | class LVDataset(torch.utils.data.Dataset): 8 | def __init__(self, version="", language_layer=20) -> None: 9 | self.entry_paths = [] 10 | self.version = version 11 | self.language_layer = language_layer 12 | 13 | def __len__(self): 14 | return len(self.entry_paths) 15 | 16 | def __getitem__(self, index): 17 | if torch.is_tensor(index): 18 | index = index.tolist() 19 | return self.get_entry(self.entry_paths[index]) 20 | 21 | def get_entry(self, entry_path): 22 | # return entry_path 23 | # print(dir(self)) 24 | with open(entry_path, 'rb') as f: 25 | entry = pickle.load(f) 26 | assert self.language_layer == 20 27 | if hasattr(self, 'language_layer') == False: 28 | assert False, "self.language_layer is not defined" 29 | self.language_layer = 0 30 | if self.language_layer == 15: 31 | language = entry.language_feature_15 32 | elif self.language_layer == 20: 33 | language = entry.language_feature_20 34 | elif self.language_layer == 25: 35 | language = entry.language_feature_25 36 | else: 37 | language = entry.language_feature.squeeze(0).cpu().numpy() 38 | point_cloud = entry.point_cloud 39 | vision_global = entry.global_feature 40 | vision_local = entry.local_feature.T 41 | grasp_map = entry.grasp_map 42 | pos_index = entry.pos_index 43 | neg_index = entry.neg_index 44 | pos_neg_num = np.array([pos_index.shape[0], neg_index.shape[0]]) 45 | pos_index = np.pad(pos_index, ((0, 200 - pos_index.shape[0]), (0, 0)), mode='constant', constant_values=0) 46 | neg_index = np.pad(neg_index, ((0, 200 - neg_index.shape[0]), (0, 0)), mode='constant', constant_values=0) 47 | 48 | # print(type(pos_index), pos_index.shape, type(neg_index), neg_index.shape, type(vision_local), vision_local.shape) 49 | # print(language.shape) 50 | return language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_path 51 | 52 | def test_attributes(self): 53 | cnt = 0 54 | random.shuffle(self.entry_paths) 55 | 56 | for id, entry_path in enumerate(self.entry_paths): 57 | with open(entry_path, 'rb') as f: 58 | entry = pickle.load(f) 59 | if hasattr(entry, 'global_feature') == False: 60 | print(dir(entry)) 61 | print(entry_path) 62 | cnt += 1 63 | if id % 100 == 0: 64 | print("{}/{}".format(cnt, id)) 65 | 66 | def test_data(self): 67 | print("test data") 68 | print(len(self.entry_paths)) 69 | random.shuffle(self.entry_paths) 70 | # self.entry_paths = self.entry_paths[0:2] 71 | 72 | pos = [] 73 | neg = [] 74 | for entry_path in self.entry_paths: 75 | with open(entry_path, 'rb') as f: 76 | entry = pickle.load(f) 77 | pos_grasps = entry.pos_grasps 78 | neg_grasps = entry.neg_grasps 79 | pos.append(pos_grasps.shape[0]) 80 | neg.append(neg_grasps.shape[0]) 81 | pos = np.array(pos) 82 | neg = np.array(neg) 83 | print("pos mean, std", np.mean(pos, axis=0), np.std(pos, axis=0)) 84 | print("neg mean, std", np.mean(neg, axis=0), np.std(neg, axis=0)) 85 | 86 | def small_dataset(self): 87 | random.shuffle(self.entry_paths) 88 | self.entry_paths = self.entry_paths[0:10000] 89 | 90 | def load(self, dataset_dir, version=""): 91 | for object_id in os.listdir(dataset_dir): 92 | object_dir = os.path.join(dataset_dir, object_id) 93 | for entry_name in os.listdir(object_dir): 94 | suffix = f'{version}.pkl' 95 | if entry_name.endswith(suffix): 96 | config_id = entry_name[:-len(suffix)].split('_')[1] 97 | entry_path = os.path.join(object_dir, entry_name) 98 | self.entry_paths.append(entry_path) 99 | 100 | # self.entry_paths = self.entry_paths[0:1000] 101 | 102 | if __name__ == '__main__': 103 | lvdataset = LVDataset() 104 | lvdataset.load('./data/objects/', version="_v1") 105 | entry_paths = lvdataset.entry_paths 106 | print(len(entry_paths)) 107 | random.shuffle(entry_paths) 108 | pickle.dump(entry_paths, open('./data/dataset/v1_random_1000.pkl', 'wb')) -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/category.py: -------------------------------------------------------------------------------- 1 | CATEGORY = {'table': 9550, 'trash_can': 357, 'chair': 8097, 'pot': 579, 'faucet': 753, 'storage_furniture': 2201, 'door_set': 210, 'display': 933, 'lamp': 2967, 'keyboard': 103, 'refrigerator': 159, 'bottle': 474, 'hat': 119, 'earphone': 260, 'laptop': 223, 'bed': 181, 'cutting_instrument': 498, 'clock': 156, 'bowl': 104, 'bag': 83, 'scissors': 117, 'mug': 213, 'dishwasher': 190, 'microwave': 72} 2 | 3 | OBJECTS = ['table', 'trash_can', 'chair', 'pot', 'faucet', 'storage_furniture', 'door_set', 'display', 'lamp', 'keyboard', 'refrigerator', 'bottle', 'hat', 'earphone', 'laptop', 'bed', 'cutting_instrument', 'clock', 'bowl', 'bag', 'scissors', 'mug', 'dishwasher', 'microwave'] 4 | 5 | SCALES = {'table': 1.0, 'trash_can': 0.2, 'chair': 0.6, 'pot': 0.2, 'faucet': 0.1, 'storage_furniture': 1.5, 'door_set': 1.4, 'display': 0.3, 'lamp': 0.2, 'keyboard': 0.3, 'refrigerator': 1.2, 'bottle': 0.1, 'hat': 0.1, 'earphone': 0.01, 'laptop': 0.3, 'bed': 1.6, 'cutting_instrument': 0.1, 'clock': 0.2, 'bowl': 0.1, 'bag': 0.3, 'scissors': 0.1, 'mug': 0.1, 'dishwasher': 1.1, 'microwave': 0.7} 6 | 7 | SAMPLES = {'table': 1, 'trash_can': 29, 'chair': 1, 'pot': 17, 'faucet': 13, 'storage_furniture': 4, 'door_set': 49, 'display': 11, 'lamp': 3, 'keyboard': 101, 'refrigerator': 65, 'bottle': 21, 'hat': 87, 'earphone': 40, 'laptop': 46, 'bed': 57, 'cutting_instrument': 20, 'clock': 66, 'bowl': 100, 'bag': 125, 'scissors': 89, 'mug': 48, 'dishwasher': 54, 'microwave': 144} 8 | 9 | SAMPLES_TEST = {'table': 0, 'trash_can': 1, 'chair': 0, 'pot': 1, 'faucet': 1, 'storage_furniture': 1, 'door_set': 1, 'display': 1, 'lamp': 1, 'keyboard': 1, 'refrigerator': 1, 'bottle': 1, 'hat': 1, 'earphone': 1, 'laptop': 1, 'bed': 1, 'cutting_instrument': 1, 'clock': 1, 'bowl': 1, 'bag': 1, 'scissors': 1, 'mug': 1, 'dishwasher': 1, 'microwave': 1} 10 | 11 | 12 | 13 | if __name__ == '__main__': 14 | # total = 250000 15 | test_total = 0 16 | for k, v in SAMPLES_TEST.items(): 17 | test_total += v * CATEGORY[k] 18 | print(test_total) 19 | -------------------------------------------------------------------------------- /dataset/config_to_object_id.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | import json 5 | import csv 6 | 7 | # import LVDataset 8 | 9 | 10 | if __name__ == '__main__': 11 | paths = [] 12 | for i in range(303): 13 | if os.path.exists('./checkpoints/maps_{}.pkl'.format(i)) == False: 14 | continue 15 | test_dataset = pickle.load(open('./checkpoints/maps_{}.pkl'.format(i), 'rb')) 16 | paths += test_dataset['entry_path'] 17 | 18 | # object_ids = [path.split('/')[-2] for path in paths] 19 | # object_ids = np.array(object_ids) 20 | # object_ids = np.unique(object_ids) 21 | # print(object_ids) 22 | # print(len(object_ids)) 23 | # # save json 24 | # json.dump(object_ids.tolist(), open('./data/dataset/test_object_ids.json', 'w')) 25 | 26 | # config_ids = [path.split('/')[-1].split('_')[1] for path in paths] 27 | # print(config_ids[0]) 28 | # print(len(config_ids)) 29 | # config_ids = np.unique(config_ids) 30 | # print(len(config_ids)) 31 | 32 | object_config_ids = [[path.split('/')[-2], path.split('/')[-1].split('_')[1]] for path in paths] 33 | 34 | i = 0 35 | filtered_object_config_ids = [] 36 | for object_id, config_id in object_config_ids: 37 | if 'pkl' in config_id: 38 | i += 1 39 | continue 40 | filtered_object_config_ids.append([object_id, config_id]) 41 | print('not _v1: ', i) 42 | 43 | # with open('./data/dataset/test_object_config_ids.csv', 'w', newline='') as f: 44 | # writer = csv.writer(f) 45 | # writer.writerow(['object_id', 'config_id']) 46 | # writer.writerows(filtered_object_config_ids) 47 | print(len(filtered_object_config_ids)) 48 | 49 | 50 | # dataset_set = pickle.load(open('./data/dataset/test_dataset_v1.pkl', 'rb')) 51 | # object_ids = [entry[-1].split('/')[-2] for entry in dataset_set] 52 | # object_ids = np.array(object_ids) 53 | # object_ids = np.unique(object_ids) 54 | # print(object_ids) 55 | # print(len(object_ids)) 56 | 57 | -------------------------------------------------------------------------------- /dataset/generate_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import random 7 | import trimesh 8 | import traceback 9 | from multiprocessing import Pool 10 | 11 | from dataset.Dataset import Dataset, Config, DataEntry 12 | from dataset.utils import shift_mass_center 13 | from grasp.force_optimization import filter_contact_points_by_force 14 | from grasp.generate_grasp import find_contact_points_multi, vis_grasp 15 | import dataset.GPTsummary as GPTsummary 16 | import dataset.category as category 17 | 18 | 19 | def min_distances(query_points, reference_points): 20 | squared_diff = np.sum((query_points[:, np.newaxis] - reference_points) ** 2, axis=-1) 21 | return np.sqrt(np.min(squared_diff, axis=1)) 22 | 23 | 24 | def gaussian(x, mean, variance): 25 | sigma = np.sqrt(variance) 26 | return (1.0 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(- (x - mean) ** 2 / (2 * variance)) 27 | 28 | 29 | def generate_grasp_map(pos_grasps, mesh): 30 | contact_pairs = pos_grasps[:, np.r_[1:4, 8:11]].reshape(-1, 2, 3) 31 | contact_locs = contact_pairs.reshape(-1, 3) 32 | dists = np.linalg.norm(mesh.vertices[:, np.newaxis] - contact_locs, axis=-1, ord=2) 33 | guassian_dists = gaussian(dists, 0, 0.01) 34 | guassian_map = np.sum(guassian_dists, axis=1) 35 | guassian_map = guassian_map / np.max(guassian_map) 36 | # show_grasp_heatmap(mesh, guassian_map, contact_pairs) 37 | return guassian_map 38 | 39 | def show_grasp_heatmap(mesh, grasp_map, contact_pairs=[]): 40 | mesh.visual.vertex_colors = trimesh.visual.color.interpolate(grasp_map, color_map='hot') 41 | mesh.visual.vertex_colors[:, 3] = 0.8 * 255 42 | scene_list = [mesh] 43 | for contact_point_1, contact_point_2 in contact_pairs: 44 | # c1 = trimesh.creation.uv_sphere(radius=0.005) 45 | # c2 = trimesh.creation.uv_sphere(radius=0.005) 46 | # c1.vertices += contact_point 47 | # c2.vertices += another_contact_point 48 | grasp_axis = trimesh.creation.cylinder(0.005, sections=6, 49 | segment=np.vstack([contact_point_1, contact_point_2])) 50 | grasp_axis.visual.vertex_colors = [0, 0., 1.] 51 | # c1.visual.vertex_colors = [1., 0, 0] 52 | # c2.visual.vertex_colors = [1., 0, 0] 53 | # scene_list += [c1, c2, grasp_axis] 54 | scene_list += [grasp_axis] 55 | trimesh.Scene(scene_list).show() 56 | 57 | 58 | def generate_grasp_data(meshes, frictions, sample_probs, max_normal_forces, weight, num_sample=100): 59 | pos_grasps, neg_grasps = find_contact_points_multi(meshes, frictions, (sample_probs * num_sample).astype('int')) 60 | pos_grasps, force_neg_grasps = filter_contact_points_by_force(pos_grasps, frictions, max_normal_forces, weight) 61 | grasp_map = generate_grasp_map(pos_grasps, trimesh.util.concatenate(meshes)) 62 | return grasp_map, pos_grasps, np.append(neg_grasps, force_neg_grasps, axis=0) 63 | 64 | 65 | def generate_random_config(meshes, available_materials): 66 | config_id = time.time() 67 | num_parts = len(meshes) 68 | material_ids = random.sample(range(0, len(available_materials)), num_parts) 69 | 70 | materials = [available_materials[material_id]['Material'] for material_id in material_ids] 71 | frictions = [available_materials[material_id]['Friction'] for material_id in material_ids] 72 | densities = [available_materials[material_id]['Density'] for material_id in material_ids] 73 | grasp_likelihoods = random.choice([None, np.random.random(num_parts)]) 74 | 75 | fragilities = [available_materials[material_id]['Fragility'] for material_id in material_ids] 76 | 77 | sample_probs = grasp_likelihoods / sum(grasp_likelihoods) if grasp_likelihoods is not None else np.ones(num_parts) / num_parts 78 | 79 | max_normal_forces = np.power(10, np.array(fragilities)) 80 | masses = shift_mass_center(meshes, densities) 81 | 82 | return Config(config_id, materials, frictions, densities, grasp_likelihoods, fragilities, 83 | sample_probs, max_normal_forces, masses) 84 | 85 | def test(args): 86 | print(len(args)) 87 | time.sleep(1) 88 | 89 | def generate_data(objects): 90 | dataset = Dataset('./data/objects') 91 | available_materials = json.load(open('./dataset/materials.json', 'r')) 92 | num_sample = 200 93 | for object_id in objects: 94 | meshes = dataset[object_id].load_meshes() 95 | name = dataset[object_id].name 96 | num_config = category.SAMPLES[name] 97 | for i in range(num_config): 98 | try: 99 | config = generate_random_config(meshes, available_materials) 100 | grasp_map, pos_grasps, neg_grasps = generate_grasp_data(meshes, config.frictions, config.sample_probs, 101 | config.max_normal_forces, sum(config.masses) * 9.81, 102 | num_sample) 103 | language = GPTsummary.summary(config, dataset[object_id], available_materials) 104 | dataset[object_id][config.id] = DataEntry(config, pos_grasps, neg_grasps, grasp_map, language) 105 | except Exception as e: 106 | # traceback.print_exc() 107 | # print(f'Object {object_id} config {i} error: {e}') 108 | with open('./dataset/error.txt', 'a') as f: 109 | f.write(f'Object {object_id} config {i} error: {e}\n') 110 | continue 111 | dataset[object_id].save() 112 | 113 | if __name__ == '__main__': 114 | dataset0 = Dataset('./data/objects') 115 | objs = dataset0.get_object_ids() 116 | NUM_PROCESS = 16 117 | tasks = [] 118 | random.shuffle(objs) 119 | for i in range(NUM_PROCESS): 120 | tasks.append(objs[i::NUM_PROCESS]) 121 | 122 | pool = Pool(NUM_PROCESS) 123 | pool.map(generate_data, tasks) 124 | pool.close() 125 | 126 | # dataset.load() -------------------------------------------------------------------------------- /dataset/hardset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from dataset.LVDataset import LVDataset 5 | import numpy as np 6 | import time 7 | import random 8 | import os 9 | import openai 10 | import multiprocessing 11 | openai.api_type = "azure" 12 | openai.api_version = "2023-05-15" 13 | openai.api_key = os.getenv("AZURE_OPENAI_KEY") 14 | openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") 15 | deployment_name='gpt35' 16 | print(openai.api_key, openai.api_base) 17 | from dataset.hard_prompt import A, B 18 | 19 | def random_set(entry_paths, size = 2000): 20 | random.shuffle(entry_paths) 21 | with open('./data/dataset/v1_random.pkl', 'wb') as f: 22 | pickle.dump(entry_paths[0:size], f) 23 | return entry_paths[0:size] 24 | 25 | def test_prompt(languages): 26 | cnt = 0 27 | for language in languages: 28 | if gpt_query(language): 29 | cnt += 1 30 | print("{}/{}".format(cnt, len(languages))) 31 | 32 | def gpt_query(language): 33 | completion = openai.ChatCompletion.create( 34 | engine=deployment_name, 35 | model="gpt-3.5-turbo", 36 | messages=[ 37 | {"role": "system", "content": "You should judge the following description based on the given information. The description is about a specific object and its parts. You should judge whether this object is category A or catergory B. Here are the characteristics of A and B. Material Usage: A (precious/uncommon) vs. B (common/practical); Friction Focus: A (ease of use/movement) vs. B (safety/stability); Density and Weight Considerations: A (heavier/solid) vs. B (lighter/varied); Fragility vs. Toughness: A (balanced) vs. B (durable/heavier use); Grasp Probability Guidance: A (specific) vs. B (general). Please Answer 'A' or 'B'. If you are not sure, please answer 'I don't know'."}, 38 | # {"role": "system", "content": PT,}, 39 | {"role": "user", "content": language}, 40 | ], 41 | ) 42 | input_tokens = completion.usage['prompt_tokens'] 43 | output_tokens = completion.usage['completion_tokens'] 44 | response = completion.choices[0].message['content'] 45 | print(response) #,'Tokens = ',input_tokens,'+',output_tokens,'=',input_tokens+output_tokens,'Price =', 1e-3*input_tokens*0.01+1e-3*output_tokens*0.03) 46 | # convert response to boolean 47 | return 'A' in response 48 | 49 | def get_hard(entry_paths): 50 | hard_entry_paths = [] 51 | for entry_path in entry_paths: 52 | with open(entry_path, 'rb') as f: 53 | entry = pickle.load(f) 54 | language = entry.language 55 | if gpt_query(language): 56 | hard_entry_paths.append(entry_path) 57 | # time.sleep(1.0) 58 | # with open('./data/dataset/hard_entry_paths_p2.pkl', 'wb') as f: 59 | # pickle.dump(hard_entry_paths, f) 60 | return hard_entry_paths 61 | 62 | def main(): 63 | hard_entry_paths = [] 64 | with open('./data/dataset/test_dataset_v2.pkl', 'rb') as f: 65 | test_dataset = pickle.load(f) 66 | entry_paths = test_dataset.entry_paths 67 | hard_q = pickle.load(open('./evaluation/hard_entry_paths_q.pkl', 'rb')) 68 | hard_x = pickle.load(open('./evaluation/hard_entry_paths_x.pkl', 'rb')) 69 | random.shuffle(hard_q) 70 | random.shuffle(hard_x) 71 | hards = hard_q[:3000] + hard_x[:3000] 72 | entry_paths = list(set(entry_paths) - set(hards)) 73 | random.shuffle(entry_paths) 74 | random.shuffle(hards) 75 | paths = entry_paths[:1800] + hards[:600] 76 | 77 | N_PROC = 2 78 | tasks = [] 79 | for i in range(N_PROC): 80 | tasks.append(paths[i::N_PROC]) 81 | pool = multiprocessing.Pool(processes=N_PROC) 82 | results = pool.map(get_hard, tasks) 83 | for result in results: 84 | hard_entry_paths += result 85 | with open('./data/dataset/hard_entry_paths_q2.pkl', 'wb') as f: 86 | pickle.dump(hard_entry_paths, f) 87 | print(len(hard_entry_paths)) 88 | 89 | if __name__ == "__main__": 90 | main() -------------------------------------------------------------------------------- /dataset/helpers.py: -------------------------------------------------------------------------------- 1 | def count_category(dataset): 2 | category = {} 3 | for obj in dataset.data_entries.values(): 4 | num = len(obj.data) 5 | category[obj.name] = category.get(obj.name, 0) + num 6 | return category 7 | 8 | def get_languages(dataset): 9 | languages = [] 10 | for obj in dataset.data_entries.values(): 11 | for entry in obj.data.values(): 12 | languages.append(entry.language) 13 | return languages -------------------------------------------------------------------------------- /dataset/label_grasp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import random 7 | import trimesh 8 | import traceback 9 | from multiprocessing import Pool 10 | import multiprocessing 11 | 12 | from dataset.Dataset import Dataset, Config, DataEntry 13 | from dataset.utils import shift_mass_center 14 | from grasp.force_optimization import filter_contact_points_by_force 15 | from grasp.generate_grasp import find_contact_points_multi, vis_grasp 16 | import dataset.GPTsummary as GPTsummary 17 | import dataset.category as category 18 | from tqdm import tqdm 19 | import pickle 20 | 21 | # start_time = time.time() 22 | MIN_POS_GRASP = 10 23 | updated_objects = pickle.load(open('./data/updated_objects.pkl', 'rb')) 24 | 25 | def min_distances(query_points, reference_points): 26 | squared_diff = np.sum((query_points[:, np.newaxis] - reference_points) ** 2, axis=-1) 27 | return np.sqrt(np.min(squared_diff, axis=1)) 28 | 29 | 30 | def gaussian(x, mean, variance): 31 | sigma = np.sqrt(variance) 32 | return (1.0 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(- (x - mean) ** 2 / (2 * variance)) 33 | 34 | 35 | def generate_grasp_map(pos_grasps, mesh): 36 | contact_pairs = pos_grasps[:, np.r_[1:4, 8:11]].reshape(-1, 2, 3) 37 | contact_locs = contact_pairs.reshape(-1, 3) 38 | dists = np.linalg.norm(mesh.vertices[:, np.newaxis] - contact_locs, axis=-1, ord=2) 39 | guassian_dists = gaussian(dists, 0, 0.01) 40 | guassian_map = np.sum(guassian_dists, axis=1) 41 | guassian_map = guassian_map / np.max(guassian_map) 42 | # global start_time 43 | # print(f'generate grasp map time: {time.time() - start_time} s') 44 | # is_show = input('show grasp map? (y/n)') 45 | # if is_show == 'y': 46 | # show_grasp_heatmap(mesh, guassian_map, contact_pairs) 47 | return guassian_map 48 | 49 | def show_grasp_heatmap(mesh, grasp_map, contact_pairs=[]): 50 | mesh.visual.vertex_colors = trimesh.visual.color.interpolate(grasp_map, color_map='hot') 51 | mesh.visual.vertex_colors[:, 3] = 0.8 * 255 52 | scene_list = [mesh] 53 | for contact_point_1, contact_point_2 in contact_pairs: 54 | grasp_axis = trimesh.creation.cylinder(0.005, sections=6, 55 | segment=np.vstack([contact_point_1, contact_point_2])) 56 | grasp_axis.visual.vertex_colors = [0, 0., 1.] 57 | scene_list += [grasp_axis] 58 | trimesh.Scene(scene_list).show() 59 | 60 | 61 | def generate_grasp_data(meshes, frictions, sample_probs, max_normal_forces, weight, num_sample=100): 62 | pos_grasps, neg_grasps = find_contact_points_multi(meshes, frictions, (sample_probs * num_sample).astype('int')) 63 | pos_grasps, force_neg_grasps = filter_contact_points_by_force(pos_grasps, frictions, max_normal_forces, weight) 64 | grasp_map = generate_grasp_map(pos_grasps, trimesh.util.concatenate(meshes)) 65 | return grasp_map, pos_grasps, np.append(neg_grasps, force_neg_grasps, axis=0) 66 | 67 | 68 | 69 | def test(object_id): 70 | objects = [object_id] 71 | label_grasp(objects, config_id_t=1704583341.9222496) 72 | exit() 73 | a = pickle.load(open('./data/objects/4573/4573_1704605092.5541441_v1.pkl', 'rb')) 74 | print(a.language) 75 | print(a.pos_grasps.shape) 76 | print(a.neg_grasps.shape) 77 | print(a.grasp_map.shape) 78 | print(a.config.id) 79 | exit() 80 | 81 | 82 | 83 | def label_grasp(objects, config_id_t=None): 84 | dataset = Dataset('./data/objects') 85 | available_materials = json.load(open('./dataset/materials.json', 'r')) 86 | n_total_entries = 0 87 | n_valid_entries = 0 88 | num_sample = 1000 89 | current_objects = [] 90 | for i, object_id in enumerate(objects): 91 | meshes = dataset[object_id].load_meshes() 92 | name = dataset[object_id].name 93 | dataset[object_id].load() 94 | entries = dataset[object_id].data 95 | n_total_entries += len(entries) 96 | invalid_entries = [] 97 | for entry in entries.values(): 98 | # print(f'Object {object_id} config {entry.config.id} language: {entry.language}') 99 | try: 100 | config = entry.config 101 | # if config_id_t is not None and config.id != config_id_t: 102 | # continue 103 | # else: 104 | # print(f'Object {object_id} config {entry.config.id} start') 105 | shift_mass_center(meshes, config.densities) 106 | grasp_map, pos_grasps, neg_grasps = generate_grasp_data(meshes, config.frictions, config.sample_probs, 107 | config.max_normal_forces, sum(config.masses) * 9.81, 108 | num_sample) 109 | if pos_grasps.shape[0] < MIN_POS_GRASP: 110 | invalid_entries.append(str(entry.config.id)) 111 | continue 112 | entry.pos_grasps = pos_grasps 113 | entry.neg_grasps = neg_grasps 114 | entry.language = GPTsummary.summary(config, dataset[object_id], available_materials) 115 | except Exception as e: 116 | traceback.print_exc() 117 | print(f'Object {object_id} config {entry.config.id} error: {e}') 118 | with open('./dataset/error.txt', 'a') as f: 119 | f.write(f'Object {object_id} config {entry.config.id} error: {e}\n') 120 | invalid_entries.append(str(entry.config.id)) 121 | continue 122 | 123 | # print("finish object {}".format(object_id)) 124 | for config_id in invalid_entries: 125 | del entries[config_id] 126 | n_valid_entries += len(dataset[object_id].data) 127 | 128 | # if config_id_t is not None: 129 | dataset[object_id].save(version='_v1') 130 | dataset[object_id].unload() 131 | current = multiprocessing.current_process() 132 | print("{}/{} entries are valid {}/{} objects done by process {}".format(n_valid_entries, n_total_entries, i + 1, len(objects), current.name)) 133 | current_objects.append(object_id) 134 | if i % 10 == 0 or i == len(objects) - 1: 135 | updated_objects = pickle.load(open('./data/updated_objects.pkl', 'rb')) 136 | updated_objects += current_objects 137 | updated_objects = list(set(updated_objects)) 138 | pickle.dump(updated_objects, open('./data/updated_objects.pkl', 'wb')) 139 | 140 | if __name__ == '__main__': 141 | # test('8001') 142 | dataset0 = Dataset('./data/objects') 143 | objs = dataset0.get_object_ids() 144 | objs = [obj for obj in objs if obj not in updated_objects] 145 | NUM_PROCESS = 24 146 | 147 | tasks = [] 148 | random.shuffle(objs) 149 | for i in range(NUM_PROCESS): 150 | tasks.append(objs[i::NUM_PROCESS]) 151 | 152 | pool = Pool(NUM_PROCESS) 153 | pool.map(label_grasp, tasks) 154 | pool.close() 155 | 156 | # dataset.load() -------------------------------------------------------------------------------- /dataset/material_count.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/dataset/material_count.pkl -------------------------------------------------------------------------------- /dataset/materials.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Material": "aluminum", 4 | "Friction": 0.35, 5 | "Density": 2700, 6 | "Fragility": 4 7 | }, 8 | { 9 | "Material": "steel", 10 | "Friction": 0.45, 11 | "Density": 7850, 12 | "Fragility": 5 13 | }, 14 | { 15 | "Material": "copper", 16 | "Friction": 0.36, 17 | "Density": 8960, 18 | "Fragility": 4 19 | }, 20 | { 21 | "Material": "iron", 22 | "Friction": 0.47, 23 | "Density": 7874, 24 | "Fragility": 5 25 | }, 26 | { 27 | "Material": "brass", 28 | "Friction": 0.38, 29 | "Density": 8530, 30 | "Fragility": 4 31 | }, 32 | { 33 | "Material": "bronze", 34 | "Friction": 0.37, 35 | "Density": 8800, 36 | "Fragility": 4 37 | }, 38 | { 39 | "Material": "nickel", 40 | "Friction": 0.39, 41 | "Density": 8908, 42 | "Fragility": 4 43 | }, 44 | { 45 | "Material": "tin", 46 | "Friction": 0.32, 47 | "Density": 7310, 48 | "Fragility": 3 49 | }, 50 | { 51 | "Material": "lead", 52 | "Friction": 0.3, 53 | "Density": 11340, 54 | "Fragility": 3 55 | }, 56 | { 57 | "Material": "zinc", 58 | "Friction": 0.33, 59 | "Density": 7140, 60 | "Fragility": 3 61 | }, 62 | { 63 | "Material": "titanium", 64 | "Friction": 0.38, 65 | "Density": 4500, 66 | "Fragility": 4 67 | }, 68 | { 69 | "Material": "silicon", 70 | "Friction": 0.6, 71 | "Density": 2330, 72 | "Fragility": 2 73 | }, 74 | { 75 | "Material": "gold", 76 | "Friction": 0.47, 77 | "Density": 19300, 78 | "Fragility": 3 79 | }, 80 | { 81 | "Material": "silver", 82 | "Friction": 0.38, 83 | "Density": 10490, 84 | "Fragility": 3 85 | }, 86 | { 87 | "Material": "platinum", 88 | "Friction": 0.4, 89 | "Density": 21450, 90 | "Fragility": 3 91 | }, 92 | { 93 | "Material": "quartz", 94 | "Friction": 0.45, 95 | "Density": 2650, 96 | "Fragility": 2 97 | }, 98 | { 99 | "Material": "granite", 100 | "Friction": 0.55, 101 | "Density": 2750, 102 | "Fragility": 2 103 | }, 104 | { 105 | "Material": "marble", 106 | "Friction": 0.5, 107 | "Density": 2560, 108 | "Fragility": 2 109 | }, 110 | { 111 | "Material": "glass", 112 | "Friction": 0.5, 113 | "Density": 2500, 114 | "Fragility": 1 115 | }, 116 | { 117 | "Material": "fiberglass", 118 | "Friction": 0.6, 119 | "Density": 2020, 120 | "Fragility": 3 121 | }, 122 | { 123 | "Material": "polyester", 124 | "Friction": 0.4, 125 | "Density": 1380, 126 | "Fragility": 2 127 | }, 128 | { 129 | "Material": "nylon", 130 | "Friction": 0.35, 131 | "Density": 1150, 132 | "Fragility": 4 133 | }, 134 | { 135 | "Material": "linen", 136 | "Friction": 0.4, 137 | "Density": 1500, 138 | "Fragility": 4 139 | }, 140 | { 141 | "Material": "leather", 142 | "Friction": 0.5, 143 | "Density": 860, 144 | "Fragility": 4 145 | }, 146 | { 147 | "Material": "rubber", 148 | "Friction": 0.8, 149 | "Density": 1522, 150 | "Fragility": 5 151 | }, 152 | { 153 | "Material": "neoprene", 154 | "Friction": 0.7, 155 | "Density": 1700, 156 | "Fragility": 4 157 | }, 158 | { 159 | "Material": "latex", 160 | "Friction": 0.8, 161 | "Density": 1015, 162 | "Fragility": 4 163 | }, 164 | { 165 | "Material": "plywood", 166 | "Friction": 0.5, 167 | "Density": 600, 168 | "Fragility": 4 169 | }, 170 | { 171 | "Material": "bamboo", 172 | "Friction": 0.38, 173 | "Density": 700, 174 | "Fragility": 4 175 | }, 176 | { 177 | "Material": "ceramics", 178 | "Friction": 0.7, 179 | "Density": 2300, 180 | "Fragility": 1 181 | }, 182 | { 183 | "Material": "porcelain", 184 | "Friction": 0.7, 185 | "Density": 2400, 186 | "Fragility": 1 187 | }, 188 | { 189 | "Material": "terracotta", 190 | "Friction": 0.55, 191 | "Density": 1800, 192 | "Fragility": 2 193 | }, 194 | { 195 | "Material": "clay", 196 | "Friction": 0.5, 197 | "Density": 1600, 198 | "Fragility": 2 199 | }, 200 | { 201 | "Material": "cardboard", 202 | "Friction": 0.3, 203 | "Density": 689, 204 | "Fragility": 2 205 | }, 206 | { 207 | "Material": "wood", 208 | "Friction": 0.4, 209 | "Density": 700, 210 | "Fragility": 3 211 | }, 212 | { 213 | "Material": "plastic", 214 | "Friction": 0.4, 215 | "Density": 1400, 216 | "Fragility": 3 217 | } 218 | ] -------------------------------------------------------------------------------- /dataset/object_count.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/dataset/object_count.pkl -------------------------------------------------------------------------------- /dataset/prompt1.py: -------------------------------------------------------------------------------- 1 | raise NotImplementedError("Please modify prompt1.py to update with prompt2.py, especially the description of fragility. Then remove this line.") 2 | 3 | ROLE = "You are a grasping analytical assistant, skilled in summerizing the feature of different objects and materials with a natural language. You should provide much information with minimal words. You focus on the important features of every part with their materials, rather than the specific values. You will be given a paragraph describing the object and its parts with their materials, densities, frictions, fragilities (1 is the most fragile, 5 is the least fragile), and human grasp probabilities hints. You should follow such rules: \n1. Names: Describe the object names and material names precisely.\n2. Densities: Point out the most dense part or the most light part. If the density difference is not obvious, you can ignore it.\n3. Frictions: Point out the part with the highest friction and the lowest friction. If the friction difference is not obvious, you can ignore it.\n4. Fragilities: Point out the most fragile. If the fragility difference is not obvious, you can ignore it.\n5. Grasp Probabilities: Point out the part with the highest grasp probability and the lowest grasp probability. If the grasp probability difference is not obvious, you can ignore it. If some instances don't have grasp probabilites, you can discard this part. \n" 4 | 5 | # TODO: Prompt Example with grasp probability 6 | EXAMPLE_DESCRIPTION = "I will give you some examples.\n" 7 | 8 | 9 | EXAMPLE1 = "Example 1: \n" + "Input: There is a table, it has several parts including tabletop, table_base. The material of each part is latex, terracotta, with friction 0.8, 0.55, density 1015, 1800, fragility level 4, 2. The grasping probabilities of each part it 0.78, 0.22." + "Output: The table consists of two parts: top and base. The top is made of latex, which has a high friction of 0.8. The base is made of terracotta, which has a lower friction of 0.55. The top has a density of 1015 kg/m³, while the base has a higher density of 1800 kg/m³. In terms of fragility, the top has a fragility level of 4, while the base has a lower fragility level of 2. Personally, I suggest to grasp the top." 10 | 11 | EXAMPLE2 = "Example 2: \n" + "Input: The trash_can has 2 parts: container_box and container_bottom. The materials used are glass and steel. The friction values for glass and steel are 0.6 and 0.5 respectively. The density values for glass and steel are 2500 and 7850 respectively. The fragility levels for glass and steel are 1 and 5 respectively. \n" + "Output: The trash_can has two parts: box and bottom. The box is made of glass with a higher fragility of 1. The container_bottom's material is steel, with lower fragility 5. The box's density is 2500 kg/m³ and the bottom has a larger density of 7850 kg/m³. The friction values for box and bottom are 0.6 and 0.5 respectively.\n" 12 | 13 | 14 | EXAMPLE3 = "Example 3: \n" + "Input: There is an faucet, it has several parts including switch, frame, spout. The material of each part is plastic, brass, fiberglass, with friction 0.4, 0.38, 0.6, density 1400, 8530, 2020, fragility level 3, 4, 3." + "Output: The faucet consists of multiple components: switch, frame, and spout. The switch is crafted from plastic, characterized by a high friction coefficient of 0.4. The frame, made of brass, presents a slightly lower friction at 0.38. In contrast, the spout is made of fiberglass, notable for its higher friction of 0.6. Regarding density, the switch has a moderate density of 1400 kg/m³, whereas the frame is significantly denser at 8530 kg/m³. The spout, on the other hand, has a density of 2020 kg/m³. As for fragility, both the plastic switch and the fiberglass spout have a fragility level of 3, indicating an ordinary risk of damage, while the brass frame is slightly more robust with a level of 4." 15 | 16 | EXAMPLE4 = "Example 4: \n" + "Input: There is an chair, it has several parts including chair_back, chair_seat, chair_base. The material of each part is quartz, latex, steel, with friction 0.45, 0.8, 0.45, density 2650, 1015, 7850, fragility level 2, 4, 5. The grasping probabilities of each part it 0.06, 0.44, 0.5.\n" + "Output: The chair consists of three parts: chair_back, chair_seat, and chair_base. The chair_back is made of quartz with a friction value of 0.45. The chair_seat is made of latex, which has a high friction value of 0.8. The chair_base is made of steel, with a friction value of 0.45. The chair_back has a fragility level of 2, indicating it is more fragile compared to the other parts. The chair_seat has a fragility level of 4, while the chair_base is the least fragile with a fragility level of 5. The chair_base has the highest density of 7850 kg/m³, making it the heaviest part. The chair_seat has a density of 1015 kg/m³, and the chair_back has the lowest density of 2650 kg/m³. The chair_seat has the highest grasp probability of 0.44, indicating a higher chance of successful grasping. The chair_back has the lowest grasp probability of 0.06, suggesting more difficulty in grasping it.\n" 17 | 18 | EXAMPLE5 = "Example 5: \n" + "Input: There is an bowl, it has several parts including container. The material of each part is silver, with friction 0.38, density 10490, fragility level 3.\n" + "Output: The bowl has one part, the container. It is made of silver with a friction value of 0.38. The density is 10490 kg/m³. The fragility level of is 3.\n" 19 | 20 | EXAMPLES = EXAMPLE1 + EXAMPLE2 + EXAMPLE3 + EXAMPLE4 + EXAMPLE5 21 | 22 | INSTRUCTION = "Please process the following paragraph. Output in one paragraph. \n" -------------------------------------------------------------------------------- /dataset/prompt2.py: -------------------------------------------------------------------------------- 1 | ROLE = "You are a grasping analytical assistant, skilled in summerizing the feature of different objects and materials with a natural language. You should provide much information with minimal words. You focus on the important features of every part with their materials, rather than the specific values. You will be given a paragraph describing the object and its parts with their materials, densities, frictions, fragilities, and human grasp probabilities hints. You should follow such rules: \n1. Names: Describe the object names and material names precisely.\n2. Densities: Point out the most dense part or the most light part. If the density difference is not obvious, you can ignore it.\n3. Frictions: Point out the part with the highest friction and the lowest friction. If the friction difference is not obvious, you can ignore it.\n4. Fragilities: Point out the most fragile. If the fragility difference is not obvious, you can ignore it.\n5. Grasp Probabilities: Point out the part with the highest grasp probability and the lowest grasp probability. If the grasp probability difference is not obvious, you can ignore it. If some instances don't have grasp probabilites, you can discard this part. \n" 2 | 3 | EXAMPLE_DESCRIPTION = "I will give you some examples.\n" 4 | 5 | EXAMPLE1 = "Example 1: \n" + "Input: There is a table, it has several parts including tabletop, table_base. The material of each part is latex, terracotta, with friction: 0.8, 0.55, density: 1015, 1800, fragility: tough, fragile. The grasping probabilities of each part it 0.78, 0.22.\n" + "Output: The table consists of two parts: table top and table base. The tabletop is made of latex, which has a bigger friction. The material of table_base is terracotta, which has a lower friction. The density of table base is about twice that of the table top. The table base is more fragile. I will advice to grasp the table top.\n" 6 | 7 | EXAMPLE2 = "Example 2: \n" + "Input: There is a trash_can, it has several parts including container_box and container_bottom. The material for each part is glass and steel, with friction: 0.6 and 0.5 , density: 2500 and 7850, fragility: very fragile and very tough.\n" + "Output: The trash_can has two parts: container_box and container_bottom. The container_box is made of glass, which is very fragile. The container_bottom's material is steel, with much bigger density. They share a similar frictions.\n" 8 | 9 | 10 | EXAMPLE3 = "Example 3: \n" + "Input: There is an faucet, it has several parts including switch, frame, spout. The material of each part is plastic, brass, fiberglass, with friction: 0.4, 0.38, 0.6, density: 1400, 8530, 2020, fragility: normal, tough, normal." + "Output: The faucet has three parts: switch, frame, spout. The spout is made of fiberglass with the highest friction. The switch's material is plastic and the frame is made of brass with the biggest density. \n" 11 | 12 | EXAMPLE4 = "Example 4: \n" + "Input: There is an chair, it has several parts including chair_back, chair_seat, chair_base. The material of each part is quartz, latex, steel, with friction: 0.45, 0.8, 0.45, density: 2650, 1015, 7850, fragility: fragile, tough, very tough. The grasping probabilities of each part it 0.06, 0.44, 0.5.\n" + "Output: The chair consists of three parts: chair_back, chair_seat, and chair_base. The chair_back is made of quartz, indicating it is more fragile compared to the other parts . The chair_seat is made of latex, which has the highest friction. The chair_base has the highest density, making it the heaviest part. I prefer not to grasp the chair_back.\n" 13 | 14 | EXAMPLE5 = "Example 5: \n" + "Input: There is an bowl, it has several parts including container. The material of each part is silver, with friction: 0.38, density: 10490, fragility: normal.\n" + "Output: The bowl has one part, the container. It is made of silver with a low friction value. Its density is relatively high. And it has normal fragility.\n" 15 | 16 | EXAMPLES = EXAMPLE1 + EXAMPLE2 + EXAMPLE3 + EXAMPLE4 + EXAMPLE5 17 | 18 | INSTRUCTION = "Please process the following paragraph. Output in one paragraph. \n" -------------------------------------------------------------------------------- /dataset/stat.py: -------------------------------------------------------------------------------- 1 | from dataset.Dataset import Dataset 2 | from dataset.category import OBJECTS 3 | import pickle 4 | 5 | OBJECT_COUNTS = {key: 0 for key in OBJECTS} 6 | assert len(OBJECTS) == len(OBJECT_COUNTS) 7 | MATERIALS_COUNT = {} 8 | material_path = './dataset/material_count.pkl' 9 | object_path = './dataset/object_count.pkl' 10 | 11 | dataset = Dataset('./data/objects') 12 | objs = dataset.get_object_ids() 13 | for i, obj_id in enumerate(objs): 14 | name = dataset[obj_id].name 15 | dataset[obj_id].load("_v1") 16 | OBJECT_COUNTS[name] += len(dataset[obj_id].data) 17 | for config_id, entry in dataset[obj_id].data.items(): 18 | for material in entry.config.materials: 19 | MATERIALS_COUNT[material] = MATERIALS_COUNT.get(material, 0) + 1 20 | dataset[obj_id].unload() 21 | if i % 100 == 0: 22 | print(f'Processed {i}/{len(objs)} objects') 23 | print("OBJECT_COUNTS", OBJECT_COUNTS) 24 | print("MATERIALS_COUNT", MATERIALS_COUNT) 25 | pickle.dump(OBJECT_COUNTS, open(object_path, 'wb')) 26 | pickle.dump(MATERIALS_COUNT, open(material_path, 'wb')) 27 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | 4 | 5 | def shift_mass_center(meshes, densities): 6 | weighted_mass_center = np.zeros(3) 7 | masses = [] 8 | for mesh, density in zip(meshes, densities): 9 | center_mass = mesh.centroid if mesh.volume < 1e-6 else mesh.center_mass 10 | volume = mesh.convex_hull.volume if mesh.volume < 1e-6 else mesh.volume 11 | mass = volume * density 12 | weighted_mass_center += center_mass * mass 13 | masses.append(mass) 14 | # green_ball = trimesh.creation.uv_sphere(radius=0.01) 15 | # green_ball.visual.vertex_colors = [0.0, 1.0, 0.0] 16 | # 17 | # red_ball = trimesh.creation.uv_sphere(radius=0.01) 18 | # red_ball.visual.vertex_colors = [1.0, 0.0, 0.0] 19 | # red_ball.vertices += center_mass 20 | # trimesh.Scene([mesh, green_ball, red_ball]).show() 21 | mass_center = weighted_mass_center / sum(masses) 22 | 23 | for mesh in meshes: 24 | mesh.vertices -= mass_center 25 | 26 | # b = trimesh.creation.uv_sphere(radius=0.05) 27 | # b.visual.vertex_colors = [0.0, 1.0, 0.0] 28 | # trimesh.Scene([combined_mesh, b]).show() 29 | 30 | return masses 31 | 32 | 33 | def compute_part_ids(p, meshes): # p: (n, 3) 34 | distance = np.zeros((len(meshes), len(p))) # (n, m) 35 | for id in range(len(meshes)): 36 | _, dis, _ = trimesh.proximity.closest_point(meshes[id], p) 37 | distance[id] = np.array(dis) 38 | return np.argmin(distance, axis=0) # (n,) 39 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/demo/__init__.py -------------------------------------------------------------------------------- /demo/demo_extract_language_feature.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from encoder.LLaMA27BEncoder import LLaMA27BEncoder 4 | 5 | if __name__ == '__main__': 6 | # conda activate bi 7 | # torchrun --nproc_per_node 1 -m demo.demo_extract_language_feature 8 | 9 | objects_path = '/home/gdk/Repositories/DualArmManipulation/demo/demo_objects' 10 | 11 | objects = [ 12 | { 13 | "name": "banana", 14 | "simple_description": "This is a banana.", 15 | "detailed_description": "The object consists of two parts: the cap and the middle part" 16 | "The object is very ripe, indicating it is very fragile to grasp on the middle part. " 17 | "The cap is easier to grasp." 18 | "I will recommended to grasp the cap." 19 | }, 20 | { 21 | "name": "monitor", 22 | "simple_description": "There is a monitor.", 23 | "detailed_description": "The display consists of two parts: the display screen and the base. " 24 | "The display screen is made of glass, indicating it is very fragile compared to the base which is made of iron. " 25 | "The base has a higher density compared to the display screen. The base also has a higher friction compared to the display screen. " 26 | "I will recommended to grasp the base." 27 | 28 | }, 29 | { 30 | "name": "pill_bottle", 31 | "simple_description": "There is a bottle.", 32 | "detailed_description": "There is a pill bottle. It consists of two parts: the cap and the bottle body. " 33 | "The bottle body is made of plastic, indicating they have low friction and harder to grasp." 34 | "The cap is made of rubber, which means it has high friction. " 35 | "I would recommend you to grasp the cap on the top." 36 | }, 37 | { 38 | "name": "plastic_hammer", 39 | "simple_description": "There is an object.", 40 | "detailed_description": "The object consists of two parts: the handle and the base. " 41 | "The handle is made of plastic, indicating it is light and smooth and low friction to grasp. " 42 | "The base is made of metal, indicating the center of mass is near the base. " 43 | "I would recommend to grasp the base." 44 | }, 45 | { 46 | "name": "hammer", 47 | "simple_description": "There is a hammer.", 48 | "detailed_description": "There is a hammer. It consists of two parts: the handle and the head. " 49 | "The handle is made of plastic, indicating it is light and has low density. " 50 | "The head is made of metal, indicating it is heavy. " 51 | "According to the grasp probabilities, it is recommended to grasp the head." 52 | } 53 | ] 54 | 55 | # language feature 56 | encoder = LLaMA27BEncoder() 57 | for obj in objects[4:]: 58 | object_name = obj['name'] 59 | description_simple = obj['simple_description'] 60 | description_detailed = obj['detailed_description'] 61 | 62 | encoded_text = encoder.encode(description_simple, layer_nums=[15, 20, 25]) 63 | language_feature_20 = encoded_text[1] 64 | pickle.dump(language_feature_20, open(f'{objects_path}/{object_name}/simple_language_feature_20.pkl', 'wb')) 65 | 66 | encoded_text = encoder.encode(description_detailed, layer_nums=[15, 20, 25]) 67 | language_feature_20 = encoded_text[1] 68 | pickle.dump(language_feature_20, open(f'{objects_path}/{object_name}/detailed_language_feature_20.pkl', 'wb')) 69 | -------------------------------------------------------------------------------- /demo/demo_objects: -------------------------------------------------------------------------------- 1 | /home/gdk/Data/bimanual/demo_objects -------------------------------------------------------------------------------- /demo/main.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | import trimesh 7 | from sklearn.cluster import KMeans 8 | 9 | from model.data_utils import ramdom_sample_pos 10 | from model.model import Net 11 | from scripts.plot_affordance_map import show_heatmap 12 | 13 | 14 | def cluster_points(points, ncluster=5): 15 | """ 16 | Function to cluster a set of points into 5 groups using K-Means algorithm. 17 | 18 | Parameters: 19 | data (numpy.ndarray): A 2D array of shape (20, 3), where each row represents a point in 3D space. 20 | 21 | Returns: 22 | list: A list containing the first index of each cluster. 23 | """ 24 | 25 | # Perform K-Means clustering to divide the points into 5 clusters 26 | kmeans = KMeans(n_clusters=ncluster) 27 | kmeans.fit(points) 28 | 29 | # Get the cluster labels for each point 30 | labels = kmeans.labels_ 31 | 32 | # Prepare the result: list of first indices for each cluster 33 | clusters = {i: [] for i in range(5)} 34 | for idx, label in enumerate(labels): 35 | clusters[label].append(idx) 36 | print(clusters) 37 | 38 | 39 | # Extracting the first index of each cluster 40 | first_indices = [cluster[0] for cluster in clusters.values()] 41 | 42 | return first_indices 43 | 44 | 45 | 46 | 47 | if __name__ == '__main__': 48 | seed = 32 49 | np.random.seed(seed) 50 | torch.manual_seed(seed) 51 | 52 | objects_path = '/home/gdk/Repositories/DualArmManipulation/demo/demo_objects' 53 | object_name = 'hammer' 54 | # language_level = 'simple' 55 | language_level = 'detailed' 56 | 57 | mesh = trimesh.load(f'{objects_path}/{object_name}/{object_name}.obj') 58 | 59 | if False and os.path.exists(f'{objects_path}/{object_name}/{language_level}_grasps.pkl'): 60 | print('loading grasps') 61 | pos = pickle.load(open(f'{objects_path}/{object_name}/{language_level}_grasps.pkl', 'rb')) 62 | k1, k2, _ = pos.shape 63 | for i in range(k1): 64 | for j in range(k2): 65 | print(i, j) 66 | p1, p2 = pos[i, j, :3], pos[i, j, 3:] 67 | ball1 = trimesh.primitives.Sphere(radius=0.01, center=p1) 68 | ball1.visual.face_colors = [255, 0, 0, 255] 69 | ball2 = trimesh.primitives.Sphere(radius=0.01, center=p2) 70 | ball2.visual.face_colors = [0, 255, 0, 255] 71 | scene = [mesh, ball1, ball2] 72 | trimesh.Scene(scene).show() 73 | exit() 74 | 75 | 76 | language_feature = pickle.load(open(f'{objects_path}/{object_name}/{language_level}_language_feature_20.pkl', 'rb')) 77 | language_feature = torch.Tensor(language_feature).unsqueeze(0) 78 | vision_features = pickle.load(open(f'{objects_path}/{object_name}/vision_features.pkl', 'rb')) 79 | vision_local = torch.Tensor(vision_features['local_features']).permute(0, 2, 1) 80 | vision_global = torch.Tensor(vision_features['global_features']) 81 | points = torch.Tensor(vision_features['points']).unsqueeze(0) 82 | matrix = vision_features['transform'] 83 | matrix = np.linalg.inv(matrix) 84 | # mesh.apply_transform(matrix) 85 | 86 | print(language_feature.shape, points.shape, vision_local.shape, vision_global.shape) 87 | 88 | model = Net() 89 | checkpoint_path = '/home/gdk/Repositories/DualArmManipulation/checkpoints/model_1706605305.8925593_39.pth' 90 | checkpoint = torch.load(checkpoint_path) 91 | model.load_state_dict(checkpoint['model']) 92 | model.eval() 93 | 94 | output_global, output_local = model(language_feature, points, vision_global, vision_local) 95 | 96 | # show grasp map 97 | pcd_points = points.squeeze().cpu().numpy() 98 | pcd_points = pcd_points @ matrix[:3, :3].T + matrix[:3, 3] 99 | mesh.visual = trimesh.visual.ColorVisuals(vertex_colors=[255, 0, 0, 255]) 100 | show_heatmap(pcd_points, output_global.squeeze().detach().cpu().numpy(), mesh) 101 | 102 | k1, k2, f1, f2 = 5, 5, 4, 10 103 | index1 = torch.topk(output_global.squeeze(), k=k1 * f1).indices # (batch_size, kp1) 104 | index1 = index1.unsqueeze(0) 105 | pos = model.get_pos(output_local, index1, points, kp2=k2 * f2) # (batch_size, kp1, kp2, 6) 106 | # pos = ramdom_sample_pos(pos) # (batch_size, 5, 6) 107 | pos = pos.squeeze(0).cpu().numpy() # (kp1, kp2, 6) 108 | 109 | grasp1_pos = pos[:, 0, 3:] 110 | grasp1_index = cluster_points(grasp1_pos, ncluster=k1) 111 | print(grasp1_index) 112 | pos = pos[grasp1_index] # (k1, kp2, 6) 113 | new_pos = np.zeros((k1, k2, 6)) 114 | for i in range(k1): 115 | grasp2_pos = pos[i, :, :3] 116 | grasp2_index = cluster_points(grasp2_pos, ncluster=k2) 117 | new_pos[i] = pos[i][grasp2_index] 118 | pos = new_pos 119 | 120 | for i in range(k1): 121 | for j in range(k2): 122 | print(i, j) 123 | p1, p2 = pos[i, j, :3], pos[i, j, 3:] 124 | p1 = matrix[:3, :3] @ p1 + matrix[:3, 3] 125 | p2 = matrix[:3, :3] @ p2 + matrix[:3, 3] 126 | pos[i, j, :3], pos[i, j, 3:] = p1, p2 127 | ball1 = trimesh.primitives.Sphere(radius=0.0075, center=p1) 128 | ball1.visual.face_colors = [0, 255, 0, 255] 129 | ball2 = trimesh.primitives.Sphere(radius=0.0075, center=p2) 130 | ball2.visual.face_colors = [0, 255, 0, 255] 131 | mesh.visual.vertex_colors[:, 3] = 0.8 * 255 132 | scene = [mesh, ball1, ball2] 133 | trimesh.Scene(scene).show() 134 | 135 | pickle.dump(pos, open(f'{objects_path}/{object_name}/{language_level}_grasps.pkl', 'wb')) 136 | 137 | -------------------------------------------------------------------------------- /demo/view_graspnet_grasps.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import trimesh 4 | 5 | 6 | if __name__ == '__main__': 7 | objects_path = '/home/gdk/Repositories/DualArmManipulation/demo/demo_objects' 8 | object_name = 'hammer' 9 | 10 | mesh = trimesh.load(f'{objects_path}/{object_name}/{object_name}.obj') 11 | 12 | pos = pickle.load(open(f'{objects_path}/{object_name}/graspnet_grasps.pkl', 'rb')) 13 | for i in range(len(pos)): 14 | p1, p2 = pos[i, :3], pos[i, 3:] 15 | ball1 = trimesh.primitives.Sphere(radius=0.01, center=p1) 16 | ball1.visual.face_colors = [255, 0, 0, 255] 17 | ball2 = trimesh.primitives.Sphere(radius=0.01, center=p2) 18 | ball2.visual.face_colors = [0, 255, 0, 255] 19 | scene = [mesh, ball1, ball2] 20 | trimesh.Scene(scene).show() 21 | -------------------------------------------------------------------------------- /encoder/LLaMA27BEncoder.py: -------------------------------------------------------------------------------- 1 | import llama 2 | import os 3 | from pathlib import Path 4 | import pickle 5 | 6 | class LLaMA27BEncoder: 7 | def __init__(self, model_name="llama-2-7b", tokenizer_path="tokenizer.model"): 8 | """ 9 | Initializes the LLaMA-2-7B encoder with the specified model. 10 | """ 11 | llama_path = Path(os.path.dirname(llama.__file__)).parent 12 | model_path = os.path.join(llama_path, model_name) 13 | tokenizer_path = os.path.join(llama_path, tokenizer_path) 14 | self.generator = llama.Llama.build( 15 | ckpt_dir=model_path, 16 | tokenizer_path=tokenizer_path, 17 | max_seq_len=512, 18 | max_batch_size=6, 19 | ) 20 | 21 | def encode(self, text, layer_nums=[15, 20, 25]): 22 | """ 23 | Encodes the given text using the LLaMA-2-7B model. 24 | 25 | Args: 26 | text (str): The text to be encoded. 27 | 28 | Returns: 29 | list: The encoded text. 30 | """ 31 | 32 | encoded_text = self.generator.encode_layer_output(text, layer_nums=[15, 20, 25]) 33 | for i in range(len(encoded_text)): 34 | encoded_text[i] = encoded_text[i].cpu().detach().squeeze().numpy() 35 | return encoded_text 36 | 37 | if __name__ == '__main__': 38 | ''' 39 | usage: torchrun --nproc_per_node 1 -m encoder.LLaMA27BEncoder 40 | ''' 41 | # Example usage 42 | llama_encoder = LLaMA27BEncoder() 43 | 44 | # Encode some text 45 | sample_text = "This is an object." 46 | encoded_text = llama_encoder.encode(sample_text) 47 | 48 | for i in range(len(encoded_text)): 49 | print(encoded_text[i]) 50 | pickle.dump(encoded_text[1], open('./data/dataset/naive_language.pkl', 'wb')) 51 | 52 | -------------------------------------------------------------------------------- /encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/encoder/__init__.py -------------------------------------------------------------------------------- /encoder/language_encode.py: -------------------------------------------------------------------------------- 1 | from encoder.LLaMA27BEncoder import LLaMA27BEncoder 2 | import dataset.Dataset as Dataset 3 | import random 4 | from multiprocessing import Pool 5 | 6 | def feature_extraction(dataset): 7 | encoder = LLaMA27BEncoder() 8 | for obj in dataset.data_entries.values(): 9 | for entry in obj.data.values(): 10 | entry.language_feature = encoder.encode(entry.language) 11 | obj.save() 12 | 13 | def feature_extraction_objs(objs): 14 | encoder = LLaMA27BEncoder() 15 | dataset = Dataset.Dataset('./data/objects') 16 | for i, obj in enumerate(objs): 17 | dataset[obj].load("_v1") 18 | to_save = False 19 | for entry in dataset[obj].data.values(): 20 | if not hasattr(entry, 'language_feature_20'): 21 | to_save = True 22 | encoded_text = encoder.encode(entry.language, layer_nums=[15, 20, 25]) 23 | # entry.language_feature = encoder.encode(entry.language) 24 | entry.language_feature_15 = encoded_text[0] 25 | entry.language_feature_20 = encoded_text[1] 26 | entry.language_feature_25 = encoded_text[2] 27 | if to_save: 28 | dataset[obj].save("_v1") 29 | dataset[obj].unload() 30 | if i % 10 == 0: 31 | print(f'{obj}: {i}/{len(objs)}') 32 | 33 | def get_objs(): 34 | dataset = Dataset.Dataset('./data/objects') 35 | objs = dataset.get_object_ids() 36 | return objs 37 | 38 | def main(): 39 | objs = get_objs() 40 | random.shuffle(objs) 41 | # tasks = [] 42 | # NUM_PROCESS = 4 43 | # for i in range(NUM_PROCESS): 44 | # tasks.append(objs[i::NUM_PROCESS]) 45 | # pool = Pool(NUM_PROCESS) 46 | # pool.map(feature_extraction_objs, tasks) 47 | # pool.close() 48 | 49 | feature_extraction_objs(objs) 50 | 51 | if __name__ == '__main__': 52 | ''' 53 | usage: torchrun --nproc_per_node 1 -m encoder.language_encode 54 | ''' 55 | main() -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/affordance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | # from dataset.Dataset import Config 5 | from model.model import Net 6 | from model.modelGlobal import NetG 7 | import logging 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | import torch.multiprocessing as mp 11 | from model.utils import plot, save_loss 12 | from model.data_utils import get_dataloader, get_dataloader_special, ramdom_sample_pos 13 | import argparse 14 | import time 15 | import os 16 | import pickle 17 | import trimesh 18 | import json 19 | import numpy as np 20 | 21 | mp.set_start_method('spawn', force=True) 22 | VISION_LOCAL_THRESHOLD = 1000.0 23 | 24 | 25 | 26 | # KL Divergence 27 | # kld(map2||map1) -- map2 is gt 28 | def KLD(map1, map2, eps = 1e-12): 29 | map1, map2 = map1/(map1.sum()+eps), map2/(map2.sum() + eps) 30 | kld = np.sum(map2*np.log( map2/(map1+eps) + eps)) 31 | return kld 32 | 33 | # historgram intersection 34 | def SIM(map1, map2, eps=1e-12): 35 | map1, map2 = map1/(map1.sum()+eps), map2/(map2.sum() + eps) 36 | intersection = np.minimum(map1, map2) 37 | return np.sum(intersection) 38 | 39 | def AUC_Judd(saliency_map, fixation_map, jitter=True): 40 | saliency_map = np.array(saliency_map, copy=False) 41 | fixation_map = np.array(fixation_map, copy=False) > 0.5 42 | # If there are no fixation to predict, return NaN 43 | if not np.any(fixation_map): 44 | return np.nan 45 | # Make the saliency_map the size of the fixation_map 46 | if saliency_map.shape != fixation_map.shape: 47 | saliency_map = resize(saliency_map, fixation_map.shape) 48 | # Jitter the saliency map slightly to disrupt ties of the same saliency value 49 | if jitter: 50 | saliency_map += np.random.rand(*saliency_map.shape) * 1e-7 51 | # Normalize saliency map to have values between [0,1] 52 | saliency_map = (saliency_map - np.min(saliency_map)) / (np.max(saliency_map) - np.min(saliency_map) + 1e-12) 53 | 54 | S = saliency_map.ravel() 55 | F = fixation_map.ravel() 56 | S_fix = S[F] # Saliency map values at fixation locations 57 | n_fix = len(S_fix) 58 | n_pixels = len(S) 59 | # Calculate AUC 60 | thresholds = sorted(S_fix, reverse=True) 61 | tp = np.zeros(len(thresholds)+2) 62 | fp = np.zeros(len(thresholds)+2) 63 | tp[0] = 0; tp[-1] = 1 64 | fp[0] = 0; fp[-1] = 1 65 | for k, thresh in enumerate(thresholds): 66 | above_th = np.sum(S >= thresh) # Total number of saliency map values above threshold 67 | tp[k+1] = (k + 1) / float(n_fix) # Ratio saliency map values at fixation locations above threshold 68 | fp[k+1] = (above_th - k - 1) / float(n_pixels - n_fix) # Ratio other saliency map values above threshold 69 | return np.trapz(tp, fp) # y, x 70 | 71 | def main(params): 72 | kld_list = [] 73 | sim_list = [] 74 | auc_list = [] 75 | kld_list_vgn = [] 76 | sim_list_vgn = [] 77 | auc_list_vgn = [] 78 | vgn_path = 'data_vgn_map.pickle' 79 | with open(vgn_path, 'rb') as f: 80 | vgn_vol_map = pickle.load(f) 81 | print(len(vgn_vol_map)) 82 | 83 | for i in range(303): 84 | maps_path = f'./checkpoints/maps_{params["model_id"]}_{i}.pkl' 85 | if os.path.exists(maps_path) == False: 86 | print(f'File {maps_path} does not exist') 87 | continue 88 | with open(maps_path, 'rb') as f: 89 | maps = pickle.load(f) 90 | gt = maps['grasp_map'] # (batch_size, 2048), 100 91 | pred = maps['prediction'] 92 | pc = maps['point_cloud'] 93 | entry_paths = maps['entry_path'] 94 | pred[pred < 0.] = 0. 95 | pred[pred > 1.] = 1. 96 | assert (pred >= 0.).all(), 'Negative value in prediction' 97 | assert (pred <= 1.).all(), 'Value greater than 1 in prediction' 98 | for k in range(gt.shape[0]): 99 | kld_list.append(KLD(pred[k], gt[k])) 100 | sim_list.append(SIM(pred[k], gt[k])) 101 | auc_list.append(AUC_Judd(pred[k], gt[k])) 102 | object_id = entry_paths[k].split('/')[-2] 103 | vgn_vol = vgn_vol_map[object_id] 104 | vgn_pred = map_convert(vgn_vol, pc[k]) 105 | kld_list_vgn.append(KLD(vgn_pred, gt[k])) 106 | sim_list_vgn.append(SIM(vgn_pred, gt[k])) 107 | auc_list_vgn.append(AUC_Judd(vgn_pred, gt[k])) 108 | print(f'KL Divergence: {np.mean(kld_list)}') 109 | print(f'Histogram Intersection: {np.mean(sim_list)}') 110 | print(f'AUC Judd: {np.mean(auc_list)}') 111 | 112 | print(f'KL Divergence VGN: {np.mean(kld_list_vgn)}') 113 | print(f'Histogram Intersection VGN: {np.mean(sim_list_vgn)}') 114 | print(f'AUC Judd VGN: {np.mean(auc_list_vgn)}') 115 | 116 | def map_convert(vgn_vol, point_cloud): 117 | if len(vgn_vol) == 0: 118 | return np.ones(point_cloud.shape[0]) / point_cloud.shape[0] 119 | # normalize point cloud 120 | point_cloud = point_cloud - point_cloud.min(0) 121 | point_cloud = point_cloud / point_cloud.max(0) 122 | # voxel grid: 40x40x40 123 | heatmap = np.zeros(point_cloud.shape[0]) 124 | point_cloud = point_cloud * 39 125 | for i in range(point_cloud.shape[0]): 126 | x, y, z = point_cloud[i].astype(int) 127 | heatmap[i] = vgn_vol[x, y, z] 128 | # heatmap[heatmap < 0.] = 0. 129 | # heatmap[heatmap > 1.] = 1. 130 | if heatmap.sum() == 0: 131 | return np.ones(point_cloud.shape[0]) / point_cloud.shape[0] 132 | heatmap = heatmap / heatmap.sum() 133 | return heatmap 134 | 135 | if __name__ == '__main__': 136 | # vgn_path = 'data_vgn_map.pickle' 137 | # with open(vgn_path, 'rb') as f: 138 | # vgn = pickle.load(f) 139 | # print(len(vgn)) 140 | # for k, v in vgn.items(): 141 | # print(len(v)) 142 | # exit() 143 | argparser = argparse.ArgumentParser() 144 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 145 | argparser.add_argument('--epoch_id', type=int, default=16) 146 | argparser.add_argument('--model_id', type=str, default='') 147 | argparser.add_argument('--global', default=False, action='store_true') 148 | argparser.add_argument('--gt', default=False, action='store_true') 149 | 150 | args = argparser.parse_args() 151 | params = vars(args) 152 | main(params) -------------------------------------------------------------------------------- /evaluation/demo_graspnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | import trimesh 7 | 8 | from graspnetAPI import GraspGroup 9 | 10 | from evaluation.graspnet_baseline.models.graspnet import GraspNet, pred_decode 11 | from evaluation.graspnet_baseline.utils.collision_detector import ModelFreeCollisionDetector 12 | 13 | 14 | def get_net(): 15 | # Init the model 16 | net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4, 17 | cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01, 0.02, 0.03, 0.04], is_training=False) 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | net.to(device) 20 | # Load checkpoint 21 | checkpoint = torch.load(cfgs.checkpoint_path) 22 | net.load_state_dict(checkpoint['model_state_dict']) 23 | start_epoch = checkpoint['epoch'] 24 | print("-> loaded checkpoint %s (epoch: %d)" % (cfgs.checkpoint_path, start_epoch)) 25 | # set model to eval mode 26 | net.eval() 27 | return net 28 | 29 | 30 | def get_grasps(net, end_points): 31 | # Forward pass 32 | with torch.no_grad(): 33 | end_points = net(end_points) 34 | grasp_preds = pred_decode(end_points) 35 | gg_array = grasp_preds[0].detach().cpu().numpy() 36 | gg = GraspGroup(gg_array) 37 | return gg 38 | 39 | 40 | def collision_detection(gg, cloud): 41 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size) 42 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 43 | gg = gg[~collision_mask] 44 | return gg 45 | 46 | def gg_to_positions(p, r, w, d): 47 | p1 = [] 48 | p2 = [] 49 | for i in range(len(p)): 50 | left_offset = np.array([d[i], -w[i]/2, 0]) 51 | right_offset = np.array([d[i], w[i]/2, 0]) 52 | p1.append(p[i] + np.dot(r[i], left_offset)) 53 | p2.append(p[i] + np.dot(r[i], right_offset)) 54 | grasp_positions = np.hstack((p1, p2)) # (n, 6) 55 | return grasp_positions 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path') 61 | parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]') 62 | parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]') 63 | parser.add_argument('--collision_thresh', type=float, default=0.01, 64 | help='Collision Threshold in collision detection [default: 0.01]') 65 | parser.add_argument('--voxel_size', type=float, default=0.01, 66 | help='Voxel Size to process point clouds before collision detection [default: 0.01]') 67 | cfgs = parser.parse_args() 68 | device = torch.device("cuda:0") 69 | 70 | objects_path = '/home/gdk/Repositories/DualArmManipulation/demo/demo_objects' 71 | 72 | net = get_net() 73 | 74 | object_name = 'banana' 75 | mesh = trimesh.load(f'{objects_path}/{object_name}/{object_name}.obj') 76 | points = trimesh.sample.sample_surface(mesh, 2048)[0] 77 | print(points.shape) 78 | 79 | end_points = { 80 | 'point_clouds': torch.Tensor(points).unsqueeze(0).to(device), 81 | } 82 | 83 | gg = get_grasps(net, end_points) 84 | if cfgs.collision_thresh > 0: 85 | gg = collision_detection(gg, np.array(points)) 86 | gg.nms() 87 | gg.sort_by_score() 88 | gg = gg[:10] 89 | 90 | positions = gg_to_positions(gg.translations, gg.rotation_matrices, gg.widths, gg.depths) 91 | 92 | print(positions) 93 | 94 | pickle.dump(positions, open(f'{objects_path}/{object_name}/graspnet_grasps.pkl', 'wb')) 95 | 96 | -------------------------------------------------------------------------------- /evaluation/generate_eval_table.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import pickle 6 | 7 | from dataset.Dataset import Dataset 8 | 9 | 10 | if __name__ == '__main__': 11 | # modes = ['ours_1706605305.8925593', 'ours_1706613034.3918543', 12 | # 'ours_1706620680.2738318', 'ours_1706628441.2965753'] 13 | modes = ['ours_1706605305.8925593wol'] 14 | # modes = ['vgn'] 15 | save_names = [] 16 | grasp_datas = [] 17 | for mode in modes: 18 | if mode == 'analytical_pos': 19 | save_name = 'analytical_pos_not_tested' 20 | grasp_data = {} 21 | elif mode == 'analytical_neg': 22 | save_name = 'analytical_neg_not_tested' 23 | grasp_data = {} 24 | elif mode == 'graspnet': 25 | save_name = 'graspnet_not_tested' 26 | grasp_data = pickle.load(open('./evaluation/results/graspnet.pickle', 'rb')) 27 | elif mode == 'vgn': 28 | save_name = 'vgn_not_tested' 29 | grasp_data = pickle.load(open('./data_vgn_grasp.pickle', 'rb')) 30 | elif 'ours' in mode: 31 | model_id = mode.split('_')[1] 32 | save_name = f'ours_{model_id}_not_tested' 33 | grasp_data = {} 34 | for i in range(303): 35 | if not os.path.exists(f'./checkpoints/maps_{model_id}_{i}.pkl'): 36 | continue 37 | test_dataset = pickle.load(open(f'./checkpoints/maps_{model_id}_{i}.pkl', 'rb')) 38 | if 'pos' not in test_dataset: 39 | print('no pos in map', i) 40 | continue 41 | paths = test_dataset['entry_path'] 42 | positions = test_dataset['pos'] 43 | assert len(paths) == len(positions) 44 | for path, position in zip(paths, positions): 45 | object_id = path.split('/')[-2] 46 | config_id = path.split('/')[-1].split('_')[1] 47 | if 'pkl' in config_id: 48 | continue 49 | grasp_data[(object_id, config_id)] = position #(n, 6) 50 | save_names.append(save_name) 51 | grasp_datas.append(grasp_data) 52 | 53 | 54 | source_df = pd.read_csv('./data/dataset/test_object_config_ids.csv', dtype = str).sort_values(['object_id'], ignore_index=True)#[:10] 55 | 56 | dfs = [] 57 | for _ in range(len(modes)): 58 | df = pd.DataFrame( 59 | index=range(len(source_df) * 10), 60 | columns=['object_id', 'config_id', 'f1', 'p1', 'n1', 'f2', 'p2', 'n2', 'obj_mass', 'top_n', 'success'] 61 | ) 62 | df['top_n'] = -1 63 | df['success'] = -1 64 | df = df.astype({'object_id': str, 'config_id': str, 'f1': float, 'p1': object, 'n1': object, 'f2': float, 'p2': object, 'n2': object, 'obj_mass': float, 'top_n': int, 'success': int}) 65 | dfs.append(df) 66 | 67 | # print(df) 68 | 69 | dataset = Dataset('/home/gdk/Repositories/DualArmManipulation/data/objects') 70 | 71 | prev_object_id = -1 72 | for i in tqdm(range(len(source_df))): 73 | object_id, config_id = source_df.loc[i, ['object_id', 'config_id']] 74 | 75 | if prev_object_id != object_id: 76 | if prev_object_id != -1: 77 | dataset[prev_object_id].unload() 78 | dataset[object_id].load('_v1') 79 | prev_object_id = object_id 80 | 81 | data_entry = dataset[object_id].data[config_id] 82 | frictions = data_entry.config.frictions 83 | 84 | mass_center = data_entry.mass_center 85 | mass_center_shift = np.zeros(14) 86 | mass_center_shift[1:4] = mass_center 87 | mass_center_shift[8:11] = mass_center 88 | 89 | for m, mode in enumerate(modes): 90 | if mode == 'analytical_pos': 91 | grasps = data_entry.pos_grasps[:10] + mass_center_shift 92 | elif mode == 'analytical_neg': 93 | grasps = data_entry.neg_grasps[:10] + mass_center_shift 94 | elif mode == 'graspnet': 95 | grasps = grasp_datas[m][object_id] 96 | if len(grasps) == 0: 97 | print(object_id, 'has no grasps') 98 | elif mode == 'vgn': 99 | grasps = grasp_datas[m][object_id] 100 | if len(grasps) == 0: 101 | print(object_id, 'has no grasps') 102 | 103 | if 'ours' not in mode: 104 | for j in range(min(10, len(grasps))): 105 | grasp = grasps[j] 106 | dfs[m].at[i * 10 + j, 'object_id'] = object_id 107 | dfs[m].at[i * 10 + j, 'config_id'] = config_id 108 | dfs[m].at[i * 10 + j, 'f1'] = frictions[int(grasp[0])] 109 | dfs[m].at[i * 10 + j, 'p1'] = grasp[1:4] 110 | dfs[m].at[i * 10 + j, 'n1'] = grasp[4:7] 111 | dfs[m].at[i * 10 + j, 'f2'] = frictions[int(grasp[7])] 112 | dfs[m].at[i * 10 + j, 'p2'] = grasp[8:11] 113 | dfs[m].at[i * 10 + j, 'n2'] = grasp[11:14] 114 | dfs[m].at[i * 10 + j, 'obj_mass'] = sum(data_entry.config.masses) 115 | elif 'ours' in mode: 116 | if (object_id, config_id) in grasp_datas[m]: 117 | grasp_positions = grasp_datas[m][(object_id, config_id)] # (n, 6) 118 | for j in range(min(10, len(grasp_positions))): 119 | dfs[m].at[i * 10 + j, 'object_id'] = object_id 120 | dfs[m].at[i * 10 + j, 'config_id'] = config_id 121 | dfs[m].at[i * 10 + j, 'p1'] = grasp_positions[j, 0:3] 122 | dfs[m].at[i * 10 + j, 'p2'] = grasp_positions[j, 3:6] 123 | dfs[m].at[i * 10 + j, 'obj_mass'] = sum(data_entry.config.masses) 124 | 125 | for j in range(10): 126 | dfs[m].at[i * 10 + j, 'top_n'] = j + 1 127 | 128 | for m, mode in enumerate(modes): 129 | dfs[m].to_pickle(f'./evaluation/results/{save_names[m]}.pkl') 130 | -------------------------------------------------------------------------------- /evaluation/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import os 5 | import sys 6 | import numpy as np 7 | import pybullet as p 8 | import pybullet_planning as pp 9 | import trimesh 10 | from tqdm import tqdm 11 | import pandas as pd 12 | import argparse 13 | 14 | from dataset.Dataset import Dataset 15 | from dataset.generate_data import show_grasp_heatmap 16 | from dataset.utils import compute_part_ids 17 | 18 | 19 | 20 | def test_grasp(grasps, mesh, obj_mass=None): 21 | pp.reset_simulation() 22 | p.setGravity(0, 0, -9.81) 23 | 24 | obj_mass = obj_mass if obj_mass else default_obj_mass 25 | 26 | if visualize: 27 | p.addUserDebugLine([0, 0, 0], [1, 0, 0], [1, 0, 0]) 28 | p.addUserDebugLine([0, 0, 0], [0, 1, 0], [0, 1, 0]) 29 | p.addUserDebugLine([0, 0, 0], [0, 0, 1], [0, 0, 1]) 30 | 31 | fingers = [] 32 | for g in grasps: 33 | pos = g['pos'] - g['normal'] * sphere_radius 34 | f = g['friction'] 35 | sphere_collision = p.createCollisionShape(p.GEOM_SPHERE, radius=sphere_radius) 36 | sphere_visual = p.createVisualShape(p.GEOM_SPHERE, radius=sphere_radius, rgbaColor=color) 37 | finger = p.createMultiBody(sphere_mass, sphere_collision, sphere_visual, basePosition=pos) 38 | p.changeDynamics(finger, -1, lateralFriction=f, spinningFriction=f, rollingFriction=f) 39 | fingers.append(finger) 40 | 41 | # Load object 42 | # mesh = trimesh.load(obj_urdf_path) 43 | vertices = mesh.vertices 44 | faces = mesh.faces 45 | indices = faces.reshape(-1) 46 | objId = p.createCollisionShape(p.GEOM_MESH, vertices=vertices, indices=indices) 47 | obj = p.createMultiBody(baseMass=obj_mass, baseCollisionShapeIndex=objId, basePosition=[0, 0, 0], baseOrientation=[0, 0, 0, 1]) 48 | p.changeDynamics(obj, -1, lateralFriction=objfriction, spinningFriction=objfriction, rollingFriction=objfriction) 49 | 50 | 51 | # p.loadURDF(obj_urdf_path, [0, -0.5, 0], globalScaling=10) 52 | 53 | # force_magnitude = obj_mass * 9.81 * 10 54 | force_magnitude = sphere_mass * 9.81 * 10 55 | 56 | for i in range(100): 57 | for finger, grasp in zip(fingers, grasps): 58 | # Gravity compensation force 59 | force = -sphere_mass * np.array([0, 0, -9.81]) 60 | # Apply force in contact point in the direction of the contact normal 61 | finger_pos, finger_quat = p.getBasePositionAndOrientation(finger) 62 | p.resetBasePositionAndOrientation(finger, finger_pos, [0, 0, 0, 1]) 63 | 64 | force += np.array(force_magnitude * grasp['normal'] / np.linalg.norm(grasp['normal'])) 65 | force += np.array(sphere_mass * 9.81 * (grasp['pos'] - finger_pos)) 66 | p.applyExternalForce(finger, -1, force, grasp['pos'], p.WORLD_FRAME) 67 | # p.applyExternalTorque(finger, -1, [0, 0, 0]) 68 | 69 | pp.step_simulation() 70 | if visualize: 71 | time.sleep(0.02) 72 | 73 | distanceThreshold = 0.01 # Large number to ensure all closest points are found 74 | closestPoints_1 = p.getClosestPoints(fingers[0], obj, distanceThreshold) 75 | closestPoints_2 = p.getClosestPoints(fingers[1], obj, distanceThreshold) 76 | success = True if closestPoints_1 and closestPoints_2 else False 77 | 78 | return success 79 | 80 | 81 | def print_stats(df): 82 | n_success = len(df[(df['success'] == 1) & (df['top_n'] == 1)]) 83 | n_total = len(df[(df['success'] != -1) & (df['top_n'] == 1)]) 84 | print(f'Top_1 Success rate: {n_success / n_total} ({n_success}/{n_total})') 85 | 86 | if top_n > 1: 87 | success_mask = df[df['top_n'] == 1]['success'] == 2 # all False 88 | success_mask = success_mask.to_numpy() 89 | for i in range(top_n): 90 | success_mask_i = df[df['top_n'] == i + 1]['success'] == 1 91 | success_mask_i = success_mask_i.to_numpy() 92 | success_mask = success_mask | success_mask_i 93 | n_success = success_mask.sum() 94 | print(f'Top_{top_n} Success rate: {n_success / n_total} ({n_success}/{n_total})') 95 | 96 | 97 | if __name__ == '__main__': 98 | # mode = 'graspnet' # 0.402 99 | # mode = 'analytical_pos' # 0.873 100 | # mode = 'analytical_neg' # 0.518 101 | mode = 'ours_1706493574.6243873' # 0.673 102 | # mode = 'random' 103 | 104 | argparser = argparse.ArgumentParser() 105 | argparser.add_argument('--mode', '-m', type=str, default=mode) 106 | argparser.add_argument('--top_n', type=int, default=1) 107 | argparser.add_argument('--n_test_config', '-n', type=int, default=1000) 108 | args = argparser.parse_args() 109 | 110 | mode = args.mode 111 | top_n = args.top_n 112 | n_test_config = args.n_test_config 113 | n_test = n_test_config * top_n 114 | useMass = True 115 | skip_tested = True 116 | source_path = f'./evaluation/results/{mode}_not_tested.pkl' 117 | save_path = f'./evaluation/results/{mode}.pkl' 118 | load_path = save_path if os.path.exists(save_path) else source_path 119 | 120 | visualize = False 121 | sphere_radius = 0.05 122 | sphere_mass = 1000 if useMass else 1 123 | default_obj_mass = 1 124 | color = [1, 0, 0, 1] 125 | downsample = True 126 | objfriction = 0.5 127 | 128 | pp.connect(use_gui=visualize) 129 | start_time = time.time() 130 | dataset = Dataset('/home/gdk/Repositories/DualArmManipulation/data/objects') 131 | df = pd.read_pickle(load_path) 132 | 133 | # print(df[:100]) 134 | 135 | prev_object_id = -1 136 | mesh = None 137 | to_test = df[df['top_n'] <= top_n][:n_test] 138 | not_tested = to_test[to_test['success'] == -1] 139 | effective_test_index = not_tested.index if skip_tested else to_test.index 140 | num_tested = 0 141 | for i in tqdm(effective_test_index): 142 | object_id = df.loc[i, 'object_id'] 143 | if object_id == 'nan': 144 | print(f'object {object_id} has no grasp at top', df.loc[i, 'top_n']) 145 | df.at[i, 'success'] = 0 146 | continue 147 | if prev_object_id != object_id: 148 | meshes = dataset[object_id].load_meshes() 149 | mesh = trimesh.util.concatenate(dataset[object_id].load_meshes()) 150 | if downsample: 151 | mesh = mesh.simplify_quadric_decimation(5000) 152 | prev_object_id = object_id 153 | 154 | f1, p1, n1, f2, p2, n2 = df.loc[i, ['f1', 'p1', 'n1', 'f2', 'p2', 'n2']] 155 | obj_mass = df.loc[i, 'obj_mass'] if useMass and 'obj_mass' in df.columns else None 156 | 157 | if 'ours' in mode and np.isnan(f1): 158 | dataset[object_id].load('_v1') 159 | data_entry = dataset[object_id].data[df.loc[i, 'config_id']] 160 | frictions = data_entry.config.frictions 161 | obj_mass = sum(data_entry.config.masses) 162 | part_ids = compute_part_ids(np.array([p1, p2]), meshes) 163 | norm = (p1- p2) / np.linalg.norm(p1 - p2) 164 | df.at[i, 'f1'] = f1 = frictions[part_ids[0]] 165 | df.at[i, 'f2'] = f2 = frictions[part_ids[1]] 166 | df.at[i, 'n1'] = n1 = norm 167 | df.at[i, 'n2'] = n2 = -norm 168 | dataset[object_id].unload() 169 | 170 | if mode == 'random': 171 | samples, _ = trimesh.sample.sample_surface(mesh, 2) 172 | df.at[i, 'p1'] = p1 = samples[0] * 1.1 173 | df.at[i, 'p2'] = p2 = samples[1] * 1.1 174 | df.at[i, 'n1'] = n1 = (p1 - p2) / np.linalg.norm(p1 - p2) 175 | df.at[i, 'n2'] = n2 = -n1 176 | part_ids = compute_part_ids(np.array([p1, p2]), meshes) 177 | dataset[object_id].load('_v1') 178 | data_entry = dataset[object_id].data[df.loc[i, 'config_id']] 179 | frictions = data_entry.config.frictions 180 | df.at[i, 'f1'] = f1 = frictions[part_ids[0]] 181 | df.at[i, 'f2'] = f2 = frictions[part_ids[1]] 182 | dataset[object_id].unload() 183 | 184 | grasps = [ 185 | { 186 | 'pos': p1, 187 | 'normal': n1, 188 | 'friction': f1 ** 2 / objfriction, 189 | }, 190 | { 191 | 'pos': p2, 192 | 'normal': n2, 193 | 'friction': f2 ** 2 / objfriction, 194 | } 195 | ] 196 | 197 | try: 198 | success = test_grasp(grasps, mesh, obj_mass) 199 | df.at[i, 'success'] = 1 if success else 0 200 | except Exception as e: 201 | print(i, object_id) 202 | print(e) 203 | df.at[i, 'success'] = -2 204 | 205 | if visualize: 206 | print(i, object_id, df.loc[i, 'config_id'], df.loc[i, 'top_n'], df.at[i, 'success']) 207 | 208 | num_tested += 1 209 | 210 | if num_tested % 1000 == 0: 211 | tt = time.time() 212 | df.to_pickle(save_path) 213 | print(f'Saved in {time.time() - tt} seconds') 214 | print_stats(df) 215 | 216 | 217 | tt = time.time() 218 | df.to_pickle(save_path) 219 | print(f'Saved in {time.time() - tt} seconds') 220 | 221 | tt = time.time() 222 | print_stats(df) 223 | -------------------------------------------------------------------------------- /evaluation/results: -------------------------------------------------------------------------------- 1 | /home/gdk/Data/bimanual/evaluation_results -------------------------------------------------------------------------------- /evaluation/run.py: -------------------------------------------------------------------------------- 1 | # run "python -m evaluation.main_multi --top_n 5 -n 10000" multiple times 2 | 3 | import os 4 | 5 | if __name__ == '__main__': 6 | total = 10000 7 | every_time = 800 8 | start = 2000 9 | for n in range(start, total + every_time, every_time): 10 | print(f'n = {n}') 11 | try: 12 | os.system(f"python -m evaluation.main_multi --top_n 5 -n {n}") 13 | except KeyboardInterrupt: 14 | break 15 | 16 | -------------------------------------------------------------------------------- /evaluation/test_hardset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import time 4 | import logging 5 | import os 6 | import sys 7 | from multiprocessing import Pool 8 | 9 | import numpy as np 10 | import trimesh 11 | from tqdm import tqdm 12 | import pybullet as p 13 | import pandas as pd 14 | import argparse 15 | 16 | from dataset.Dataset import Dataset 17 | from dataset.generate_data import show_grasp_heatmap 18 | from dataset.utils import compute_part_ids 19 | from evaluation.hardset_list import entry_paths 20 | 21 | 22 | if __name__ == '__main__': 23 | entry_paths = pickle.load(open('./data/dataset/hard_entry_paths_q2.pkl', 'rb')) 24 | # entry_paths = pickle.load(open('./evaluation/easy_entry_paths_t.pkl', 'rb')) 25 | print('len(entry_paths)', len(entry_paths)) 26 | obj_configs = [] 27 | for path in entry_paths: 28 | object_id = path.split('/')[-2] 29 | config_id = path.split('/')[-1].split('_')[1] 30 | obj_configs.append((object_id, config_id)) 31 | 32 | modes = ['analytical_pos', 'ours_1706605305.8925593', 33 | 'ours_1706613034.3918543', 'ours_1706620680.2738318', 'ours_1706628441.2965753', 34 | 'graspnet', 'vgn'] 35 | 36 | top_n = 5 37 | 38 | dfs = [] 39 | save_paths = [] 40 | for mode in modes: 41 | save_path = f'/home/gdk/Repositories/DualArmManipulation/evaluation/results/{mode}_v2.pkl' 42 | save_paths.append(save_path) 43 | load_path = save_path 44 | print(f'Loading {load_path}') 45 | df = pd.read_pickle(load_path) 46 | dfs.append(df) 47 | 48 | test_indices = [] 49 | 50 | for object_id, config_id in obj_configs: 51 | df = dfs[0] 52 | ii = df[(df['object_id'] == object_id) & (df['config_id'] == config_id)].index.to_list() 53 | test_indices.extend(ii) 54 | 55 | 56 | for m, mode in enumerate(modes): 57 | print(mode) 58 | df = dfs[m].loc[test_indices] 59 | n_success = len(df[(df['success'] == 1) & (df['top_n'] == 1)]) 60 | n_total = len(df[(df['success'] > -1) & (df['top_n'] == 1)]) 61 | print(f'Top_1 Success rate: {n_success / n_total} ({n_success}/{n_total})') 62 | 63 | if top_n > 1: 64 | success_mask = df[df['top_n'] == 1]['success'] == 2 # all False 65 | success_mask = success_mask.to_numpy() 66 | for i in range(top_n): 67 | success_mask_i = df[df['top_n'] == i + 1]['success'] == 1 68 | success_mask_i = success_mask_i.to_numpy() 69 | success_mask = success_mask | success_mask_i 70 | n_success = success_mask.sum() 71 | print(f'Top_{top_n} Success rate: {n_success / n_total} ({n_success}/{n_total})') 72 | -------------------------------------------------------------------------------- /evaluation/test_to_train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | if __name__ == '__main__': 7 | seed = 100 8 | np.random.seed(seed) 9 | 10 | df = pd.read_pickle('/home/gdk/Repositories/DualArmManipulation/evaluation/results/analytical_pos_not_tested.pkl') 11 | print(len(df), len(df) // 10) 12 | max_index = len(df) // 10 13 | indices = np.random.choice(max_index, 10000, replace=False) * 10 14 | print(len(indices), min(indices), max(indices)) 15 | df_test = df.loc[indices] 16 | entry_paths = [f'./data/objects/{object_id}/{object_id}_{config_id}_v1.pkl' for object_id, config_id in df_test[['object_id', 'config_id']].values] 17 | print(len(entry_paths)) 18 | print(len(set(entry_paths))) 19 | print(entry_paths[:10]) 20 | 21 | # pickle.dump(entry_paths, open('/home/gdk/Repositories/DualArmManipulation/data/dataset/test_entry_paths_v2.pkl', 'wb')) 22 | # pickle.dump(indices, open('/home/gdk/Repositories/DualArmManipulation/data/dataset/test_indices.pkl', 'wb')) 23 | 24 | 25 | -------------------------------------------------------------------------------- /evaluation/test_vgn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from scipy.spatial.transform import Rotation as R 4 | import sys 5 | sys.path.append('/home/gdk/Repositories/DualArmManipulation') 6 | from vgn.detection import VGN 7 | from vgn.experiments import clutter_removal 8 | from vgn.perception import * 9 | from dataset.Dataset import Dataset 10 | import json 11 | import numpy as np 12 | import trimesh 13 | import pickle 14 | import pyrender 15 | from matplotlib import pyplot as plt 16 | from IPython import embed 17 | 18 | 19 | def main(args): 20 | 21 | if args.rviz or str(args.model) == "gpd": 22 | import rospy 23 | 24 | rospy.init_node("sim_grasp", anonymous=True) 25 | 26 | if str(args.model) == "gpd": 27 | from vgn.baselines import GPD 28 | 29 | grasp_planner = GPD() 30 | else: 31 | grasp_planner = VGN(args.model, rviz=args.rviz) 32 | 33 | count = 0 34 | count_hun = 0 35 | pickle_file_path = 'data_vgn_grasp.pickle' 36 | pickle_file_path_2 = 'data_vgn_map.pickle' 37 | dataset = Dataset('/home/gdk/Repositories/DualArmManipulation/data/objects') 38 | json_file_path = '/home/gdk/Repositories/DualArmManipulation/data/dataset/test_object_ids.json' 39 | # image = np.random.rand(480, 640).astype(np.float32) 40 | # with open(json_file_path, 'r') as file: 41 | # object_ids = json.load(file) 42 | object_ids = ['10519', '9000', '4931', '11405', '13214', '2056', '5174', '10850', '16208'] 43 | data_grasp_all = dict() 44 | data_vol_all = dict() 45 | fx, fy, cx, cy = 540.0, 540.0, 320.0, 240.0 46 | camera_in = pyrender.camera.IntrinsicsCamera(fx, fy, cx, cy, znear=0.01, zfar=2) 47 | intrinsic = CameraIntrinsic(640, 480, fx, fy, cx, cy) 48 | 49 | for object_id in object_ids: 50 | count = count + 1 51 | meshes = dataset[object_id].load_meshes() 52 | 53 | # TODO add transformation of the meshes here with mesh.apply_transform() 54 | 55 | grasps, scores, grasp_info, vol = clutter_removal.run_baseline( 56 | grasp_plan_fn=grasp_planner, 57 | meshes=meshes, 58 | intrinsic=intrinsic, 59 | camera_in =camera_in, 60 | ) 61 | if len(scores) != 0: 62 | # best_g_p, best_g_r = grasps[0].pose.translation, grasps[0].pose.rotation.as_matrix() 63 | # data_all[object_id]['grasps'] = data 64 | # data_all[object_id]['map'] = scores 65 | # data_all[object_id]['grasp_info'] = grasp_info 66 | data_grasp_all[object_id] = grasp_info 67 | data_vol_all[object_id] = vol 68 | print(object_id) 69 | print("grasp_info", np.array(grasp_info).shape) 70 | print("vol", np.array(vol).shape) 71 | else: 72 | data_grasp_all[object_id] = [] 73 | data_vol_all[object_id] = [] 74 | if count%100 == 0: 75 | count_hun = count_hun + 1 76 | print("finished processing ", count_hun*100) 77 | count = 0 78 | with open(pickle_file_path, 'wb') as file: 79 | pickle.dump(data_grasp_all, file) 80 | print("finish writing grasp file") 81 | with open(pickle_file_path_2, 'wb') as file_2: 82 | pickle.dump(data_vol_all, file_2) 83 | print("finish writing map file") 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--model", type=Path, default="./evaluation/vgn/data/models/vgn_conv.pth") 88 | parser.add_argument("--logdir", type=Path, default="data/experiments") 89 | parser.add_argument("--description", type=str, default="") 90 | parser.add_argument("--scene", type=str, choices=["pile", "packed"], default="pile") 91 | parser.add_argument("--object-set", type=str, default="blocks") 92 | parser.add_argument("--num-objects", type=int, default=5) 93 | parser.add_argument("--num-rounds", type=int, default=100) 94 | parser.add_argument("--seed", type=int, default=42) 95 | parser.add_argument("--sim-gui", action="store_true") 96 | parser.add_argument("--rviz", action="store_true") 97 | args = parser.parse_args() 98 | main(args) 99 | # pickle_file_path = 'object_count.pkl' 100 | # with open(pickle_file_path, 'rb') as file: 101 | # loaded_dict = pickle.load(file) 102 | # print(loaded_dict) 103 | # with open('data2.csv', 'w') as f: 104 | # [f.write('{0}\n'.format(key)) for key, value in loaded_dict.items()] 105 | # [f.write('{0}\n'.format(value)) for key, value in loaded_dict.items()] 106 | # import pandas as pd 107 | # df = pd.DataFrame(loaded_dict) 108 | # df.to_csv('my_file.csv', index=False, header=True) -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | 4 | from dataset.generate_data import min_distances 5 | from grasp.generate_grasp import find_contact_points_multi 6 | 7 | 8 | def vis_grasp_heatmap(mesh, contact_pairs): 9 | contact_locs = contact_pairs.reshape(-1, 3) 10 | dists = min_distances(mesh.vertices, contact_locs) 11 | # TODO: change to Gaussian 12 | mesh.visual.vertex_colors = trimesh.visual.color.interpolate(np.sqrt(dists), color_map='hot') 13 | mesh.visual.vertex_colors[:, 3] = 0.8 * 255 14 | scene_list = [mesh] 15 | 16 | for contact_point, another_contact_point in contact_pairs: 17 | # c1 = trimesh.creation.uv_sphere(radius=0.005) 18 | # c2 = trimesh.creation.uv_sphere(radius=0.005) 19 | # c1.vertices += contact_point 20 | # c2.vertices += another_contact_point 21 | grasp_axis = trimesh.creation.cylinder(0.005, sections=6, 22 | segment=np.vstack([contact_point, another_contact_point])) 23 | grasp_axis.visual.vertex_colors = [0, 0., 1.] 24 | # c1.visual.vertex_colors = [1., 0, 0] 25 | # c2.visual.vertex_colors = [1., 0, 0] 26 | # scene_list += [c1, c2, grasp_axis] 27 | scene_list += [grasp_axis] 28 | 29 | trimesh.Scene(scene_list).show() 30 | 31 | 32 | if __name__ == '__main__': 33 | meshes = [] 34 | for i in range(3): 35 | obj_path = f'./demo_objects/knife/meshes/new-{i}.obj' 36 | meshes.append(trimesh.load(obj_path)) 37 | 38 | frictions = [0.5, 0.2, 0.1] 39 | sample_nums = [100, 20, 10] 40 | 41 | # for i in range(1, 7): 42 | # obj_path = f'./demo_objects/table/meshes/original-{i}.obj' 43 | # meshes.append(trimesh.load(obj_path)) 44 | # 45 | # frictions = [0.5, 0.2, 0.1, 0.1, 0.1, 0.1] 46 | # sample_nums = [100, 20, 10, 10, 10, 10] 47 | 48 | combined_mesh = trimesh.util.concatenate(meshes) 49 | 50 | grasps = find_contact_points_multi(meshes, frictions, sample_nums) 51 | 52 | # vis_grasp(combined_mesh, np.array(grasps, dtype=object)[:, (1, 4)]) 53 | 54 | vis_grasp_heatmap(combined_mesh, grasps[:, np.r_[1:4, 8:11]].reshape(-1, 2, 3)) -------------------------------------------------------------------------------- /grasp/example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | from tqdm import tqdm 4 | 5 | from force_optimization import solve_force 6 | from generate_grasp import find_contact_points, vis_grasp 7 | 8 | if __name__ == '__main__': 9 | """A simple example with stanford bunny.""" 10 | obj_path = 'demo_objects/stanford_bunny.obj' 11 | n_sample_point = 100 12 | friction = 1 13 | contact_rad = 0.05 # used for collision check 14 | 15 | # Load mesh 16 | mesh = trimesh.load(obj_path) 17 | # mesh.vertices -= mesh.center_mass 18 | # mesh.vertices /= np.linalg.norm(mesh.vertices, axis=1).max() 19 | 20 | contacts = find_contact_points(mesh, n_sample_point, friction, contact_rad) 21 | 22 | print(len(contacts), 'contacts found.') 23 | 24 | # vis_grasp 25 | # vis_grasp(mesh, contacts) 26 | 27 | grasps = [] 28 | for contact_points, contact_normals in tqdm(contacts): 29 | force = solve_force(contact_points, contact_normals, friction, np.array([0.0, 0.0, 1, 0.0, 0.0, 0.0]), soft_contact=True) 30 | if force is not None: 31 | grasps.append((contact_points, contact_normals, force)) 32 | 33 | print(grasps) 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /grasp/force_optimization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cvxpy as cp 3 | 4 | from grasp.generate_grasp import parse_grasp 5 | 6 | 7 | def normalize(x): 8 | mag = np.linalg.norm(x) 9 | if mag == 0: 10 | mag = mag + 1e-10 11 | return x / mag 12 | 13 | 14 | def hat(v): 15 | if v.shape == (3, 1) or v.shape == (3,): 16 | return np.array([ 17 | [0, -v[2], v[1]], 18 | [v[2], 0, -v[0]], 19 | [-v[1], v[0], 0] 20 | ]) 21 | else: 22 | raise ValueError 23 | 24 | 25 | def generate_contact_frame(pos, normal): 26 | """Generate contact frame, whose z-axis aligns with the normal direction (inward to the object) 27 | """ 28 | up = normalize(np.random.rand(3)) 29 | z = normalize(normal) 30 | x = normalize(np.cross(up, z)) 31 | y = normalize(np.cross(z, x)) 32 | 33 | result = np.eye(4) 34 | result[0:3, 0] = x 35 | result[0:3, 1] = y 36 | result[0:3, 2] = z 37 | result[0:3, 3] = pos 38 | return result 39 | 40 | 41 | def adj_T(frame): 42 | """Compute the adjoint matrix for the contact frame 43 | """ 44 | assert frame.shape[0] == frame.shape[1] == 4, 'Frame needs to be 4x4' 45 | 46 | R = frame[0:3, 0:3] 47 | p = frame[0:3, 3] 48 | result = np.zeros((6, 6)) 49 | result[0:3, 0:3] = R 50 | result[3:6, 0:3] = hat(p) @ R 51 | result[3:6, 3:6] = R 52 | return result 53 | 54 | 55 | def compute_grasp_map(contact_pos, contact_normal, soft_contact=False): 56 | """ Computes the grasp map for all contact points. 57 | Check chapter 5 of http://www.cse.lehigh.edu/~trink/Courses/RoboticsII/reading/murray-li-sastry-94-complete.pdf for details. 58 | Args: 59 | contact_pos: location of contact in the object frame 60 | contact_normal: surface normals at the contact location, point inward !!!, N x 3, in the object frame 61 | soft_contact: whether use soft contact model. Defaults to False. 62 | Returns: 63 | G: grasp map for the contacts 64 | """ 65 | n_point = len(contact_pos) 66 | 67 | # Compute the contact basis B 68 | if soft_contact: 69 | B = np.zeros((6, 4)) 70 | B[0:3, 0:3] = np.eye(3) 71 | B[5, 3] = 1 72 | else: # use point contact w/ friction 73 | B = np.zeros((6, 3)) 74 | B[0:3, 0:3] = np.eye(3) 75 | 76 | # Compute the contact frames, adjoint matrix, and grasp map 77 | contact_frames = [] 78 | grasp_maps = [] 79 | for pos, normal in zip(contact_pos, contact_normal): 80 | contact_frame = generate_contact_frame(pos, normal) 81 | contact_frames.append(contact_frame) 82 | 83 | adj_matrix = adj_T(contact_frame) 84 | grasp_map = adj_matrix @ B 85 | grasp_maps.append(grasp_map) 86 | 87 | G = np.hstack(grasp_maps) 88 | assert G.shape == (6, n_point * B.shape[1]), 'Grasp map shape does not match' 89 | 90 | return G 91 | 92 | 93 | def solve_force(contact_positions, contact_normals, frictions, weight, soft_contact=False): 94 | w_ext = np.array([0.0, 0.0, weight, 0.0, 0.0, 0.0]) 95 | 96 | num_contact = len(contact_positions) 97 | f = cp.Variable(4 * num_contact) if soft_contact else cp.Variable(3 * num_contact) 98 | s = cp.Variable(1) 99 | 100 | G = compute_grasp_map(contact_pos=contact_positions, contact_normal=contact_normals, soft_contact=soft_contact) 101 | 102 | constraints = [ 103 | G @ f == - w_ext, 104 | s >= -1 105 | ] 106 | 107 | # cp.SOC(t, x) creates the SOC constraint ||x||_2 <= t. 108 | for i in range(num_contact): 109 | constraints += [ 110 | cp.SOC(frictions[i] * (f[3 * i + 2] + s), 111 | f[3 * i: 3 * i + 2]) 112 | ] 113 | 114 | prob = cp.Problem(cp.Minimize(s), constraints) 115 | prob.solve() 116 | 117 | if f.value is None: 118 | # print("Cannot find a feasible solution") 119 | return None 120 | 121 | # print("The optimal value for s is", prob.value) 122 | # print("The optimal value for f is", f.value.reshape(num_contact, -1)) 123 | return f.value.reshape(num_contact, -1) 124 | 125 | 126 | def filter_contact_points_by_force(grasps, part_frictions, part_max_normal_forces, weight): 127 | neg_grasps, pos_grasps = [], [] 128 | for grasp in grasps: 129 | part_idx_1, contact_point_1, contact_normal_1, part_idx_2, contact_point_2, contact_normal_2 = parse_grasp(grasp) 130 | contact_postions = [contact_point_1, contact_point_2] 131 | contact_normals = [contact_normal_1, contact_normal_2] 132 | frictions = [part_frictions[part_idx_1], part_frictions[part_idx_2]] 133 | try: 134 | forces = solve_force(contact_postions, contact_normals, frictions, weight, soft_contact=True) 135 | except: 136 | forces = None 137 | if forces is None \ 138 | or abs(forces[0, 2]) > part_max_normal_forces[part_idx_1] \ 139 | or abs(forces[1, 2]) > part_max_normal_forces[part_idx_2]: 140 | neg_grasps.append(grasp) 141 | else: 142 | pos_grasps.append(grasp) 143 | return np.array(pos_grasps), np.array(neg_grasps) 144 | 145 | 146 | if __name__ == '__main__': 147 | contact_pos = np.array([[0, 1, 0], [0, -1, 0]]) 148 | contact_normal = np.array([[0, -1, 0], [0, 1, 0]]) 149 | 150 | print(solve_force(contact_pos, contact_normal, [0.5, 0.5], soft_contact=True)) 151 | 152 | grasps = np.array([[0, 0, 1, 0, 0, -1, 0, 1, 0, -1, 0, 0, 1, 0], 153 | [0, 0, 1, 0, 0, -1, 0, 1, 0, -1, 0, 0, -1, 0]]) 154 | part_frictions = [0.3, 0.5] 155 | max_forces = [0.5, 0.5] 156 | pos_grasps, neg_grasps = filter_contact_points_by_force(grasps, part_frictions, max_forces, 1.0) 157 | print(pos_grasps, neg_grasps) 158 | -------------------------------------------------------------------------------- /grasp/generate_grasp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | 4 | import logging 5 | 6 | logger = logging.getLogger("trimesh") 7 | logger.setLevel(logging.ERROR) 8 | 9 | """ 10 | Install 11 | pip install numpy trimesh rtree 12 | """ 13 | 14 | 15 | def sample_from_cone(tx, ty, tz, friction_coef=0.3): 16 | """ Samples directoins from within the friction cone using uniform sampling. 17 | 18 | Parameters 19 | ---------- 20 | tx : 3x1 normalized :obj:`numpy.ndarray` 21 | tangent x vector 22 | ty : 3x1 normalized :obj:`numpy.ndarray` 23 | tangent y vector 24 | tz : 3x1 normalized :obj:`numpy.ndarray` 25 | surface normal 26 | num_samples : int 27 | number of directions to sample 28 | 29 | Returns 30 | ------- 31 | v_samples : :obj:`list` of 3x1 :obj:`numpy.ndarray` 32 | sampled directions in the friction cone 33 | """ 34 | theta = 2 * np.pi * np.random.rand() 35 | r = friction_coef * np.random.rand() 36 | v = tz + r * np.cos(theta) * tx + r * np.sin(theta) * ty 37 | return normalize(v) 38 | 39 | 40 | def normalize(x): 41 | norm = np.linalg.norm(x) 42 | norm = 1e-5 if norm == 0 else norm 43 | return x / norm 44 | 45 | 46 | def find_contact(mesh, vertices, directions, offset=20.0): 47 | start_points = vertices + offset * directions 48 | hit_indices = mesh.ray.intersects_first(ray_origins=start_points, ray_directions=-directions) 49 | return hit_indices 50 | 51 | 52 | def vis_grasp(mesh, contact_points): 53 | mesh.visual.vertex_colors = [0.5, 0.5, 0.5, 0.5] 54 | scene_list = [mesh] 55 | # scene_list = [] 56 | 57 | for contact_point, another_contact_point in contact_points: 58 | c1 = trimesh.creation.uv_sphere(radius=0.015) 59 | c2 = trimesh.creation.uv_sphere(radius=0.015) 60 | c1.vertices += contact_point 61 | c2.vertices += another_contact_point 62 | grasp_axis = trimesh.creation.cylinder(0.005, sections=6, 63 | segment=np.vstack([contact_point, another_contact_point])) 64 | grasp_axis.visual.vertex_colors = [1.0, 0.0, 0.0] 65 | c1.visual.vertex_colors = [0.0, 1.0, 0.0] 66 | c2.visual.vertex_colors = [0.0, 1.0, 0.0] 67 | scene_list += [c1, c2, grasp_axis] 68 | 69 | trimesh.Scene(scene_list).show() 70 | 71 | 72 | def generate_contact_rays(surface_vertices, surface_normals, friction): 73 | rays = np.zeros_like(surface_vertices) 74 | for i, (vertice, normal) in enumerate(zip(surface_vertices, surface_normals)): 75 | tz = normalize(normal) 76 | up = normalize(np.random.rand(3)) 77 | tx = normalize(np.cross(tz, up)) 78 | ty = normalize(np.cross(tz, tx)) 79 | ray = sample_from_cone(tx, ty, tz, friction_coef=friction) 80 | rays[i] = ray 81 | return rays 82 | 83 | 84 | def sample_contact_points(mesh, n_sample_point): 85 | surface_vertices, face_idx = trimesh.sample.sample_surface_even(mesh, count=n_sample_point) 86 | surface_vertices = np.asarray(surface_vertices) 87 | surface_normals = - mesh.face_normals[face_idx] # flip normals to point inward 88 | return surface_vertices, surface_normals 89 | 90 | 91 | def check_collision_points(contact_rad, points_volume, contact_point, another_contact_point, contact_normal, 92 | another_contact_normal): 93 | center_1 = contact_point - contact_rad * contact_normal 94 | center_2 = another_contact_point - contact_rad * another_contact_normal 95 | dist_1 = np.linalg.norm(points_volume - center_1, axis=1) 96 | dist_2 = np.linalg.norm(points_volume - center_2, axis=1) 97 | c1_not_in_collision = all(dist_1 > contact_rad) 98 | c2_not_in_collision = all(dist_2 > contact_rad) 99 | is_collision_free = c1_not_in_collision and c2_not_in_collision 100 | return is_collision_free 101 | 102 | 103 | def find_contact_points(mesh, n_sample_point, friction, contact_rad): 104 | surface_vertices, surface_normals = sample_contact_points(mesh, n_sample_point) 105 | rays = generate_contact_rays(surface_vertices, surface_normals, friction) 106 | hit_indices = find_contact(mesh, surface_vertices, rays, offset=np.linalg.norm(mesh.extents) * 4.0) 107 | 108 | 109 | alpha = np.arctan(friction) 110 | grasps = [] 111 | 112 | # Sample surface and volume point cloud for collision check 113 | points_volume = trimesh.sample.volume_mesh(mesh, count=4096) 114 | points_surface, _ = trimesh.sample.sample_surface_even(mesh, count=2048) 115 | points_volume = np.vstack([points_volume, points_surface]) 116 | 117 | for contact_point, contact_normal, ray, hit_index in zip(surface_vertices, surface_normals, rays, hit_indices): 118 | if hit_index == -1: 119 | continue 120 | another_contact_point, another_contact_normal = mesh.triangles_center[hit_index], -mesh.face_normals[hit_index] 121 | 122 | # Check whether force closure 123 | is_force_closure = np.arccos(-ray.dot(another_contact_normal)) <= alpha 124 | 125 | # Check whether collision free 126 | if is_force_closure: 127 | is_collision_free = check_collision_points(contact_rad, points_volume, contact_point, another_contact_point, 128 | contact_normal, another_contact_normal) 129 | 130 | if is_collision_free: 131 | # g = np.hstack([contact_point, another_contact_point, contact_normal, another_contact_normal]) 132 | grasps.append((np.array([contact_point, another_contact_point]), 133 | np.array([contact_normal, another_contact_normal]))) 134 | 135 | return grasps 136 | 137 | 138 | def find_contact_points_multi(meshes, frictions, sample_nums, contact_rad=0.05): 139 | """ 140 | Return (n, 14) vector: part_idx_1, contact_point_1, contact_normal_1, part_idx_2, contact_point_2, contact_normal_2 141 | """ 142 | part_indices, surface_vertices, surface_normals, rays = [], [], [], [] 143 | for i, (mesh, friction, n_sample) in enumerate(zip(meshes, frictions, sample_nums)): 144 | v, n = sample_contact_points(mesh, n_sample) 145 | r = generate_contact_rays(v, n, friction) 146 | surface_vertices.extend(v) 147 | surface_normals.extend(n) 148 | rays.extend(r) 149 | part_indices.extend([i] * n_sample) 150 | 151 | combined_mesh = trimesh.util.concatenate(meshes) 152 | hit_indices = find_contact(combined_mesh, surface_vertices, np.array(rays), offset=np.linalg.norm(combined_mesh.extents) * 4.0) 153 | part_i_thres = np.cumsum([len(mesh.face_normals) for mesh in meshes]) 154 | hit_part_indices = np.sum(np.repeat([hit_indices], len(meshes), axis=0).T >= part_i_thres, axis=1) 155 | 156 | # Sample surface and volume point cloud for collision check 157 | # points_volume = trimesh.sample.volume_mesh(combined_mesh, count=4096) 158 | # points_surface, _ = trimesh.sample.sample_surface_even(combined_mesh, count=2048) 159 | # points_volume = np.vstack([points_volume, points_surface]) 160 | 161 | pos_grasps = [] 162 | neg_grasps = [] 163 | for part_idx_1, contact_point_1, contact_normal_1, ray, hit_idx, part_idx_2 in zip( 164 | part_indices, surface_vertices, surface_normals, rays, hit_indices, hit_part_indices): 165 | if hit_idx == -1: 166 | continue 167 | contact_point_2, contact_normal_2 = combined_mesh.triangles_center[hit_idx], -combined_mesh.face_normals[hit_idx] 168 | 169 | grasp = np.concatenate(([part_idx_1], contact_point_1, contact_normal_1, [part_idx_2], contact_point_2, contact_normal_2)) 170 | 171 | if np.arccos(-ray.dot(contact_normal_2)) <= np.arctan(frictions[part_idx_2]):# and \ 172 | # check_collision_points(contact_rad, points_volume, contact_point_1, contact_point_2, 173 | # contact_normal_1, contact_normal_2): 174 | pos_grasps.append(grasp) 175 | else: 176 | neg_grasps.append(grasp) 177 | return np.array(pos_grasps), np.array(neg_grasps) 178 | 179 | 180 | def parse_grasp(grasp): 181 | """ 182 | :param grasp: (14,) numpy array 183 | :return: part_idx_1, contact_point_1, contact_normal_1, part_idx_2, contact_point_2, contact_normal_2 184 | """ 185 | part_idx_1, contact_point_1, contact_normal_1, part_idx_2, contact_point_2, contact_normal_2 = \ 186 | int(grasp[0]), grasp[1:4], grasp[4:7], int(grasp[7]), grasp[8:11], grasp[11:14] 187 | return part_idx_1, contact_point_1, contact_normal_1, part_idx_2, contact_point_2, contact_normal_2 188 | 189 | 190 | if __name__ == "__main__": 191 | obj_path = 'stanford_bunny.obj' 192 | n_contact_point = 100 193 | friction = 0.2 194 | contact_rad = 0.05 # used for collision check 195 | 196 | # Load mesh 197 | mesh = trimesh.load(obj_path) 198 | 199 | # Compute grasps 200 | grasps = find_contact_points(mesh, n_contact_point, friction, contact_rad) 201 | 202 | # Vis 203 | vis_grasp(mesh, grasps) 204 | -------------------------------------------------------------------------------- /grasp/transform.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Transformation utilities 3 | ''' 4 | 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation 7 | 8 | 9 | def get_transform_matrix(state, com=np.zeros(3)): 10 | ''' 11 | Get transformation matrix of the given state and center of mass 12 | ''' 13 | if len(state) == 3: # translation only 14 | transform = np.eye(4) 15 | transform[:3, 3] = state 16 | return transform 17 | elif len(state) == 6: # translation + rotation 18 | translation, rotation = state[:3], state[3:] 19 | rotation = Rotation.from_rotvec(rotation).as_matrix() 20 | trans0_mat = np.eye(4) 21 | trans0_mat[:3, 3] = -com 22 | rot_mat = np.eye(4) 23 | rot_mat[:3, :3] = rotation 24 | trans1_mat = np.eye(4) 25 | trans1_mat[:3, 3] = translation + com 26 | return trans1_mat.dot(rot_mat).dot(trans0_mat) 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def get_state_from_matrix(matrix, com=np.zeros(3), full_dof=False): 32 | ''' 33 | Get state from the given transformation matrix and center of mass 34 | ''' 35 | translation = matrix[:3, 3] 36 | if (matrix[:3, :3] == np.eye(3)).all(): 37 | if not full_dof: 38 | return translation 39 | else: 40 | return np.concatenate([translation, np.zeros(3)]) 41 | trans0_mat = np.eye(4) 42 | trans0_mat[:3, 3] = -com 43 | rot_mat = np.eye(4) 44 | rot_mat[:3, :3] = matrix[:3, :3] 45 | trans1_mat = matrix.dot(np.linalg.inv(rot_mat.dot(trans0_mat))) 46 | translation = trans1_mat[:3, 3] - com 47 | rotation = Rotation.from_matrix(rot_mat[:3, :3]).as_rotvec() 48 | state = np.concatenate([translation, rotation]) 49 | return state 50 | 51 | 52 | def transform_pts_by_matrix(pts, matrix): 53 | ''' 54 | Transform an array of xyz pts (n, 3) by a 4x4 matrix 55 | ''' 56 | pts = np.array(pts) 57 | if len(pts.shape) == 1: 58 | if len(pts) == 3: 59 | v = np.append(pts, 1.0) 60 | elif len(pts) == 4: 61 | v = pts 62 | else: 63 | raise NotImplementedError 64 | v = matrix @ v 65 | return v[0:3] 66 | elif len(pts.shape) == 2: 67 | # transpose first 68 | if pts.shape[1] == 3: 69 | # pad the points with ones to be (n, 4) 70 | v = np.hstack([pts, np.ones((len(pts), 1))]).T 71 | elif pts.shape[1] == 4: 72 | v = pts.T 73 | else: 74 | raise NotImplementedError 75 | v = matrix @ v 76 | # transpose and crop back to (n, 3) 77 | return v.T[:, 0:3] 78 | else: 79 | raise NotImplementedError 80 | 81 | 82 | def transform_pts_by_state(pts, state, com=np.zeros(3)): 83 | matrix = get_transform_matrix(state, com) 84 | return transform_pts_by_matrix(pts, matrix) 85 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/model/__init__.py -------------------------------------------------------------------------------- /model/data_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from dataset.LVDataset import LVDataset 5 | import numpy as np 6 | import random 7 | 8 | def get_dataloader(params): 9 | batch_size = params['batch_size'] 10 | shuffle = params['shuffle'] 11 | num_workers = params['num_workers'] 12 | with open('./data/dataset/train_dataset_v2.pkl', 'rb') as f: 13 | train_dataset = pickle.load(f) 14 | with open('./data/dataset/val_dataset_v2.pkl', 'rb') as f: 15 | val_dataset = pickle.load(f) 16 | with open('./data/dataset/test_dataset_v2.pkl', 'rb') as f: 17 | test_dataset = pickle.load(f) 18 | 19 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 20 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 21 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 22 | return train_loader, val_loader, test_loader 23 | 24 | def get_dataloader_special(params): 25 | batch_size = params['batch_size'] 26 | shuffle = params['shuffle'] 27 | num_workers = params['num_workers'] 28 | dataset_path = f'./checkpoints/test_dataset_{params["model_id"]}.pkl' 29 | with open(dataset_path, 'rb') as f: 30 | dataset = pickle.load(f) 31 | dataset.version = "" 32 | test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 33 | return test_loader 34 | 35 | def ramdom_sample_pos(pos): 36 | ''' 37 | pos: tensor of shape (batch_size, kp1, 6) 38 | select best 1 in kp1 for each batch, then select the second with max distance with the first one, select the third with max distance with the first two, and so on until 5 39 | ''' 40 | 41 | batch_size, kp1, _ = pos.shape 42 | pos_sample = torch.zeros((batch_size, 5, 6)).to(pos.device) 43 | for i in range(batch_size): 44 | pos_sample[i, 0] = pos[i, 0] 45 | indexes = [0] 46 | for j in range(1, 5): 47 | pos_i = pos[i].unsqueeze(1).repeat(1, j, 1) # (kp1, j, 6) 48 | pos_sample_i = pos_sample[i, :j].unsqueeze(0) # (1, j, 6) 49 | dis = torch.norm(pos_i - pos_sample_i, dim=(1,2)) # (kp1) 50 | ids = torch.topk(dis, k=5).indices 51 | for id in ids: 52 | if id not in indexes: 53 | break 54 | pos_sample[i, j] = pos[i, id] 55 | indexes.append(id) 56 | # check if the index is repeated 57 | assert len(indexes) == len(set(indexes)), "index is repeated" 58 | return pos_sample 59 | 60 | def split_data(params): 61 | dataset_dir = params['dataset_dir'] 62 | # assert False, "split_data is deprecated" 63 | 64 | dataset = LVDataset() 65 | dataset.load(dataset_dir, version="_v1") 66 | dataset_size = len(dataset) 67 | train_size = int(dataset_size * 0.70) 68 | val_size = int(dataset_size * 0.10) 69 | test_size = dataset_size - train_size - val_size 70 | 71 | train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) 72 | with open('./data/dataset/train_dataset_v1.pkl', 'wb') as f: 73 | pickle.dump(train_dataset, f) 74 | with open('./data/dataset/val_dataset_v1.pkl', 'wb') as f: 75 | pickle.dump(val_dataset, f) 76 | with open('./data/dataset/test_dataset_v1.pkl', 'wb') as f: 77 | pickle.dump(test_dataset, f) 78 | 79 | def place_data(): 80 | with open('./data/dataset/train_dataset_v1.pkl', 'rb') as f: 81 | train_dataset = pickle.load(f) 82 | train_entry_paths = [entry for entry in train_dataset] 83 | pickle.dump(train_entry_paths, open('./data/dataset/train_entry_paths_v1.pkl', 'wb')) 84 | print(len(train_entry_paths)) 85 | with open('./data/dataset/val_dataset_v1.pkl', 'rb') as f: 86 | val_dataset = pickle.load(f) 87 | val_entry_paths = [entry for entry in val_dataset] 88 | pickle.dump(val_entry_paths, open('./data/dataset/val_entry_paths_v1.pkl', 'wb')) 89 | print(len(val_entry_paths)) 90 | with open('./data/dataset/test_dataset_v1.pkl', 'rb') as f: 91 | test_dataset = pickle.load(f) 92 | test_entry_paths = [entry for entry in test_dataset] 93 | pickle.dump(test_entry_paths, open('./data/dataset/test_entry_paths_v1.pkl', 'wb')) 94 | print(len(test_entry_paths)) 95 | 96 | def move_data(): 97 | with open('./data/dataset/train_entry_paths_v1.pkl', 'rb') as f: 98 | train_entry_paths = pickle.load(f) 99 | with open('./data/dataset/val_entry_paths_v1.pkl', 'rb') as f: 100 | val_entry_paths = pickle.load(f) 101 | random.shuffle(val_entry_paths) 102 | val_entry_paths_move = val_entry_paths[0:9385] 103 | val_entry_paths = val_entry_paths[9385:] 104 | train_entry_paths += val_entry_paths_move 105 | with open('./data/dataset/test_entry_paths_v1.pkl', 'rb') as f: 106 | test_entry_paths = pickle.load(f) 107 | 108 | with open('./data/dataset/test_entry_paths_v2.pkl', 'rb') as f: 109 | test_entry_paths_remain = pickle.load(f) 110 | test_entry_paths_move = set(test_entry_paths) - set(test_entry_paths_remain) 111 | train_entry_paths += list(test_entry_paths_move) 112 | 113 | train_dataset = LVDataset(version="_v1") 114 | train_dataset.entry_paths = train_entry_paths 115 | val_dataset = LVDataset(version="_v1") 116 | val_dataset.entry_paths = val_entry_paths 117 | test_dataset = LVDataset(version="_v1") 118 | test_dataset.entry_paths = test_entry_paths_remain 119 | print(len(train_dataset), len(val_dataset), len(test_dataset)) 120 | # with open('./data/dataset/train_dataset_v2.pkl', 'wb') as f: 121 | # pickle.dump(train_dataset, f) 122 | # with open('./data/dataset/val_dataset_v2.pkl', 'wb') as f: 123 | # pickle.dump(val_dataset, f) 124 | # with open('./data/dataset/test_dataset_v2.pkl', 'wb') as f: 125 | # pickle.dump(test_dataset, f) 126 | 127 | if __name__ == "__main__": 128 | move_data() -------------------------------------------------------------------------------- /model/eval_ab_l.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | # from dataset.Dataset import Config 5 | from model.model_ab_l import NetABL 6 | from model.modelGlobal import NetG 7 | import logging 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | import torch.multiprocessing as mp 11 | from model.utils import plot, save_loss 12 | from model.data_utils import get_dataloader, get_dataloader_special, ramdom_sample_pos 13 | import argparse 14 | import time 15 | import os 16 | import pickle 17 | import trimesh 18 | import json 19 | import numpy as np 20 | 21 | mp.set_start_method('spawn', force=True) 22 | VISION_LOCAL_THRESHOLD = 1000.0 23 | 24 | def evaluate(model: NetABL, test_loader, device, params): 25 | model.to(device) 26 | model.eval() 27 | with torch.no_grad(): 28 | for i, data in enumerate(test_loader): 29 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 30 | language = language.to(device).float() 31 | point_cloud = point_cloud.to(device).float() 32 | vision_global = vision_global.to(device).float() 33 | vision_local = vision_local.to(device).float() 34 | grasp_map = grasp_map.to(device).float() 35 | pos_index = pos_index.to(device).long() 36 | neg_index = neg_index.to(device).long() 37 | pos_neg_num = pos_neg_num.to(device).long() 38 | point_cloud_copy = point_cloud.clone().detach() 39 | 40 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 41 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 42 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 43 | for i in index: 44 | print("invalid data entry path", entry_paths[i]) 45 | print("vision_local mean", vision_local_mean[i]) 46 | print("vision_local", vision_local[i]) 47 | continue 48 | 49 | if params['global']: 50 | output_global = model(language, point_cloud, vision_global, vision_local) 51 | else: 52 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 53 | loss_global, loss_local_pos, loss_local_neg, _, _ = model.get_loss(output_global, output_local, grasp_map, 54 | pos_index, neg_index, pos_neg_num, 55 | params['delta_v'], params['delta_d']) 56 | index1 = torch.topk(output_global.squeeze(), k=params['kp1']).indices # (batch_size, kp1) 57 | score = model.get_score(output_local, index1) # (batch_size, kp1, 2048, 1) 58 | pos = model.get_pos(output_local, index1, point_cloud_copy) # (batch_size, kp1, 6) 59 | pos = ramdom_sample_pos(pos) # (batch_size, 5, 6) 60 | 61 | print(f'Evaluating Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}') 62 | 63 | if params['maps_save']: 64 | maps = { 65 | # 'point_cloud': point_cloud.squeeze().cpu().numpy(), # (batch_size, 2048, 3) 66 | 'point_cloud': point_cloud_copy.squeeze().cpu().numpy(), # (batch_size, 2048, 3) 67 | 'grasp_map': grasp_map.squeeze().cpu().numpy(), # (batch_size, 2048) 68 | 'prediction': output_global.squeeze().cpu().numpy(), # (batch_size, 2048) 69 | 'embeddings': output_local.squeeze().cpu().numpy(), # (batch_size, 2048, 32) 70 | 'entry_path': entry_paths, 71 | 'index20': index1.squeeze().cpu().numpy(), # (batch_size, kp1) 72 | # 'index1': index_1.squeeze().cpu().numpy(), # (batch_size,) 73 | 'score': score.squeeze().cpu().numpy(), 74 | 'pos': pos.squeeze().cpu().numpy(), # (batch_size, 5, 6) 75 | } 76 | pickle.dump(maps, open('./checkpoints/maps_{}_{}.pkl'.format(params['model_id'], i), 'wb')) 77 | # print('Saved maps_{}_{}.pkl'.format(params['model_id'], i)) 78 | 79 | def main(params): 80 | _, _, test_loader = get_dataloader(params) 81 | # test_loader = get_dataloader_special(params) 82 | if params['global']: 83 | checkpoint_path = './checkpoints/modelGlobal_{}_{}.pth'.format(params['model_id'], params['epoch_id']) 84 | model = NetG() 85 | else: 86 | checkpoint_path = './checkpoints/model_{}_{}.pth'.format(params['model_id'], params['epoch_id']) 87 | model = NetABL() 88 | checkpoint = torch.load(checkpoint_path) 89 | model.load_state_dict(checkpoint['model']) 90 | # torch.set_printoptions(threshold=1000000) 91 | device = torch.device('cuda:0') 92 | evaluate(model, test_loader, device, params) 93 | 94 | if __name__ == '__main__': 95 | argparser = argparse.ArgumentParser() 96 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 97 | argparser.add_argument('--shuffle', type=bool, default=False) 98 | argparser.add_argument('--epoch_id', type=int, default=16) 99 | argparser.add_argument('--batch_size', type=int, default=128) 100 | argparser.add_argument('--num_workers', type=int, default=16) 101 | argparser.add_argument('--lr', type=float, default=2e-3) 102 | # argparser.add_argument('--object_id', type=str, default='42') 103 | # argparser.add_argument('--config_id', type=str, default='1704368644.959488') 104 | 105 | argparser.add_argument('--model_id', type=str, default='') 106 | argparser.add_argument('--global', default=False, action='store_true') 107 | argparser.add_argument('--gt', default=False, action='store_true') 108 | argparser.add_argument('--maps_save', default=False, action='store_true') 109 | argparser.add_argument('--delta_v', type=float, default=0.5) 110 | argparser.add_argument('--delta_d', type=float, default=3.0) 111 | argparser.add_argument('--kp1', type=int, default=20) 112 | argparser.add_argument('--kp2', type=int, default=10) 113 | # 1000_1704368410.6270053 114 | 115 | args = argparser.parse_args() 116 | params = vars(args) 117 | # params['model_id'] = time.time() 118 | main(params) -------------------------------------------------------------------------------- /model/eval_ab_vg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | # from dataset.Dataset import Config 5 | from model.model_ab_vg import NetABVG 6 | from model.modelGlobal import NetG 7 | import logging 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | import torch.multiprocessing as mp 11 | from model.utils import plot, save_loss 12 | from model.data_utils import get_dataloader, get_dataloader_special, ramdom_sample_pos 13 | import argparse 14 | import time 15 | import os 16 | import pickle 17 | import trimesh 18 | import json 19 | import numpy as np 20 | 21 | mp.set_start_method('spawn', force=True) 22 | VISION_LOCAL_THRESHOLD = 1000.0 23 | 24 | def evaluate(model: NetABVG, test_loader, device, params): 25 | model.to(device) 26 | model.eval() 27 | with torch.no_grad(): 28 | for i, data in enumerate(test_loader): 29 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 30 | language = language.to(device).float() 31 | point_cloud = point_cloud.to(device).float() 32 | vision_global = vision_global.to(device).float() 33 | vision_local = vision_local.to(device).float() 34 | grasp_map = grasp_map.to(device).float() 35 | pos_index = pos_index.to(device).long() 36 | neg_index = neg_index.to(device).long() 37 | pos_neg_num = pos_neg_num.to(device).long() 38 | point_cloud_copy = point_cloud.clone().detach() 39 | 40 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 41 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 42 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 43 | for i in index: 44 | print("invalid data entry path", entry_paths[i]) 45 | print("vision_local mean", vision_local_mean[i]) 46 | print("vision_local", vision_local[i]) 47 | continue 48 | 49 | if params['global']: 50 | output_global = model(language, point_cloud, vision_global, vision_local) 51 | else: 52 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 53 | loss_global, loss_local_pos, loss_local_neg, _, _ = model.get_loss(output_global, output_local, grasp_map, 54 | pos_index, neg_index, pos_neg_num, 55 | params['delta_v'], params['delta_d']) 56 | index1 = torch.topk(output_global.squeeze(), k=params['kp1']).indices # (batch_size, kp1) 57 | score = model.get_score(output_local, index1) # (batch_size, kp1, 2048, 1) 58 | pos = model.get_pos(output_local, index1, point_cloud_copy) # (batch_size, kp1, 6) 59 | pos = ramdom_sample_pos(pos) # (batch_size, 5, 6) 60 | 61 | print(f'Evaluating Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}') 62 | 63 | if params['maps_save']: 64 | maps = { 65 | # 'point_cloud': point_cloud.squeeze().cpu().numpy(), # (batch_size, 2048, 3) 66 | 'point_cloud': point_cloud_copy.squeeze().cpu().numpy(), # (batch_size, 2048, 3) 67 | 'grasp_map': grasp_map.squeeze().cpu().numpy(), # (batch_size, 2048) 68 | 'prediction': output_global.squeeze().cpu().numpy(), # (batch_size, 2048) 69 | 'embeddings': output_local.squeeze().cpu().numpy(), # (batch_size, 2048, 32) 70 | 'entry_path': entry_paths, 71 | 'index20': index1.squeeze().cpu().numpy(), # (batch_size, kp1) 72 | # 'index1': index_1.squeeze().cpu().numpy(), # (batch_size,) 73 | 'score': score.squeeze().cpu().numpy(), 74 | 'pos': pos.squeeze().cpu().numpy(), # (batch_size, 5, 6) 75 | } 76 | pickle.dump(maps, open('./checkpoints/maps_{}_{}.pkl'.format(params['model_id'], i), 'wb')) 77 | # print('Saved maps_{}_{}.pkl'.format(params['model_id'], i)) 78 | 79 | def main(params): 80 | _, _, test_loader = get_dataloader(params) 81 | # test_loader = get_dataloader_special(params) 82 | if params['global']: 83 | checkpoint_path = './checkpoints/modelGlobal_{}_{}.pth'.format(params['model_id'], params['epoch_id']) 84 | model = NetG() 85 | else: 86 | checkpoint_path = './checkpoints/model_{}_{}.pth'.format(params['model_id'], params['epoch_id']) 87 | model = NetABVG() 88 | checkpoint = torch.load(checkpoint_path) 89 | model.load_state_dict(checkpoint['model']) 90 | # torch.set_printoptions(threshold=1000000) 91 | device = torch.device('cuda:0') 92 | evaluate(model, test_loader, device, params) 93 | 94 | if __name__ == '__main__': 95 | argparser = argparse.ArgumentParser() 96 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 97 | argparser.add_argument('--shuffle', type=bool, default=False) 98 | argparser.add_argument('--epoch_id', type=int, default=16) 99 | argparser.add_argument('--batch_size', type=int, default=128) 100 | argparser.add_argument('--num_workers', type=int, default=16) 101 | argparser.add_argument('--lr', type=float, default=2e-3) 102 | # argparser.add_argument('--object_id', type=str, default='42') 103 | # argparser.add_argument('--config_id', type=str, default='1704368644.959488') 104 | 105 | argparser.add_argument('--model_id', type=str, default='') 106 | argparser.add_argument('--global', default=False, action='store_true') 107 | argparser.add_argument('--gt', default=False, action='store_true') 108 | argparser.add_argument('--maps_save', default=False, action='store_true') 109 | argparser.add_argument('--delta_v', type=float, default=0.5) 110 | argparser.add_argument('--delta_d', type=float, default=3.0) 111 | argparser.add_argument('--kp1', type=int, default=20) 112 | argparser.add_argument('--kp2', type=int, default=10) 113 | # 1000_1704368410.6270053 114 | 115 | args = argparser.parse_args() 116 | params = vars(args) 117 | # params['model_id'] = time.time() 118 | main(params) -------------------------------------------------------------------------------- /model/eval_ab_vl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | # from dataset.Dataset import Config 5 | from model.model_ab_vl import NetABVL 6 | from model.modelGlobal import NetG 7 | import logging 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import DataLoader 10 | import torch.multiprocessing as mp 11 | from model.utils import plot, save_loss 12 | from model.data_utils import get_dataloader, get_dataloader_special, ramdom_sample_pos 13 | import argparse 14 | import time 15 | import os 16 | import pickle 17 | import trimesh 18 | import json 19 | import numpy as np 20 | 21 | mp.set_start_method('spawn', force=True) 22 | VISION_LOCAL_THRESHOLD = 1000.0 23 | 24 | def evaluate(model: NetABVL, test_loader, device, params): 25 | model.to(device) 26 | model.eval() 27 | with torch.no_grad(): 28 | for i, data in enumerate(test_loader): 29 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 30 | language = language.to(device).float() 31 | point_cloud = point_cloud.to(device).float() 32 | vision_global = vision_global.to(device).float() 33 | vision_local = vision_local.to(device).float() 34 | grasp_map = grasp_map.to(device).float() 35 | pos_index = pos_index.to(device).long() 36 | neg_index = neg_index.to(device).long() 37 | pos_neg_num = pos_neg_num.to(device).long() 38 | point_cloud_copy = point_cloud.clone().detach() 39 | 40 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 41 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 42 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 43 | for i in index: 44 | print("invalid data entry path", entry_paths[i]) 45 | print("vision_local mean", vision_local_mean[i]) 46 | print("vision_local", vision_local[i]) 47 | continue 48 | 49 | if params['global']: 50 | output_global = model(language, point_cloud, vision_global, vision_local) 51 | else: 52 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 53 | loss_global, loss_local_pos, loss_local_neg, _, _ = model.get_loss(output_global, output_local, grasp_map, 54 | pos_index, neg_index, pos_neg_num, 55 | params['delta_v'], params['delta_d']) 56 | index1 = torch.topk(output_global.squeeze(), k=params['kp1']).indices # (batch_size, kp1) 57 | score = model.get_score(output_local, index1) # (batch_size, kp1, 2048, 1) 58 | pos = model.get_pos(output_local, index1, point_cloud_copy) # (batch_size, kp1, 6) 59 | pos = ramdom_sample_pos(pos) # (batch_size, 5, 6) 60 | 61 | print(f'Evaluating Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}') 62 | 63 | if params['maps_save']: 64 | maps = { 65 | # 'point_cloud': point_cloud.squeeze().cpu().numpy(), # (batch_size, 2048, 3) 66 | 'point_cloud': point_cloud_copy.squeeze().cpu().numpy(), # (batch_size, 2048, 3) 67 | 'grasp_map': grasp_map.squeeze().cpu().numpy(), # (batch_size, 2048) 68 | 'prediction': output_global.squeeze().cpu().numpy(), # (batch_size, 2048) 69 | 'embeddings': output_local.squeeze().cpu().numpy(), # (batch_size, 2048, 32) 70 | 'entry_path': entry_paths, 71 | 'index20': index1.squeeze().cpu().numpy(), # (batch_size, kp1) 72 | # 'index1': index_1.squeeze().cpu().numpy(), # (batch_size,) 73 | 'score': score.squeeze().cpu().numpy(), 74 | 'pos': pos.squeeze().cpu().numpy(), # (batch_size, 5, 6) 75 | } 76 | pickle.dump(maps, open('./checkpoints/maps_{}_{}.pkl'.format(params['model_id'], i), 'wb')) 77 | # print('Saved maps_{}_{}.pkl'.format(params['model_id'], i)) 78 | 79 | def main(params): 80 | _, _, test_loader = get_dataloader(params) 81 | # test_loader = get_dataloader_special(params) 82 | if params['global']: 83 | checkpoint_path = './checkpoints/modelGlobal_{}_{}.pth'.format(params['model_id'], params['epoch_id']) 84 | model = NetG() 85 | else: 86 | checkpoint_path = './checkpoints/model_{}_{}.pth'.format(params['model_id'], params['epoch_id']) 87 | model = NetABVL() 88 | checkpoint = torch.load(checkpoint_path) 89 | model.load_state_dict(checkpoint['model']) 90 | # torch.set_printoptions(threshold=1000000) 91 | device = torch.device('cuda:0') 92 | evaluate(model, test_loader, device, params) 93 | 94 | if __name__ == '__main__': 95 | argparser = argparse.ArgumentParser() 96 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 97 | argparser.add_argument('--shuffle', type=bool, default=False) 98 | argparser.add_argument('--epoch_id', type=int, default=16) 99 | argparser.add_argument('--batch_size', type=int, default=128) 100 | argparser.add_argument('--num_workers', type=int, default=16) 101 | argparser.add_argument('--lr', type=float, default=2e-3) 102 | # argparser.add_argument('--object_id', type=str, default='42') 103 | # argparser.add_argument('--config_id', type=str, default='1704368644.959488') 104 | 105 | argparser.add_argument('--model_id', type=str, default='') 106 | argparser.add_argument('--global', default=False, action='store_true') 107 | argparser.add_argument('--gt', default=False, action='store_true') 108 | argparser.add_argument('--maps_save', default=False, action='store_true') 109 | argparser.add_argument('--delta_v', type=float, default=0.5) 110 | argparser.add_argument('--delta_d', type=float, default=3.0) 111 | argparser.add_argument('--kp1', type=int, default=20) 112 | argparser.add_argument('--kp2', type=int, default=10) 113 | # 1000_1704368410.6270053 114 | 115 | args = argparser.parse_args() 116 | params = vars(args) 117 | # params['model_id'] = time.time() 118 | main(params) -------------------------------------------------------------------------------- /model/grasp_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | 4 | 5 | 6 | def show_heatmap(points, grasp_map): 7 | pcd = trimesh.PointCloud(points) 8 | pcd.visual.vertex_colors = trimesh.visual.color.interpolate(grasp_map, color_map='hot') 9 | pcd.visual.vertex_colors[:, 3] = 0.8 * 255 10 | scene_list = [pcd] 11 | trimesh.Scene(scene_list).show() 12 | 13 | ''' 14 | def show_score_map(pid, points, score,): 15 | pcd = trimesh.PointCloud(points) 16 | pid2 = np.argmax(score) 17 | score_max = np.max(score) 18 | score_min = np.min(score) 19 | print("score_min {}, score_max {}".format(score_min, score_max)) 20 | print("score {}".format(score)) 21 | score = (score - score_min) / (score_max - score_min) 22 | 23 | pcd.visual.vertex_colors = trimesh.visual.color.interpolate(score, color_map='hot') 24 | pcd.visual.vertex_colors[:, 3] = 0.8 * 255 25 | # highlight the point pid with larger size and different color 26 | pcd.visual.vertex_colors[pid] = [0, 255., 0., .8 * 255] 27 | pcd.visual.vertex_colors[pid2] = [0, 0., 255., .8 * 255] 28 | scene_list = [pcd] 29 | trimesh.Scene(scene_list).show() 30 | ''' 31 | 32 | def show_score_map(pid, points, score,): 33 | pcd = trimesh.PointCloud(points) 34 | pid2 = np.argmax(score) 35 | score_max = np.max(score) 36 | score_min = np.min(score) 37 | print("score_min {}, score_max {}".format(score_min, score_max)) 38 | print("score {}".format(score)) 39 | score = (score - score_min) / (score_max - score_min) 40 | score = np.square(score) 41 | 42 | pcd.visual.vertex_colors = trimesh.visual.color.interpolate(score, color_map='cividis') 43 | pcd.visual.vertex_colors[:, 3] = 1.0 * 255 44 | # highlight the point pid with larger size and different color 45 | # pcd.visual.vertex_colors[pid] = [0, 255., 0., .8 * 255] 46 | # pcd.visual.vertex_colors[pid2] = [0, 0., 255., .8 * 255] 47 | ball = trimesh.creation.uv_sphere(radius=0.025) 48 | ball.visual.vertex_colors = [255., 0., 0., .8 * 255] 49 | ball.apply_translation(points[pid]) 50 | ball2 = trimesh.creation.uv_sphere(radius=0.025) 51 | ball2.visual.vertex_colors = [255., 255., 0., .8 * 255] 52 | ball2.apply_translation(points[pid2]) 53 | 54 | scene_list = [pcd, ball, ball2] 55 | trimesh.Scene(scene_list).show() 56 | 57 | def show_embedding_map(pid, points, embeddings): 58 | embed_dists = embeddings_map(pid, embeddings) 59 | pcd = trimesh.PointCloud(points) 60 | pcd.visual.vertex_colors = trimesh.visual.color.interpolate(embed_dists, color_map='hot') 61 | 62 | pcd.visual.vertex_colors[:, 3] = 0.8 * 255 63 | # highlight the point pid with larger size and different color 64 | 65 | pcd.visual.vertex_colors[pid] = [0, 255., 0., .8 * 255] 66 | 67 | scene_list = [pcd] 68 | trimesh.Scene(scene_list).show() 69 | 70 | def embeddings_map(pid, embeddings): 71 | assert embeddings.shape == (2048, 32), "embeddings shape is {}".format(embeddings.shape) 72 | embed_dists = np.linalg.norm(embeddings - embeddings[pid], axis=-1, ord=2) 73 | embed_dist_copy = embed_dists.copy() 74 | embed_min = np.min(embed_dist_copy) 75 | embed_max = np.max(embed_dist_copy) 76 | print("embed_min {}, embed_max {}".format(embed_min, embed_max)) 77 | embed_dist_copy = np.delete(embed_dist_copy, pid) 78 | embed_min = np.min(embed_dist_copy) 79 | embed_max = np.max(embed_dist_copy) 80 | print("embed_min {}, embed_max {}".format(embed_min, embed_max)) 81 | print("embed_dists {}".format(embed_dists)) 82 | embed_dists = (embed_dists - embed_min) / (embed_max - embed_min) 83 | embed_dists[pid] = 0. 84 | # print("embed_dists {}".format(embed_dists)) 85 | 86 | return embed_dists -------------------------------------------------------------------------------- /model/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | if __name__ == '__main__': 3 | # n_awl: 2: w/o embedding loss, 3: w/ embedding loss 4 | os.system("python -u -m model.trainer --lr 1e-3 --n_awl 3 --delta_v 0.4 --delta_d 2.0 --lambda_p 2.0") 5 | # os.system("python -u -m model.trainer --lr 1e-3 --n_awl 2 --delta_v 0.4 --delta_d 2.0 --lambda_p 2.0") 6 | os.system("python -u -m model.trainer_ab_l --lr 1e-3 --n_awl 3 --delta_v 0.4 --delta_d 2.0 --lambda_p 2.0") 7 | os.system("python -u -m model.trainer_ab_vg --lr 1e-3 --n_awl 3 --delta_v 0.4 --delta_d 2.0 --lambda_p 2.0") 8 | os.system("python -u -m model.trainer_ab_vl --lr 1e-3 --n_awl 3 --delta_v 0.4 --delta_d 2.0 --lambda_p 2.0") 9 | 10 | -------------------------------------------------------------------------------- /model/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | from model.model import Net 5 | import matplotlib.pyplot as plt 6 | from torch.utils.data import DataLoader 7 | import torch.multiprocessing as mp 8 | from model.utils import plot, save_loss, save_loss10 9 | from model.data_utils import get_dataloader 10 | import argparse 11 | import time 12 | import pickle 13 | from AutomaticWeightedLoss.AutomaticWeightedLoss import AutomaticWeightedLoss 14 | 15 | mp.set_start_method('spawn', force=True) 16 | VISION_LOCAL_THRESHOLD = 1000.0 17 | 18 | def train(model: Net, train_loader, val_loader, params, device): 19 | model.to(device) 20 | model_id = params['model_id'] 21 | awl = AutomaticWeightedLoss(params['n_awl']) 22 | optimizer, scheduler = model.get_optimizer(lr=params['lr'], steps_per_epoch=len(train_loader), num_epochs=params['epochs'], awl=awl) 23 | train_losses_global = [] 24 | train_losses_local_pos = [] 25 | train_losses_local_neg = [] 26 | train_losses_local_pos1 = [] 27 | train_losses_local_neg1 = [] 28 | test_losses_global = [] 29 | test_losses_local_pos = [] 30 | test_losses_local_neg = [] 31 | test_losses_local_pos1 = [] 32 | test_losses_local_neg1 = [] 33 | for epoch in range(params['epochs']): 34 | model.train() 35 | for i, data in enumerate(train_loader): 36 | optimizer.zero_grad() 37 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 38 | language = language.to(device).float() 39 | point_cloud = point_cloud.to(device).float() 40 | vision_global = vision_global.to(device).float() 41 | vision_local = vision_local.to(device).float() 42 | grasp_map = grasp_map.to(device).float() 43 | pos_index = pos_index.to(device).long() 44 | neg_index = neg_index.to(device).long() 45 | pos_neg_num = pos_neg_num.to(device).long() 46 | 47 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 48 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 49 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 50 | for i in index: 51 | print("invalid data entry path", entry_paths[i]) 52 | print("vision_local mean", vision_local_mean[i]) 53 | print("vision_local", vision_local[i]) 54 | continue 55 | 56 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 57 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 58 | 59 | print(f'Train Epoch {epoch}, Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 60 | if params['n_awl'] == 2: 61 | loss = awl(loss_global, loss_local_pos * params['lambda_pos'] + loss_local_neg) 62 | else: 63 | assert params['n_awl'] == 3 64 | loss = awl(loss_global, loss_local_pos + loss_local_neg, loss_local_pos1 * params['lambda_pos'] + loss_local_neg1) 65 | loss.backward() 66 | optimizer.step() 67 | scheduler.step() 68 | 69 | train_losses_global.append(loss_global.item()) 70 | train_losses_local_pos.append(loss_local_pos.item()) 71 | train_losses_local_neg.append(loss_local_neg.item()) 72 | train_losses_local_pos1.append(loss_local_pos1.item()) 73 | train_losses_local_neg1.append(loss_local_neg1.item()) 74 | 75 | 76 | test_loss_global, test_loss_local_pos, test_loss_local_neg, test_loss_local_pos1, test_loss_local_neg1 = validate(model, val_loader, device) 77 | 78 | test_losses_global += test_loss_global 79 | test_losses_local_pos += test_loss_local_pos 80 | test_losses_local_neg += test_loss_local_neg 81 | test_losses_local_pos1 += test_loss_local_pos1 82 | test_losses_local_neg1 += test_loss_local_neg1 83 | # save_loss(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, model_id=model_id) 84 | save_loss10(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, train_losses_local_pos1=train_losses_local_pos1, train_losses_local_neg1=train_losses_local_neg1, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, test_losses_local_pos1=test_losses_local_pos1, test_losses_local_neg1=test_losses_local_neg1, model_id=model_id) 85 | checkpoint = { 86 | 'model': model.state_dict(), 87 | 'optimizer': optimizer.state_dict(), 88 | 'awl': awl.state_dict(), 89 | 'epoch': epoch 90 | } 91 | checkpoint_path = './checkpoints/model_{}_{}.pth'.format(model_id, epoch) 92 | torch.save(checkpoint, checkpoint_path) 93 | 94 | return train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg 95 | 96 | def validate(model: Net, val_loader, device): 97 | model.to(device) 98 | model.eval() 99 | with torch.no_grad(): 100 | test_losses_global = [] 101 | test_losses_local_pos = [] 102 | test_losses_local_neg = [] 103 | test_losses_local_pos1 = [] 104 | test_losses_local_neg1 = [] 105 | for i, data in enumerate(val_loader): 106 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 107 | language = language.to(device).float() 108 | point_cloud = point_cloud.to(device).float() 109 | vision_global = vision_global.to(device).float() 110 | vision_local = vision_local.to(device).float() 111 | grasp_map = grasp_map.to(device).float() 112 | pos_index = pos_index.to(device).long() 113 | neg_index = neg_index.to(device).long() 114 | pos_neg_num = pos_neg_num.to(device).long() 115 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 116 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 117 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 118 | for i in index: 119 | print("invalid data entry path", entry_paths[i]) 120 | print("vision_local mean", vision_local_mean[i]) 121 | print("vision_local", vision_local[i]) 122 | continue 123 | 124 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 125 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 126 | 127 | test_losses_global.append(loss_global.item()) 128 | test_losses_local_pos.append(loss_local_pos.item()) 129 | test_losses_local_neg.append(loss_local_neg.item()) 130 | test_losses_local_pos1.append(loss_local_pos1.item()) 131 | test_losses_local_neg1.append(loss_local_neg1.item()) 132 | print(f'Validate Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 133 | 134 | return test_losses_global, test_losses_local_pos, test_losses_local_neg, test_losses_local_pos1, test_losses_local_neg1 135 | 136 | 137 | def main(params): 138 | train_loader, val_loader, test_loader = get_dataloader(params) 139 | model = Net() 140 | device = torch.device('cuda:0') 141 | train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg = train(model, train_loader, val_loader, params, device) 142 | print("finished training") 143 | validate(model, test_loader, device) 144 | # plot(train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg, model_id=params['model_id']) 145 | 146 | if __name__ == '__main__': 147 | argparser = argparse.ArgumentParser() 148 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 149 | argparser.add_argument('--shuffle', type=bool, default=True) 150 | argparser.add_argument('--epochs', type=int, default=40) 151 | argparser.add_argument('--batch_size', type=int, default=128) 152 | argparser.add_argument('--num_workers', type=int, default=16) 153 | argparser.add_argument('--lr', type=float, default=2e-3) 154 | argparser.add_argument('--small_data', default=False, action='store_true') 155 | argparser.add_argument('--delta_v', type=float, default=0.5) 156 | argparser.add_argument('--delta_d', type=float, default=3.0) 157 | argparser.add_argument('--lambda_pos', type=float, default=1.0) 158 | argparser.add_argument('--n_awl', type=int, default=2) 159 | args = argparser.parse_args() 160 | params = vars(args) 161 | params['model_id'] = time.time() 162 | main(params) -------------------------------------------------------------------------------- /model/trainer_ab_l.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | from model.model_ab_l import NetABL 5 | import matplotlib.pyplot as plt 6 | from torch.utils.data import DataLoader 7 | import torch.multiprocessing as mp 8 | from model.utils import plot, save_loss, save_loss10 9 | from model.data_utils import get_dataloader 10 | import argparse 11 | import time 12 | import pickle 13 | from AutomaticWeightedLoss.AutomaticWeightedLoss import AutomaticWeightedLoss 14 | 15 | mp.set_start_method('spawn', force=True) 16 | VISION_LOCAL_THRESHOLD = 1000.0 17 | 18 | def train(model: NetABL, train_loader, val_loader, params, device): 19 | model.to(device) 20 | model_id = params['model_id'] 21 | awl = AutomaticWeightedLoss(params['n_awl']) 22 | optimizer, scheduler = model.get_optimizer(lr=params['lr'], steps_per_epoch=len(train_loader), num_epochs=params['epochs'], awl=awl) 23 | train_losses_global = [] 24 | train_losses_local_pos = [] 25 | train_losses_local_neg = [] 26 | train_losses_local_pos1 = [] 27 | train_losses_local_neg1 = [] 28 | test_losses_global = [] 29 | test_losses_local_pos = [] 30 | test_losses_local_neg = [] 31 | test_losses_local_pos1 = [] 32 | test_losses_local_neg1 = [] 33 | for epoch in range(params['epochs']): 34 | model.train() 35 | for i, data in enumerate(train_loader): 36 | optimizer.zero_grad() 37 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 38 | language = language.to(device).float() 39 | point_cloud = point_cloud.to(device).float() 40 | vision_global = vision_global.to(device).float() 41 | vision_local = vision_local.to(device).float() 42 | grasp_map = grasp_map.to(device).float() 43 | pos_index = pos_index.to(device).long() 44 | neg_index = neg_index.to(device).long() 45 | pos_neg_num = pos_neg_num.to(device).long() 46 | 47 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 48 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 49 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 50 | for i in index: 51 | print("invalid data entry path", entry_paths[i]) 52 | print("vision_local mean", vision_local_mean[i]) 53 | print("vision_local", vision_local[i]) 54 | continue 55 | 56 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 57 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 58 | 59 | print(f'Train Epoch {epoch}, Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 60 | if params['n_awl'] == 2: 61 | loss = awl(loss_global, loss_local_pos * params['lambda_pos'] + loss_local_neg) 62 | else: 63 | assert params['n_awl'] == 3 64 | loss = awl(loss_global, loss_local_pos + loss_local_neg, loss_local_pos1 * params['lambda_pos'] + loss_local_neg1) 65 | loss.backward() 66 | optimizer.step() 67 | scheduler.step() 68 | 69 | train_losses_global.append(loss_global.item()) 70 | train_losses_local_pos.append(loss_local_pos.item()) 71 | train_losses_local_neg.append(loss_local_neg.item()) 72 | train_losses_local_pos1.append(loss_local_pos1.item()) 73 | train_losses_local_neg1.append(loss_local_neg1.item()) 74 | 75 | 76 | test_loss_global, test_loss_local_pos, test_loss_local_neg, test_loss_local_pos1, test_loss_local_neg1 = validate(model, val_loader, device) 77 | 78 | test_losses_global += test_loss_global 79 | test_losses_local_pos += test_loss_local_pos 80 | test_losses_local_neg += test_loss_local_neg 81 | test_losses_local_pos1 += test_loss_local_pos1 82 | test_losses_local_neg1 += test_loss_local_neg1 83 | # save_loss(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, model_id=model_id) 84 | save_loss10(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, train_losses_local_pos1=train_losses_local_pos1, train_losses_local_neg1=train_losses_local_neg1, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, test_losses_local_pos1=test_losses_local_pos1, test_losses_local_neg1=test_losses_local_neg1, model_id=model_id) 85 | checkpoint = { 86 | 'model': model.state_dict(), 87 | 'optimizer': optimizer.state_dict(), 88 | 'awl': awl.state_dict(), 89 | 'epoch': epoch 90 | } 91 | checkpoint_path = './checkpoints/model_{}_{}.pth'.format(model_id, epoch) 92 | torch.save(checkpoint, checkpoint_path) 93 | 94 | return train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg 95 | 96 | def validate(model: NetABL, val_loader, device): 97 | model.to(device) 98 | model.eval() 99 | with torch.no_grad(): 100 | test_losses_global = [] 101 | test_losses_local_pos = [] 102 | test_losses_local_neg = [] 103 | test_losses_local_pos1 = [] 104 | test_losses_local_neg1 = [] 105 | for i, data in enumerate(val_loader): 106 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 107 | language = language.to(device).float() 108 | point_cloud = point_cloud.to(device).float() 109 | vision_global = vision_global.to(device).float() 110 | vision_local = vision_local.to(device).float() 111 | grasp_map = grasp_map.to(device).float() 112 | pos_index = pos_index.to(device).long() 113 | neg_index = neg_index.to(device).long() 114 | pos_neg_num = pos_neg_num.to(device).long() 115 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 116 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 117 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 118 | for i in index: 119 | print("invalid data entry path", entry_paths[i]) 120 | print("vision_local mean", vision_local_mean[i]) 121 | print("vision_local", vision_local[i]) 122 | continue 123 | 124 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 125 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 126 | 127 | test_losses_global.append(loss_global.item()) 128 | test_losses_local_pos.append(loss_local_pos.item()) 129 | test_losses_local_neg.append(loss_local_neg.item()) 130 | test_losses_local_pos1.append(loss_local_pos1.item()) 131 | test_losses_local_neg1.append(loss_local_neg1.item()) 132 | print(f'Validate Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 133 | 134 | return test_losses_global, test_losses_local_pos, test_losses_local_neg, test_losses_local_pos1, test_losses_local_neg1 135 | 136 | 137 | def main(params): 138 | train_loader, val_loader, test_loader = get_dataloader(params) 139 | model = NetABL() 140 | device = torch.device('cuda:0') 141 | train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg = train(model, train_loader, val_loader, params, device) 142 | print("finished training") 143 | validate(model, test_loader, device) 144 | # plot(train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg, model_id=params['model_id']) 145 | 146 | if __name__ == '__main__': 147 | argparser = argparse.ArgumentParser() 148 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 149 | argparser.add_argument('--shuffle', type=bool, default=True) 150 | argparser.add_argument('--epochs', type=int, default=40) 151 | argparser.add_argument('--batch_size', type=int, default=128) 152 | argparser.add_argument('--num_workers', type=int, default=16) 153 | argparser.add_argument('--lr', type=float, default=2e-3) 154 | argparser.add_argument('--small_data', default=False, action='store_true') 155 | argparser.add_argument('--delta_v', type=float, default=0.5) 156 | argparser.add_argument('--delta_d', type=float, default=3.0) 157 | argparser.add_argument('--lambda_pos', type=float, default=1.0) 158 | argparser.add_argument('--n_awl', type=int, default=2) 159 | args = argparser.parse_args() 160 | params = vars(args) 161 | params['model_id'] = time.time() 162 | main(params) -------------------------------------------------------------------------------- /model/trainer_ab_vg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | from model.model_ab_vg import NetABVG 5 | import matplotlib.pyplot as plt 6 | from torch.utils.data import DataLoader 7 | import torch.multiprocessing as mp 8 | from model.utils import plot, save_loss, save_loss10 9 | from model.data_utils import get_dataloader 10 | import argparse 11 | import time 12 | import pickle 13 | from AutomaticWeightedLoss.AutomaticWeightedLoss import AutomaticWeightedLoss 14 | 15 | mp.set_start_method('spawn', force=True) 16 | VISION_LOCAL_THRESHOLD = 1000.0 17 | 18 | def train(model: NetABVG, train_loader, val_loader, params, device): 19 | model.to(device) 20 | model_id = params['model_id'] 21 | awl = AutomaticWeightedLoss(params['n_awl']) 22 | optimizer, scheduler = model.get_optimizer(lr=params['lr'], steps_per_epoch=len(train_loader), num_epochs=params['epochs'], awl=awl) 23 | train_losses_global = [] 24 | train_losses_local_pos = [] 25 | train_losses_local_neg = [] 26 | train_losses_local_pos1 = [] 27 | train_losses_local_neg1 = [] 28 | test_losses_global = [] 29 | test_losses_local_pos = [] 30 | test_losses_local_neg = [] 31 | test_losses_local_pos1 = [] 32 | test_losses_local_neg1 = [] 33 | for epoch in range(params['epochs']): 34 | model.train() 35 | for i, data in enumerate(train_loader): 36 | optimizer.zero_grad() 37 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 38 | language = language.to(device).float() 39 | point_cloud = point_cloud.to(device).float() 40 | vision_global = vision_global.to(device).float() 41 | vision_local = vision_local.to(device).float() 42 | grasp_map = grasp_map.to(device).float() 43 | pos_index = pos_index.to(device).long() 44 | neg_index = neg_index.to(device).long() 45 | pos_neg_num = pos_neg_num.to(device).long() 46 | 47 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 48 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 49 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 50 | for i in index: 51 | print("invalid data entry path", entry_paths[i]) 52 | print("vision_local mean", vision_local_mean[i]) 53 | print("vision_local", vision_local[i]) 54 | continue 55 | 56 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 57 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 58 | 59 | print(f'Train Epoch {epoch}, Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 60 | if params['n_awl'] == 2: 61 | loss = awl(loss_global, loss_local_pos * params['lambda_pos'] + loss_local_neg) 62 | else: 63 | assert params['n_awl'] == 3 64 | loss = awl(loss_global, loss_local_pos + loss_local_neg, loss_local_pos1 * params['lambda_pos'] + loss_local_neg1) 65 | loss.backward() 66 | optimizer.step() 67 | scheduler.step() 68 | 69 | train_losses_global.append(loss_global.item()) 70 | train_losses_local_pos.append(loss_local_pos.item()) 71 | train_losses_local_neg.append(loss_local_neg.item()) 72 | train_losses_local_pos1.append(loss_local_pos1.item()) 73 | train_losses_local_neg1.append(loss_local_neg1.item()) 74 | 75 | 76 | test_loss_global, test_loss_local_pos, test_loss_local_neg, test_loss_local_pos1, test_loss_local_neg1 = validate(model, val_loader, device) 77 | 78 | test_losses_global += test_loss_global 79 | test_losses_local_pos += test_loss_local_pos 80 | test_losses_local_neg += test_loss_local_neg 81 | test_losses_local_pos1 += test_loss_local_pos1 82 | test_losses_local_neg1 += test_loss_local_neg1 83 | # save_loss(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, model_id=model_id) 84 | save_loss10(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, train_losses_local_pos1=train_losses_local_pos1, train_losses_local_neg1=train_losses_local_neg1, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, test_losses_local_pos1=test_losses_local_pos1, test_losses_local_neg1=test_losses_local_neg1, model_id=model_id) 85 | checkpoint = { 86 | 'model': model.state_dict(), 87 | 'optimizer': optimizer.state_dict(), 88 | 'awl': awl.state_dict(), 89 | 'epoch': epoch 90 | } 91 | checkpoint_path = './checkpoints/model_{}_{}.pth'.format(model_id, epoch) 92 | torch.save(checkpoint, checkpoint_path) 93 | 94 | return train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg 95 | 96 | def validate(model: NetABVG, val_loader, device): 97 | model.to(device) 98 | model.eval() 99 | with torch.no_grad(): 100 | test_losses_global = [] 101 | test_losses_local_pos = [] 102 | test_losses_local_neg = [] 103 | test_losses_local_pos1 = [] 104 | test_losses_local_neg1 = [] 105 | for i, data in enumerate(val_loader): 106 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 107 | language = language.to(device).float() 108 | point_cloud = point_cloud.to(device).float() 109 | vision_global = vision_global.to(device).float() 110 | vision_local = vision_local.to(device).float() 111 | grasp_map = grasp_map.to(device).float() 112 | pos_index = pos_index.to(device).long() 113 | neg_index = neg_index.to(device).long() 114 | pos_neg_num = pos_neg_num.to(device).long() 115 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 116 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 117 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 118 | for i in index: 119 | print("invalid data entry path", entry_paths[i]) 120 | print("vision_local mean", vision_local_mean[i]) 121 | print("vision_local", vision_local[i]) 122 | continue 123 | 124 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 125 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 126 | 127 | test_losses_global.append(loss_global.item()) 128 | test_losses_local_pos.append(loss_local_pos.item()) 129 | test_losses_local_neg.append(loss_local_neg.item()) 130 | test_losses_local_pos1.append(loss_local_pos1.item()) 131 | test_losses_local_neg1.append(loss_local_neg1.item()) 132 | print(f'Validate Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 133 | 134 | return test_losses_global, test_losses_local_pos, test_losses_local_neg, test_losses_local_pos1, test_losses_local_neg1 135 | 136 | 137 | def main(params): 138 | train_loader, val_loader, test_loader = get_dataloader(params) 139 | model = NetABVG() 140 | device = torch.device('cuda:0') 141 | train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg = train(model, train_loader, val_loader, params, device) 142 | print("finished training") 143 | validate(model, test_loader, device) 144 | # plot(train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg, model_id=params['model_id']) 145 | 146 | if __name__ == '__main__': 147 | argparser = argparse.ArgumentParser() 148 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 149 | argparser.add_argument('--shuffle', type=bool, default=True) 150 | argparser.add_argument('--epochs', type=int, default=40) 151 | argparser.add_argument('--batch_size', type=int, default=128) 152 | argparser.add_argument('--num_workers', type=int, default=16) 153 | argparser.add_argument('--lr', type=float, default=2e-3) 154 | argparser.add_argument('--small_data', default=False, action='store_true') 155 | argparser.add_argument('--delta_v', type=float, default=0.5) 156 | argparser.add_argument('--delta_d', type=float, default=3.0) 157 | argparser.add_argument('--lambda_pos', type=float, default=1.0) 158 | argparser.add_argument('--n_awl', type=int, default=2) 159 | args = argparser.parse_args() 160 | params = vars(args) 161 | params['model_id'] = time.time() 162 | main(params) -------------------------------------------------------------------------------- /model/trainer_ab_vl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataset.LVDataset import LVDataset 4 | from model.model_ab_vl import NetABVL 5 | import matplotlib.pyplot as plt 6 | from torch.utils.data import DataLoader 7 | import torch.multiprocessing as mp 8 | from model.utils import plot, save_loss, save_loss10 9 | from model.data_utils import get_dataloader 10 | import argparse 11 | import time 12 | import pickle 13 | from AutomaticWeightedLoss.AutomaticWeightedLoss import AutomaticWeightedLoss 14 | 15 | mp.set_start_method('spawn', force=True) 16 | VISION_LOCAL_THRESHOLD = 1000.0 17 | 18 | def train(model: NetABVL, train_loader, val_loader, params, device): 19 | model.to(device) 20 | model_id = params['model_id'] 21 | awl = AutomaticWeightedLoss(params['n_awl']) 22 | optimizer, scheduler = model.get_optimizer(lr=params['lr'], steps_per_epoch=len(train_loader), num_epochs=params['epochs'], awl=awl) 23 | train_losses_global = [] 24 | train_losses_local_pos = [] 25 | train_losses_local_neg = [] 26 | train_losses_local_pos1 = [] 27 | train_losses_local_neg1 = [] 28 | test_losses_global = [] 29 | test_losses_local_pos = [] 30 | test_losses_local_neg = [] 31 | test_losses_local_pos1 = [] 32 | test_losses_local_neg1 = [] 33 | for epoch in range(params['epochs']): 34 | model.train() 35 | for i, data in enumerate(train_loader): 36 | optimizer.zero_grad() 37 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 38 | language = language.to(device).float() 39 | point_cloud = point_cloud.to(device).float() 40 | vision_global = vision_global.to(device).float() 41 | vision_local = vision_local.to(device).float() 42 | grasp_map = grasp_map.to(device).float() 43 | pos_index = pos_index.to(device).long() 44 | neg_index = neg_index.to(device).long() 45 | pos_neg_num = pos_neg_num.to(device).long() 46 | 47 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 48 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 49 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 50 | for i in index: 51 | print("invalid data entry path", entry_paths[i]) 52 | print("vision_local mean", vision_local_mean[i]) 53 | print("vision_local", vision_local[i]) 54 | continue 55 | 56 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 57 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 58 | 59 | print(f'Train Epoch {epoch}, Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 60 | if params['n_awl'] == 2: 61 | loss = awl(loss_global, loss_local_pos * params['lambda_pos'] + loss_local_neg) 62 | else: 63 | assert params['n_awl'] == 3 64 | loss = awl(loss_global, loss_local_pos + loss_local_neg, loss_local_pos1 * params['lambda_pos'] + loss_local_neg1) 65 | loss.backward() 66 | optimizer.step() 67 | scheduler.step() 68 | 69 | train_losses_global.append(loss_global.item()) 70 | train_losses_local_pos.append(loss_local_pos.item()) 71 | train_losses_local_neg.append(loss_local_neg.item()) 72 | train_losses_local_pos1.append(loss_local_pos1.item()) 73 | train_losses_local_neg1.append(loss_local_neg1.item()) 74 | 75 | 76 | test_loss_global, test_loss_local_pos, test_loss_local_neg, test_loss_local_pos1, test_loss_local_neg1 = validate(model, val_loader, device) 77 | 78 | test_losses_global += test_loss_global 79 | test_losses_local_pos += test_loss_local_pos 80 | test_losses_local_neg += test_loss_local_neg 81 | test_losses_local_pos1 += test_loss_local_pos1 82 | test_losses_local_neg1 += test_loss_local_neg1 83 | # save_loss(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, model_id=model_id) 84 | save_loss10(train_losses_global=train_losses_global, train_losses_local_pos=train_losses_local_pos, train_losses_local_neg=train_losses_local_neg, train_losses_local_pos1=train_losses_local_pos1, train_losses_local_neg1=train_losses_local_neg1, test_losses_global=test_losses_global, test_losses_local_pos=test_losses_local_pos, test_losses_local_neg=test_losses_local_neg, test_losses_local_pos1=test_losses_local_pos1, test_losses_local_neg1=test_losses_local_neg1, model_id=model_id) 85 | checkpoint = { 86 | 'model': model.state_dict(), 87 | 'optimizer': optimizer.state_dict(), 88 | 'awl': awl.state_dict(), 89 | 'epoch': epoch 90 | } 91 | checkpoint_path = './checkpoints/model_{}_{}.pth'.format(model_id, epoch) 92 | torch.save(checkpoint, checkpoint_path) 93 | 94 | return train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg 95 | 96 | def validate(model: NetABVL, val_loader, device): 97 | model.to(device) 98 | model.eval() 99 | with torch.no_grad(): 100 | test_losses_global = [] 101 | test_losses_local_pos = [] 102 | test_losses_local_neg = [] 103 | test_losses_local_pos1 = [] 104 | test_losses_local_neg1 = [] 105 | for i, data in enumerate(val_loader): 106 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 107 | language = language.to(device).float() 108 | point_cloud = point_cloud.to(device).float() 109 | vision_global = vision_global.to(device).float() 110 | vision_local = vision_local.to(device).float() 111 | grasp_map = grasp_map.to(device).float() 112 | pos_index = pos_index.to(device).long() 113 | neg_index = neg_index.to(device).long() 114 | pos_neg_num = pos_neg_num.to(device).long() 115 | vision_local_mean = torch.mean(vision_local, dim=(1,2)) 116 | if torch.any(vision_local_mean > VISION_LOCAL_THRESHOLD): 117 | index = torch.nonzero(vision_local_mean > VISION_LOCAL_THRESHOLD) 118 | for i in index: 119 | print("invalid data entry path", entry_paths[i]) 120 | print("vision_local mean", vision_local_mean[i]) 121 | print("vision_local", vision_local[i]) 122 | continue 123 | 124 | output_global, output_local = model(language, point_cloud, vision_global, vision_local) 125 | loss_global, loss_local_pos, loss_local_neg, loss_local_pos1, loss_local_neg1 = model.get_loss(output_global, output_local, grasp_map, pos_index, neg_index, pos_neg_num, params['delta_v'], params['delta_d']) 126 | 127 | test_losses_global.append(loss_global.item()) 128 | test_losses_local_pos.append(loss_local_pos.item()) 129 | test_losses_local_neg.append(loss_local_neg.item()) 130 | test_losses_local_pos1.append(loss_local_pos1.item()) 131 | test_losses_local_neg1.append(loss_local_neg1.item()) 132 | print(f'Validate Batch {i}, Loss Global {loss_global.item()}, Loss Local Pos {loss_local_pos.item()}, Loss Local Neg {loss_local_neg.item()}, Loss Local Pos1 {loss_local_pos1.item()}, Loss Local Neg1 {loss_local_neg1.item()}') 133 | 134 | return test_losses_global, test_losses_local_pos, test_losses_local_neg, test_losses_local_pos1, test_losses_local_neg1 135 | 136 | 137 | def main(params): 138 | train_loader, val_loader, test_loader = get_dataloader(params) 139 | model = NetABVL() 140 | device = torch.device('cuda:0') 141 | train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg = train(model, train_loader, val_loader, params, device) 142 | print("finished training") 143 | validate(model, test_loader, device) 144 | # plot(train_losses_global, train_losses_local_pos, train_losses_local_neg, test_losses_global, test_losses_local_pos, test_losses_local_neg, model_id=params['model_id']) 145 | 146 | if __name__ == '__main__': 147 | argparser = argparse.ArgumentParser() 148 | argparser.add_argument('--dataset_dir', type=str, default='./data/objects/') 149 | argparser.add_argument('--shuffle', type=bool, default=True) 150 | argparser.add_argument('--epochs', type=int, default=40) 151 | argparser.add_argument('--batch_size', type=int, default=128) 152 | argparser.add_argument('--num_workers', type=int, default=16) 153 | argparser.add_argument('--lr', type=float, default=2e-3) 154 | argparser.add_argument('--small_data', default=False, action='store_true') 155 | argparser.add_argument('--delta_v', type=float, default=0.5) 156 | argparser.add_argument('--delta_d', type=float, default=3.0) 157 | argparser.add_argument('--lambda_pos', type=float, default=1.0) 158 | argparser.add_argument('--n_awl', type=int, default=2) 159 | args = argparser.parse_args() 160 | params = vars(args) 161 | params['model_id'] = time.time() 162 | main(params) -------------------------------------------------------------------------------- /scripts/data_exam.py: -------------------------------------------------------------------------------- 1 | from dataset.LVDataset import LVDataset 2 | from model.data_utils import get_dataloader 3 | import pickle 4 | import argparse 5 | 6 | def load_entry(entry_path): 7 | with open(entry_path, 'rb') as f: 8 | entry = pickle.load(f) 9 | language = entry.language 10 | language_feature = entry.language_feature 11 | point_cloud = entry.point_cloud 12 | vision_global = entry.global_feature 13 | vision_local = entry.local_feature.T 14 | grasp_map = entry.grasp_map 15 | pos_index = entry.pos_index 16 | neg_index = entry.neg_index 17 | lanugalge_feature_15 = entry.language_feature_15 18 | # print(point_cloud.shape, vision_global.shape, vision_local.shape, grasp_map.shape, pos_index.shape, neg_index.shape) 19 | # print(type(pos_index), pos_index.shape, type(neg_index), neg_index.shape, type(vision_local), vision_local.shape) 20 | # print(pos_index) 21 | # print(lanugalge_feature_15.shape) 22 | print(language) 23 | 24 | print(entry_path) 25 | return language 26 | 27 | 28 | 29 | # language = load_entry("./data/objects/40855/40855_1704580390.4439733_v1.pkl") 30 | if __name__ == '__main__': 31 | argparse = argparse.ArgumentParser() 32 | argparse.add_argument('--dataset_dir', type=str, default='./data/objects/') 33 | argparse.add_argument('--shuffle', type=bool, default=False) 34 | argparse.add_argument('--batch_size', type=int, default=128) 35 | argparse.add_argument('--num_workers', type=int, default=16) 36 | args = argparse.parse_args() 37 | params = vars(args) 38 | _, _, test_loader = get_dataloader(params) 39 | for i, data in enumerate(test_loader): 40 | language, point_cloud, vision_global, vision_local, grasp_map, pos_index, neg_index, pos_neg_num, entry_paths = data 41 | for j, entry_path in enumerate(entry_paths): 42 | language = load_entry(entry_path) 43 | if 'display' in language: 44 | print(i, j) 45 | exit() 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /scripts/heatmap.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import pickle 3 | import numpy as np 4 | import argparse 5 | import sys 6 | sys.path.append("/home/gdk/Repositories/DualArmManipulation") 7 | from model.grasp_utils import show_heatmap, show_embedding_map, show_score_map 8 | 9 | def main(params): 10 | filename = './checkpoints/maps_{}_{}.pkl'.format(params['model_id'], params['batch_id']) 11 | 12 | maps = pickle.load(open(filename, 'rb')) 13 | points = maps['point_cloud'] 14 | grasp_map = maps['grasp_map'] 15 | prediction = maps['prediction'] 16 | entry_path = maps['entry_path'] 17 | embeddings = maps['embeddings'] 18 | # best_index = maps['index1'] 19 | index20 = maps['index20'] 20 | pos = maps['pos'] 21 | score = maps['score'] 22 | print('point_cloud shape: ', points.shape, 'grasp_map shape: ', grasp_map.shape, 'prediction shape: ', prediction.shape, 'entry_path len: ', len(entry_path)) 23 | np.set_printoptions(threshold=np.inf) 24 | 25 | for id in range(len(entry_path)): 26 | print('entry_path: ', entry_path[id]) 27 | with open(entry_path[id], 'rb') as f: 28 | entry = pickle.load(f) 29 | print(entry.language) 30 | show_heatmap(points[id], prediction[id]) 31 | 32 | if __name__ == '__main__': 33 | argparser = argparse.ArgumentParser() 34 | argparser.add_argument('--model_id', type=str, default='1706613034.3918543') 35 | argparser.add_argument('--batch_id', type=int, default=1) 36 | args = argparser.parse_args() 37 | params = vars(args) 38 | main(params) -------------------------------------------------------------------------------- /scripts/plot_affordance_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import trimesh 5 | import numpy as np 6 | 7 | from dataset.Dataset import Dataset 8 | from evaluation.affordance import map_convert 9 | 10 | 11 | def show_heatmap(points, grasp_map, mesh=None): 12 | if mesh is not None: 13 | dists = np.linalg.norm(mesh.vertices[:, np.newaxis] - points, axis=-1, ord=2) 14 | closest_points = np.argmin(dists, axis=1) 15 | colors = grasp_map[closest_points] 16 | mesh.visual.vertex_colors = trimesh.visual.color.interpolate(colors, color_map='hot') 17 | mesh.show() 18 | else: 19 | pcd = trimesh.PointCloud(points) 20 | pcd.visual.vertex_colors = trimesh.visual.color.interpolate(grasp_map, color_map='hot') 21 | pcd.visual.vertex_colors[:, 3] = 0.8 * 255 22 | scene_list = [pcd] 23 | trimesh.Scene(scene_list).show() 24 | 25 | 26 | def show_score_map(pid, points, score, mesh=None): 27 | pcd = trimesh.PointCloud(points) 28 | pid2 = np.argmax(score) 29 | score_max = np.max(score) 30 | score_min = np.min(score) 31 | score = (score - score_min) / (score_max - score_min) 32 | score = np.square(score) 33 | 34 | ball = trimesh.creation.uv_sphere(radius=0.05) 35 | ball.visual.vertex_colors = [255., 0., 0., 0.8 * 255] 36 | ball.apply_translation(points[pid]) 37 | ball2 = trimesh.creation.uv_sphere(radius=0.05) 38 | ball2.visual.vertex_colors = [255., 255., 0., 0.8 * 255] 39 | ball2.apply_translation(points[pid2]) 40 | 41 | if mesh is None: 42 | pcd.visual.vertex_colors = trimesh.visual.color.interpolate(score, color_map='cividis') 43 | pcd.visual.vertex_colors[:, 3] = 1.0 * 255 44 | scene_list = [pcd, ball, ball2] 45 | else: 46 | dists = np.linalg.norm(mesh.vertices[:, np.newaxis] - points, axis=-1, ord=2) 47 | closest_points = np.argmin(dists, axis=1) 48 | second_closest_points = np.argsort(dists, axis=1)[:, 1] 49 | colors = (score[closest_points] + score[second_closest_points]) / 2 50 | mesh.visual.vertex_colors = trimesh.visual.color.interpolate(colors, color_map='cividis') 51 | mesh.visual.vertex_colors[:, 3] = 0.8 * 255 52 | scene_list = [mesh, ball, ball2] 53 | 54 | trimesh.Scene(scene_list).show() 55 | 56 | def find_predition(entry_path): 57 | for map in maps: 58 | entry_paths = map['entry_path'] 59 | if entry_path in entry_paths: 60 | id_in_map = entry_paths.index(entry_path) 61 | points = map['point_cloud'][id_in_map] 62 | grasp_map = map['prediction'][id_in_map] 63 | score = map['score'][id_in_map] 64 | return points, grasp_map, score 65 | 66 | def find_config(object_id): 67 | found_entry_paths = [] 68 | for map in maps: 69 | entry_paths = map['entry_path'] 70 | for entry_path in entry_paths: 71 | object_id_in_entry_path = entry_path.split('/')[-2] 72 | if object_id == object_id_in_entry_path: 73 | found_entry_paths.append(entry_path) 74 | return found_entry_paths 75 | 76 | 77 | if __name__ == '__main__': 78 | # entry_paths = ['./data/objects/4931/4931_1704412294.0766087_v1.pkl'] 79 | # entry_paths = ['./data/objects/1168/1168_1704368664.9076674_v1.pkl'] 80 | 81 | model_id = '1706605305.8925593' 82 | # model_id = '1706613034.3918543' 83 | maps = [] 84 | all_entry_paths = [] 85 | for batch_id in range(80): 86 | filename = './checkpoints/maps_{}_{}.pkl'.format(model_id, batch_id) 87 | if not os.path.exists(filename): 88 | continue 89 | maps.append(pickle.load(open(filename, 'rb'))) 90 | all_entry_paths += maps[-1]['entry_path'] 91 | 92 | print(len(all_entry_paths)) 93 | 94 | dataset = Dataset('./data/objects') 95 | 96 | # vgn_maps = pickle.load(open('./data_vgn_map.pickle', 'rb')) 97 | 98 | # object_id = '4935' 99 | # meshes = dataset[object_id].load_meshes() 100 | # for mesh in meshes: 101 | # mesh.visual.vertex_colors = np.random.randint(0, 255, size=4, dtype=np.uint8) 102 | # mesh.visual.vertex_colors[:, 3] = 255 103 | # mesh = trimesh.util.concatenate(meshes) 104 | # 105 | # # find other config 106 | # found_entry_paths = find_config(object_id) 107 | # for entry_path in found_entry_paths: 108 | # print(entry_path) 109 | # points, grasp_map, score = find_predition(entry_path) 110 | # with open(entry_path, 'rb') as f: 111 | # entry = pickle.load(f) 112 | # show_heatmap(points, grasp_map, mesh) 113 | # show_score_map(0, points, score[0], mesh) 114 | # print(entry.language) 115 | # 116 | # exit() 117 | 118 | found_entry_paths = find_config('10782') 119 | 120 | for entry_path in found_entry_paths: 121 | # for entry_path in all_entry_paths[60:]: 122 | # for entry_path in entry_paths: 123 | object_id = entry_path.split('/')[-2] 124 | print(object_id) 125 | 126 | meshes = dataset[object_id].load_meshes() 127 | for mesh in meshes: 128 | mesh.visual.vertex_colors = np.random.randint(0, 255, size=4, dtype=np.uint8) 129 | mesh.visual.vertex_colors[:, 3] = 255 130 | mesh = trimesh.util.concatenate(meshes) 131 | # mesh.show() 132 | 133 | points, grasp_map, score = find_predition(entry_path) 134 | entry = pickle.load(open(entry_path, 'rb')) 135 | print(entry.language) 136 | 137 | gt_map = entry.grasp_map 138 | show_heatmap(points, gt_map, mesh) 139 | 140 | print(gt_map - grasp_map) 141 | 142 | # pcd = trimesh.PointCloud(points) 143 | # pcd.show() 144 | 145 | # if object_id in vgn_maps: 146 | # print('VGN map found') 147 | # grasp_map = map_convert(vgn_maps[object_id], points) 148 | # show_heatmap(points, grasp_map, mesh) 149 | # continue 150 | 151 | 152 | # show_heatmap(points, grasp_map, mesh) 153 | for k in range(20): 154 | show_score_map(k, points, score[k], mesh) 155 | # show_score_map(10, points, score[10], mesh) 156 | # 157 | # # find other config 158 | # found_entry_paths = find_config(object_id) 159 | # for entry_path in found_entry_paths: 160 | # print(entry_path) 161 | # points, grasp_map, score = find_predition(entry_path) 162 | # with open(entry_path, 'rb') as f: 163 | # entry = pickle.load(f) 164 | # show_heatmap(points, grasp_map, mesh) 165 | # show_score_map(0, points, score[0], mesh) 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /vision_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dkguo/PhyGrasp/7ed7af0b1406ef95cc6b1d4a2513bc469a7f3f59/vision_encoder/__init__.py -------------------------------------------------------------------------------- /vision_encoder/demo_extract_vision_feature.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | import trimesh 7 | 8 | from openpoints.models import build_model_from_cfg 9 | from openpoints.utils import EasyConfig, load_checkpoint 10 | 11 | if __name__ == '__main__': 12 | # conda activate iae 13 | # cd vision_encoder 14 | # CUDA_VISIBLE_DEVICES=0 python demo_extract_vision_feature.py 15 | 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | cfg = EasyConfig() 19 | cfg_path = './shapenetpart_pointnext-s_c64.yaml' 20 | pretrain_path = './shapenetpart-train-pointnext-s_c64-ngpus4-seed7798-20220822-024210-ZcJ8JwCgc7yysEBWzkyAaE_ckpt_best.pth' 21 | cfg.load(cfg_path, recursive=True) 22 | seg_model = build_model_from_cfg(cfg.model).cuda() 23 | load_checkpoint(seg_model, pretrained_path=pretrain_path) 24 | seg_model.eval() 25 | 26 | cfg = EasyConfig() 27 | cfg_path = './modelnet40_pointnext-s.yaml' 28 | pretrain_path = './modelnet40ply2048-train-pointnext-s-ngpus1-seed6848-model.encoder_args.width=64-20220525-145053-7tGhBV9xR9yQEBtN4GPcSc_ckpt_best.pth' 29 | cfg.load(cfg_path, recursive=True) 30 | clf_model = build_model_from_cfg(cfg.model).cuda() 31 | load_checkpoint(clf_model, pretrained_path=pretrain_path) 32 | clf_model.eval() 33 | 34 | objects_path = '/home/gdk/Repositories/DualArmManipulation/demo/demo_objects' 35 | 36 | object_names = ['banana', 'monitor', 'pill_bottle', 'plastic_hammer', 'hammer'] 37 | 38 | object_name = object_names[4] 39 | 40 | mesh = trimesh.load(f'{objects_path}/{object_name}/{object_name}.obj') 41 | 42 | # rescale to [-1, 1] box 43 | rescale = max(mesh.extents) / 2. 44 | tform = [ 45 | -(mesh.bounds[1][i] + mesh.bounds[0][i]) / 2. 46 | for i in range(3) 47 | ] 48 | matrix = np.eye(4) 49 | matrix[:3, 3] = tform 50 | # mesh.apply_transform(matrix) 51 | transform = matrix 52 | matrix = np.eye(4) 53 | matrix[:3, :3] /= rescale 54 | transform = np.dot(matrix, transform) 55 | # mesh.apply_transform(matrix) 56 | mesh.apply_transform(transform) 57 | 58 | points, idx = trimesh.sample.sample_surface(mesh, 2048) 59 | normals = mesh.face_normals[idx] 60 | 61 | heights = points[:, 2] - points[:, 2].min() 62 | pos = [points] 63 | x = [np.concatenate([points, normals, heights[:, np.newaxis]], axis=1).T] 64 | 65 | pos = torch.Tensor(np.array(pos)).cuda().contiguous() 66 | x = torch.Tensor(np.array(x)).cuda().contiguous() 67 | 68 | print(pos.shape, x.shape) 69 | 70 | inp = {'pos': pos, 71 | 'x': x, 72 | 'cls': torch.zeros(1, 16).long().cuda(), 73 | } 74 | 75 | local_features = seg_model(inp).detach().cpu().numpy() 76 | global_features = clf_model(inp['pos']).detach().cpu().numpy() 77 | 78 | print(local_features.shape, global_features.shape) 79 | 80 | print(local_features[0, :, :4]) 81 | print(global_features[0]) 82 | 83 | vision_features = { 84 | 'points': points, 85 | 'local_features': local_features, 86 | 'global_features': global_features, 87 | 'transform': transform, 88 | } 89 | 90 | pickle.dump(vision_features, open(f'{objects_path}/{object_name}/vision_features.pkl', 'wb')) -------------------------------------------------------------------------------- /vision_encoder/modelnet40_pointnext-s.yaml: -------------------------------------------------------------------------------- 1 | # GFLOPs GMACs Params.(M) 2 | # 1.64 0.81 1.374 3 | 4 | # C=64 5 | # GFLOPs GMACs Params.(M) 6 | # 6.49 3.23 4.523 7 | # Throughput (ins./s): 2032.9397323777052 8 | 9 | model: 10 | NAME: BaseCls 11 | encoder_args: 12 | NAME: PointNextEncoder 13 | blocks: [1, 1, 1, 1, 1, 1] 14 | strides: [1, 2, 2, 2, 2, 1] 15 | width: 64 16 | in_channels: 3 17 | radius: 0.15 18 | radius_scaling: 1.5 19 | sa_layers: 2 20 | sa_use_res: True 21 | nsample: 32 22 | expansion: 4 23 | aggr_args: 24 | feature_type: 'dp_fj' 25 | reduction: 'max' 26 | group_args: 27 | NAME: 'ballquery' 28 | normalize_dp: True 29 | conv_args: 30 | order: conv-norm-act 31 | act_args: 32 | act: 'relu' 33 | norm_args: 34 | norm: 'bn' 35 | # cls_args: 36 | # NAME: ClsHead 37 | # num_classes: 40 38 | # mlps: [512, 256] 39 | # norm_args: 40 | # norm: 'bn1d' -------------------------------------------------------------------------------- /vision_encoder/modelnet40ply2048-train-pointnext-s-ngpus1-seed6848-model.encoder_args.width=64-20220525-145053-7tGhBV9xR9yQEBtN4GPcSc_ckpt_best.pth: -------------------------------------------------------------------------------- 1 | /home/gdk/Data/bimanual/modelnet40ply2048-train-pointnext-s-ngpus1-seed6848-model.encoder_args.width=64-20220525-145053-7tGhBV9xR9yQEBtN4GPcSc_ckpt_best.pth -------------------------------------------------------------------------------- /vision_encoder/shapenetpart-train-pointnext-s_c64-ngpus4-seed7798-20220822-024210-ZcJ8JwCgc7yysEBWzkyAaE_ckpt_best.pth: -------------------------------------------------------------------------------- 1 | /home/gdk/Data/bimanual/shapenetpart-train-pointnext-s_c64-ngpus4-seed7798-20220822-024210-ZcJ8JwCgc7yysEBWzkyAaE_ckpt_best.pth -------------------------------------------------------------------------------- /vision_encoder/shapenetpart_pointnext-s.yaml: -------------------------------------------------------------------------------- 1 | # ===>loading from cfgs/shapenetpart/pointnext-s.yaml 2 | # Number of params: 0.9817 M 3 | # test input size: ((torch.Size([1, 2048, 3]), torch.Size([1, 3, 2048]))) 4 | # Batches npoints Params.(M) GFLOPs 5 | # 64 2048 0.982 4.52 6 | model: 7 | NAME: BasePartSeg 8 | encoder_args: 9 | NAME: PointNextEncoder 10 | blocks: [ 1, 1, 1, 1, 1 ] # 1, 1, 1, 2, 1 is better, but not the main focus of this paper 11 | strides: [ 1, 2, 2, 2, 2 ] 12 | width: 32 13 | in_channels: 7 # better than 4,6 14 | sa_layers: 3 # better than 2 15 | sa_use_res: True 16 | radius: 0.1 17 | radius_scaling: 2.5 18 | nsample: 32 # will not improve performance. 19 | expansion: 4 20 | aggr_args: 21 | feature_type: 'dp_fj' 22 | reduction: 'max' 23 | group_args: 24 | NAME: 'ballquery' 25 | normalize_dp: True 26 | conv_args: 27 | order: conv-norm-act 28 | act_args: 29 | act: 'relu' # leakrelu makes training unstable. 30 | norm_args: 31 | norm: 'bn' # ln makes training unstable 32 | decoder_args: 33 | NAME: PointNextPartDecoder 34 | cls_map: curvenet 35 | # cls_args: 36 | # NAME: SegHead 37 | # global_feat: max,avg # apped global feature to each point feature 38 | # num_classes: 50 39 | # in_channels: null 40 | # norm_args: 41 | # norm: 'bn' 42 | 43 | 44 | # ---------------------------------------------------------------------------- # 45 | # Training cfgs 46 | # ---------------------------------------------------------------------------- # 47 | lr: 0.001 48 | min_lr: null 49 | optimizer: 50 | NAME: adamw 51 | weight_decay: 1.0e-4 # the best 52 | 53 | criterion_args: 54 | NAME: Poly1FocalLoss 55 | 56 | # scheduler 57 | epochs: 300 58 | sched: multistep 59 | decay_epochs: [210, 270] 60 | decay_rate: 0.1 61 | warmup_epochs: 0 62 | 63 | datatransforms: 64 | train: [PointsToTensor, PointCloudScaling,PointCloudCenterAndNormalize,PointCloudJitter,ChromaticDropGPU] 65 | val: [PointsToTensor, PointCloudCenterAndNormalize] 66 | kwargs: 67 | jitter_sigma: 0.001 68 | jitter_clip: 0.005 69 | scale: [0.8, 1.2] 70 | gravity_dim: 1 71 | angle: [0, 1.0, 0] -------------------------------------------------------------------------------- /vision_encoder/shapenetpart_pointnext-s_c160.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0 python examples/profile.py --cfg cfgs/shapenetpart/pointnext-s_c160.yaml batch_size=64 num_points=2048 timing=True flops=True 2 | # ------------ 3 | # ===>loading from cfgs/shapenetpart/pointnext-s_c160.yaml 4 | # Number of params: 22.4999 M 5 | # test input size: ((torch.Size([1, 2048, 3]), torch.Size([1, 3, 2048]))) 6 | # Batches npoints Params.(M) GFLOPs 7 | # 64 2048 22.500 110.18 8 | # Throughput (ins./s): 76.0178306799044 9 | 10 | model: 11 | NAME: BasePartSeg 12 | encoder_args: 13 | NAME: PointNextEncoder 14 | blocks: [ 1, 1, 1, 1, 1 ] # 1, 1, 1, 2, 1 is better, but not the main focus of this paper 15 | strides: [ 1, 2, 2, 2, 2 ] 16 | width: 160 17 | in_channels: 7 # better than 4,6 18 | sa_layers: 3 # better than 2 19 | sa_use_res: True 20 | radius: 0.1 21 | radius_scaling: 2.5 22 | nsample: 32 # will not improve performance. 23 | expansion: 4 24 | aggr_args: 25 | feature_type: 'dp_fj' 26 | reduction: 'max' 27 | group_args: 28 | NAME: 'ballquery' 29 | normalize_dp: True 30 | conv_args: 31 | order: conv-norm-act 32 | act_args: 33 | act: 'relu' # leakrelu makes training unstable. 34 | norm_args: 35 | norm: 'bn' # ln makes training unstable 36 | decoder_args: 37 | NAME: PointNextPartDecoder 38 | cls_map: curvenet 39 | # cls_args: 40 | # NAME: SegHead 41 | # global_feat: max,avg # append global feature to each point feature 42 | # num_classes: 50 43 | # in_channels: null 44 | # norm_args: 45 | # norm: 'bn' 46 | 47 | 48 | # ---------------------------------------------------------------------------- # 49 | # Training cfgs 50 | # ---------------------------------------------------------------------------- # 51 | lr: 0.001 52 | min_lr: null 53 | optimizer: 54 | NAME: adamw 55 | weight_decay: 1.0e-4 # the best 56 | 57 | criterion_args: 58 | NAME: Poly1FocalLoss 59 | 60 | # scheduler 61 | epochs: 300 62 | sched: multistep 63 | decay_epochs: [210, 270] 64 | decay_rate: 0.1 65 | warmup_epochs: 0 66 | 67 | datatransforms: 68 | train: [PointsToTensor, PointCloudScaling,PointCloudCenterAndNormalize,PointCloudJitter,ChromaticDropGPU] 69 | val: [PointsToTensor, PointCloudCenterAndNormalize] 70 | kwargs: 71 | jitter_sigma: 0.001 72 | jitter_clip: 0.005 73 | scale: [0.8, 1.2] 74 | gravity_dim: 1 75 | angle: [0, 1.0, 0] 76 | 77 | -------------------------------------------------------------------------------- /vision_encoder/shapenetpart_pointnext-s_c64.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0 python examples/profile.py --cfg cfgs/shapenetpart/pointnext-s_c64.yaml batch_size=64 num_points=2048 timing=True flops=True 2 | # Batches npoints Params.(M) GFLOPs 3 | # 64 2048 3.722 17.80 4 | # Throughput (ins./s): 330.9890643832901 5 | 6 | model: 7 | NAME: BasePartSeg 8 | encoder_args: 9 | NAME: PointNextEncoder 10 | blocks: [ 1, 1, 1, 1, 1 ] # 1, 1, 1, 2, 1 is better, but not the main focus of this paper 11 | strides: [ 1, 2, 2, 2, 2 ] 12 | width: 64 13 | in_channels: 7 # better than 4,6 14 | sa_layers: 3 # better than 2 15 | sa_use_res: True 16 | radius: 0.1 17 | radius_scaling: 2.5 18 | nsample: 32 # will not improve performance. 19 | expansion: 4 20 | aggr_args: 21 | feature_type: 'dp_fj' 22 | reduction: 'max' 23 | group_args: 24 | NAME: 'ballquery' 25 | normalize_dp: True 26 | conv_args: 27 | order: conv-norm-act 28 | act_args: 29 | act: 'relu' # leakrelu makes training unstable. 30 | norm_args: 31 | norm: 'bn' # ln makes training unstable 32 | decoder_args: 33 | NAME: PointNextPartDecoder 34 | cls_map: curvenet 35 | # cls_args: 36 | # NAME: SegHead 37 | # global_feat: max,avg # append global feature to each point feature 38 | # num_classes: 50 39 | # in_channels: null 40 | # norm_args: 41 | # norm: 'bn' 42 | 43 | # ---------------------------------------------------------------------------- # 44 | # Training cfgs 45 | # ---------------------------------------------------------------------------- # 46 | lr: 0.001 47 | min_lr: null 48 | optimizer: 49 | NAME: adamw 50 | weight_decay: 1.0e-4 # the best 51 | 52 | criterion_args: 53 | NAME: Poly1FocalLoss 54 | 55 | # scheduler 56 | epochs: 300 57 | sched: multistep 58 | decay_epochs: [210, 270] 59 | decay_rate: 0.1 60 | warmup_epochs: 0 61 | 62 | datatransforms: 63 | train: [PointsToTensor, PointCloudScaling,PointCloudCenterAndNormalize,PointCloudJitter,ChromaticDropGPU] 64 | val: [PointsToTensor, PointCloudCenterAndNormalize] 65 | kwargs: 66 | jitter_sigma: 0.001 67 | jitter_clip: 0.005 68 | scale: [0.8, 1.2] 69 | gravity_dim: 1 70 | angle: [0, 1.0, 0] --------------------------------------------------------------------------------