├── .gitignore ├── README.md ├── config └── Denoiser.py ├── dataset ├── Grasp6DDataset.py └── __init__.py ├── demo └── intro.png ├── eval.py ├── generate.py ├── models ├── __init__.py ├── denoiser.py ├── loss.py ├── noise_predictor.py ├── pointnet_utils.py └── utils.py ├── requirements.txt ├── robot_exp.py ├── train.py ├── utils ├── __init__.py ├── builder.py ├── config_utils.py ├── test_utils.py └── trainer.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | __pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Language-Driven 6-DoF Grasp Detection Using Negative Prompt Guidance 4 | 5 | [![Paper](https://img.shields.io/badge/Paper-arxiv.2407.13842-FF6B6B.svg)](https://arxiv.org/abs/2407.13842) 6 | [![Homepage](https://img.shields.io/badge/Homepage-Grasp--Anything_Project-5FF66B.svg)](https://airvlab.github.io/grasp-anything/) 7 | 8 |

ECCV 2024 Oral

9 | 10 | 11 |
12 | We address the task of language-driven 6-DoF grasp detection in cluttered point clouds. We introduce a novel diffusion model incorporating the new concept of negative prompt guidance learning. Our proposed negative prompt guidance assists in tackling the fine-grained challenge of the language-driven grasp detection task, directing the detection process toward the desired object by steering away from undesired ones. 13 |
14 | 15 |
16 | 17 | 18 | ## 1. Setup 19 | Create new CONDA environment and install necessary packages 20 | 21 | conda create -n l6gd python=3.9 22 | conda activate l6gd 23 | conda install pip 24 | pip install -r requirements.txt 25 | 26 | ## 2. Download Grasp-Anything-6D dataset 27 | You can request for our HuggingFace dataset at [our project page](https://airvlab.github.io/grasp-anything/). 28 | 29 | ## 3. Training 30 | To start training the model, run 31 | 32 | python3 train.py --config 33 | Config files are stored in `./config`. Remember to change `dataset_path` in the config files after downloading the dataset. After training, log files and model weights will be saved to `./log`. 34 | 35 | ## 4. Detecting grasps 36 | To detect grasp for test data, run 37 | 38 | python3 generate.py --config --checkpoint --data_path --n_sample 64 39 | 40 | The detected grasp poses will be saved to an ```all_data.pkl``` file in the corresponding log directory. 41 | 42 | ## 5. Evaluation 43 | For evaluation, excecute 44 | 45 | python eval.py --data 46 | 47 | where `` is the path to the file `all_data.pkl` generated after the step of detecting grasps. 48 | 49 | ## 6. Citation 50 | If you find our work interesting or helpful for your research, please consider citing our paper as 51 | 52 | @inproceedings{nguyen2024language, 53 | title={Language-driven 6-dof grasp detection using negative prompt guidance}, 54 | author={Nguyen, Toan and Vu, Minh Nhat and Huang, Baoru and Vuong, An and Vuong, Quan and Le, Ngan and Vo, Thieu and Nguyen, Anh}, 55 | booktitle={ECCV}, 56 | year={2024} 57 | } 58 | -------------------------------------------------------------------------------- /config/Denoiser.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | exp_name = "Denoiser" 5 | seed = 1 6 | log_dir = os.path.join("./log/", exp_name) 7 | try: 8 | os.makedirs(log_dir) 9 | except: 10 | print("Logging Dir is already existed!") 11 | 12 | optimizer = dict( 13 | type="adam", 14 | lr=0.001, 15 | betas=(0.9, 0.999), 16 | eps=1e-08, 17 | weight_decay=1e-4, 18 | ) 19 | 20 | model = dict( 21 | type="Denoiser", 22 | betas=[1e-4, 2e-2], 23 | n_T=200, 24 | drop_prob=0.1, 25 | ) 26 | 27 | training_cfg = dict( 28 | model=model, 29 | batch_size=128, 30 | epoch=200, 31 | gamma=0.9, # used by the loss function 32 | workflow=dict( 33 | train=1, 34 | ), 35 | ) 36 | 37 | data = dict( 38 | dataset_path="/cm/shared/toannt28/grasp-anything", 39 | num_neg_prompts=4 40 | ) -------------------------------------------------------------------------------- /dataset/Grasp6DDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from pytorchse3.se3 import se3_log_map 7 | 8 | MAX_WIDTH = 0.202 # maximum width of gripper 2F-140 9 | 10 | 11 | class Grasp6DDataset_Train(Dataset): 12 | """ 13 | Data loading class for training. 14 | """ 15 | def __init__(self, dataset_path: str, num_neg_prompts=4): 16 | """ 17 | dataset_path (str): path to the dataset 18 | num_neg_prompts: number of negative prompts used in training 19 | """ 20 | super().__init__() 21 | self.dataset_path = dataset_path 22 | self.num_neg_prompts = num_neg_prompts 23 | self._load() 24 | 25 | def _load(self): 26 | self.all_data = [] 27 | filenames = sorted(os.listdir(f"{self.dataset_path}/pc")) 28 | 29 | print("Processing dataset for training!") 30 | filenames = filenames[:int(len(filenames)*4/5)] # 80% scenes for training 31 | 32 | for filename in filenames: 33 | scene, _ = os.path.splitext(filename) 34 | pc = np.load(f"{self.dataset_path}/pc/{scene}.npy") 35 | try: 36 | with open(f"{self.dataset_path}/grasp_prompt/{scene}.pkl", "rb") as f: 37 | prompts = pickle.load(f) 38 | except: 39 | continue 40 | num_objects = len(prompts) 41 | for i in range(num_objects): 42 | try: 43 | with open(f"{self.dataset_path}/grasp/{scene}_{i}.pkl", "rb") as f: 44 | Rts, ws = pickle.load(f) 45 | except: 46 | continue 47 | pos_prompt = prompts[i] # positive prompt 48 | neg_prompts = prompts[:i] + prompts[i + 1:] # negative prompts 49 | real_num_neg_prompts = len(neg_prompts) 50 | if 0 < real_num_neg_prompts < self.num_neg_prompts: 51 | neg_prompts = neg_prompts + [neg_prompts[-1]] * (self.num_neg_prompts - real_num_neg_prompts) # pad with last negative prompt 52 | elif real_num_neg_prompts == 0: # if no negative prompt 53 | neg_prompts = [""] * self.num_neg_prompts # then use empty strings 54 | else: # if the real number of negative prompts exceeeds self.num_neg_prompts 55 | neg_text = neg_text[:self.num_neg_text] 56 | 57 | self.all_data.extend([{"scene": scene, "pc": pc, "pos_prompt": pos_prompt, "neg_prompts": neg_prompts,\ 58 | "Rt": Rt, "w": 2*w/MAX_WIDTH-1.0} for Rt, w in zip(Rts, ws)]) 59 | 60 | return self.all_data 61 | 62 | def __getitem__(self, index): 63 | """ 64 | index (int): the element index 65 | """ 66 | element = self.all_data[index] 67 | return element["scene"], element["pc"], element["pos_prompt"] , element["neg_prompts"], element["Rt"], element["w"] 68 | 69 | def __len__(self): 70 | return len(self.all_data) 71 | 72 | 73 | class Grasp6DDataset_Test(Dataset): 74 | """ 75 | Data loading class for testing. 76 | """ 77 | def __init__(self, dataset_path: str): 78 | """ 79 | dataset_path (str): path to the dataset 80 | """ 81 | super().__init__() 82 | self.dataset_path = dataset_path 83 | self._load() 84 | 85 | def _load(self): 86 | self.all_data = [] 87 | filenames = sorted(os.listdir(f"{self.dataset_path}/pc")) 88 | 89 | print("Processing dataset for testing!") 90 | filenames = filenames[int(len(filenames)*4/5):] # 20% scenes for testing 91 | 92 | for filename in filenames: 93 | scene, _ = os.path.splitext(filename) 94 | pc = np.load(f"{self.dataset_path}/pc/{scene}.npy") 95 | try: 96 | with open(f"{self.dataset_path}/grasp_prompt/{scene}.pkl", "rb") as f: 97 | prompts = pickle.load(f) 98 | except: 99 | continue 100 | num_objects = len(prompts) 101 | for i in range(num_objects): 102 | try: 103 | with open(f"{self.dataset_path}/grasp/{scene}_{i}.pkl", "rb") as f: 104 | Rts, ws = pickle.load(f) 105 | except: 106 | continue 107 | pos_prompt = prompts[i] # positive prompt 108 | gs = np.concatenate((se3_log_map(torch.from_numpy(Rts)).numpy(), 2*ws[:, None]/MAX_WIDTH-1.0), axis=-1) 109 | self.all_data.append({"scene": scene, "pc": pc, "pos_prompt": pos_prompt, "gs": gs}) 110 | 111 | return self.all_data 112 | 113 | def __getitem__(self, index): 114 | """ 115 | index (int): the element index 116 | """ 117 | element = self.all_data[index] 118 | return element["scene"], element["pc"], element["pos_prompt"] , element["gs"] 119 | 120 | def __len__(self): 121 | return len(self.all_data) -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .Grasp6DDataset import Grasp6DDataset_Train, Grasp6DDataset_Test 2 | 3 | 4 | __all__ = ['Grasp6DDataset_Train', 'Grasp6DDataset_Test'] -------------------------------------------------------------------------------- /demo/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Language-Driven-6-DoF-Grasp-Detection-Using-Negative-Prompt-Guidance/ed313481467525d6095d5682945b24cb94732a12/demo/intro.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from utils import * 4 | import argparse 5 | from tqdm import tqdm 6 | from utils.test_utils import earth_movers_distance, coverage_rate, collision_free_rate 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="Evaluate a model") 11 | parser.add_argument("--data", type=str, help="path to the data") 12 | args = parser.parse_args() 13 | return args 14 | 15 | 16 | if __name__ == "__main__": 17 | args = parse_args() 18 | with open(args.data, "rb") as f: 19 | generated_data = pickle.load(f) 20 | cvr = np.array([coverage_rate(datapoint["gs"], datapoint["gen_grasps"])\ 21 | for datapoint in tqdm(generated_data)\ 22 | if coverage_rate(datapoint["gs"], datapoint["gen_grasps"]) is not None]) 23 | print(f"Average CR: {np.mean(cvr)}") 24 | emds = np.array([earth_movers_distance(datapoint["gs"], datapoint["gen_grasps"])\ 25 | for datapoint in tqdm(generated_data)\ 26 | if earth_movers_distance(datapoint["gs"], datapoint["gen_grasps"]) is not None]) 27 | print(f"Average EMD: {np.mean(emds)}") 28 | cfr = np.array([collision_free_rate(datapoint["pc"], datapoint["gen_grasps"])\ 29 | for datapoint in tqdm(generated_data)\ 30 | if collision_free_rate(datapoint["pc"], datapoint["gen_grasps"]) is not None]) 31 | print(f"Average CFR: {np.mean(cfr)}") 32 | 33 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | from gorilla.config import Config 5 | from utils import * 6 | import argparse 7 | from tqdm import tqdm 8 | from models.utils import PosNegTextEncoder 9 | from dataset import Grasp6DDataset_Test 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description="Test a model by generating grasps") 14 | parser.add_argument("--config", type=str, help="config file path") 15 | parser.add_argument("--checkpoint", type=str, help="path to checkpoint model") 16 | parser.add_argument("--data_path", type=str, help="path to test dataset") 17 | parser.add_argument("--n_sample", type=int, help="number of samples to generate for a\ 18 | point cloud-text pair") 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == "__main__": 24 | args = parse_args() 25 | cfg = Config.fromfile(args.config) 26 | if cfg.get("seed") is not None: 27 | set_random_seed(cfg.seed) 28 | model = build_model(cfg) 29 | model = model.to("cuda") 30 | model = torch.nn.DataParallel(model) 31 | 32 | # Load test data 33 | test_data = Grasp6DDataset_Test(args.data_path).all_data 34 | print("Loading checkpoint...") 35 | model.load_state_dict(torch.load(args.checkpoint)) 36 | posneg_text_encoder = PosNegTextEncoder(device=torch.device("cuda")) 37 | n_sample = args.n_sample 38 | 39 | print("Generating...") 40 | model.eval() 41 | for datapoint in tqdm(test_data): 42 | """ 43 | Each datapoint includes a point cloud, a positive prompt, 44 | and a set of corresponding grasp poses. 45 | """ 46 | pc = torch.from_numpy(datapoint["pc"]) 47 | pos_prompt = datapoint["pos_prompt"] 48 | pc = pc.unsqueeze(0).repeat(n_sample, 1, 1).float().to("cuda") 49 | with torch.no_grad(): 50 | pos_prompt_embedding = posneg_text_encoder(pos_prompt, type="pos").repeat(n_sample, 1) 51 | generated_grasps = model.module.generate(pc, pos_prompt_embedding, w=0.2).cpu().detach().numpy() # use 1 GPU only 52 | datapoint["gen_grasps"] = generated_grasps 53 | 54 | with open(os.path.join(cfg.log_dir, "all_data.pkl"), 'wb') as f: 55 | pickle.dump(test_data, f) 56 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser 2 | 3 | 4 | __all__ = ["Denoiser"] -------------------------------------------------------------------------------- /models/denoiser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorchse3.se3 import se3_log_map as pytorchse3_log_map 4 | from utils import * 5 | from models.noise_predictor import NoisePredictor 6 | from models.utils import linear_diffusion_schedule 7 | 8 | 9 | class Denoiser(nn.Module): 10 | """ 11 | This model is for the Euclidean-based distance. 12 | """ 13 | def __init__(self, n_T, betas, drop_prob=0.1): # probability to drop the text 14 | super(Denoiser, self).__init__() 15 | self.noise_predictor = NoisePredictor(time_emb_dim=64) 16 | for k, v in linear_diffusion_schedule(betas, n_T).items(): 17 | self.register_buffer(k, v) 18 | self.n_T = n_T 19 | self.drop_prob = drop_prob 20 | 21 | def forward(self, Rt, w, pc, pos_prompt_embedding, neg_prompt_embeddings, noise): 22 | B = pc.shape[0] 23 | time = torch.randint(1, self.n_T + 1, (B,)).to("cuda") 24 | g = torch.cat((pytorchse3_log_map(Rt), w[:, None]), dim=-1) 25 | g_t = (self.sqrtab[time, None] * g + self.sqrtmab[time, None] * noise).float() 26 | text_mask = torch.bernoulli(torch.zeros(B, 1) + 1 - self.drop_prob).to("cuda") 27 | predicted_noise, neg_prompt_pred, neg_prompt_embedding = self.noise_predictor(g_t, pc, pos_prompt_embedding, neg_prompt_embeddings, text_mask, time) 28 | 29 | return predicted_noise, neg_prompt_pred, neg_prompt_embedding 30 | 31 | def generate(self, pc, pos_prompt_embedding, w=1.0): 32 | """" 33 | pc's size: n_sample x 8192 x 6. 34 | pos_prompt_embedding's size: n_sample x 512 35 | """ 36 | # Pre-compute the scene tokens 37 | pc = pc.permute(0, 2, 1) # B x D x N 38 | scene_tokens = self.noise_predictor.scene_encoder(pc).permute(0, 2, 1) 39 | 40 | # Pre-compute the the negative prompt guidance 41 | scene_embedding = torch.mean(scene_tokens, dim=1) # B x 512, the embeddings for entire scene 42 | neg_prompt_embedding = self.noise_predictor.negative_net(scene_embedding - pos_prompt_embedding) 43 | text_embedding = torch.cat((pos_prompt_embedding.repeat(2, 1), neg_prompt_embedding), axis=0) 44 | scene_tokens = scene_tokens.repeat(3, 1, 1) 45 | 46 | n_sample = pc.shape[0] 47 | g_i = torch.randn(n_sample, (7)).to("cuda") 48 | text_mask = torch.ones_like(pos_prompt_embedding).to("cuda") 49 | text_mask = text_mask.repeat(3, 1) 50 | text_mask[:n_sample] = 0. # make the first part text-free 51 | for j in range(self.n_T, 0, -1): 52 | z = torch.randn(n_sample, (7)) if j > 1 else torch.zeros((n_sample, 7)).float() 53 | z = z.to("cuda") 54 | 55 | g_i = g_i.repeat(3, 1) 56 | time = torch.tensor([j]).repeat(3*n_sample).to("cuda") 57 | eps = self.noise_predictor.forward_precomputing(g_i, scene_tokens, text_embedding, text_mask, time) 58 | eps1 = eps[:n_sample] 59 | eps2 = eps[n_sample:2*n_sample] 60 | eps3 = eps[2*n_sample:] 61 | eps = eps1 + w * (eps2 - eps3) 62 | eps = torch.clamp(eps, -1.0, 1.0) 63 | g_i = g_i[:n_sample] 64 | g_i = self.oneover_sqrta[j] * (g_i - eps * self.mab_over_sqrtmab[j]) + self.sqrt_beta_t[j] * z 65 | 66 | return g_i -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DenoiserEuclideanLoss(nn.Module): 7 | def __init__(self): 8 | super(DenoiserEuclideanLoss, self).__init__() 9 | 10 | def forward(self, predicted_noise, noise, neg_prompt_pred, neg_prompt_embeddings): 11 | """ 12 | neg_prompt_pred's size is B x 512 13 | neg_prompt_embeddings' size is B x num_neg_prompts x 512 14 | """ 15 | mse_loss = F.mse_loss(predicted_noise, noise) 16 | neg_prompt_pred = neg_prompt_pred.unsqueeze(1).expand_as(neg_prompt_embeddings) # B x num_neg_prompts x 512 17 | paired_distances = torch.sqrt(torch.sum((neg_prompt_pred - neg_prompt_embeddings)**2, dim=2)) 18 | neg_loss = torch.mean(torch.min(paired_distances, dim=1)[0]) # use minimum 19 | 20 | return mse_loss, neg_loss -------------------------------------------------------------------------------- /models/noise_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils import SceneEncoderPointNetPlusPlus, GraspNet, SinusoidalPositionEmbeddings 4 | 5 | 6 | class NoisePredictor(nn.Module): 7 | """ 8 | This model uses PointNet++. 9 | This model uses text mask, which is the mask to drop the text guidance. 10 | This model also predicts the negative prompt guidance. 11 | This model is for the Euclidean distance-based loss function for the learning of negative prompt guidance. 12 | """ 13 | def __init__(self, time_emb_dim): 14 | """ 15 | This model uses PointNet++. 16 | time_emd_dim: dimension of the point embedding. 17 | """ 18 | super(NoisePredictor, self).__init__() 19 | self.scene_encoder = SceneEncoderPointNetPlusPlus(additional_channel=3) 20 | self.time_encoder = SinusoidalPositionEmbeddings(time_emb_dim) 21 | self.grasp_encoder = GraspNet() 22 | self.negative_net = nn.Sequential( # this module output the predicted negative embedding 23 | nn.Linear(512, 512), 24 | nn.LayerNorm(512), 25 | nn.ReLU(), 26 | nn.Linear(512, 512) 27 | ) 28 | 29 | self.grasp_pos_prompt_time_net = nn.Sequential( # this module is to embed the concatenation of grasp + pos_prompt's, and time's embeddings. 30 | nn.Linear(512 + time_emb_dim, 512), 31 | nn.LayerNorm(512), 32 | nn.ReLU(), 33 | nn.Linear(512, 512) 34 | ) 35 | 36 | self.cross_attention = nn.MultiheadAttention(embed_dim=512, num_heads=4, batch_first=True) 37 | self.final_net = nn.Sequential( 38 | nn.Linear(512, 128), 39 | nn.LayerNorm(128), 40 | nn.ReLU(), 41 | nn.Linear(128, 7) # 7 = 6 (for Rt) + 1 (for w) 42 | ) 43 | 44 | def forward(self, g, pc, pos_prompt_embedding, neg_prompt_embeddings, text_mask, time): 45 | """ 46 | pc's size is B x N x D 47 | text_mask's size is B x 1. The text_mask is used for positive prompt only. No mask for the pc. 48 | time's size is B 49 | g is in se(3) 50 | """ 51 | pc = pc.permute(0, 2, 1) 52 | scene_tokens = self.scene_encoder(pc) # B x 512 x 128, for the scenes of 8192 points 53 | scene_tokens = scene_tokens.permute(0, 2, 1) # B x 128 x 512 54 | scene_embedding = torch.mean(scene_tokens, dim=1) # B x 512, the embeddings for entire scene 55 | 56 | neg_prompt_pred = self.negative_net(scene_embedding - pos_prompt_embedding) # predict the negative prompt 57 | masked_pos_prompt_embedding = pos_prompt_embedding * text_mask # B x 512, drop the positive prompt using the text_mask 58 | 59 | grasp_embedding = self.grasp_encoder(g) # B x 512 60 | grasp_pos_prompt_embedding = grasp_embedding + masked_pos_prompt_embedding # B x 512 61 | time_embedding = self.time_encoder(time) # B x 64, get the time positional embedding 62 | grasp_pos_prompt_time_embedding = self.grasp_pos_prompt_time_net(torch.cat((grasp_pos_prompt_embedding, time_embedding), dim=1)).unsqueeze(1) # B x 1 x 512 63 | 64 | e, _ = self.cross_attention(query=grasp_pos_prompt_time_embedding, key=scene_tokens, value=scene_tokens) # B x 1 x 512 65 | predicted_noise = self.final_net(e.squeeze(1) + grasp_pos_prompt_time_embedding.squeeze(1)) # residual connection 66 | 67 | return predicted_noise, neg_prompt_pred, neg_prompt_embeddings 68 | 69 | def forward_precomputing(self, g, scene_tokens, pos_prompt_embedding, text_mask, time): 70 | """ 71 | This performs given pre-computed scene_tokens. 72 | The neg_prompt_pred is not used. 73 | """ 74 | masked_pos_prompt_embedding = pos_prompt_embedding * text_mask # B x 512, drop the positive prompt using the text_mask 75 | 76 | grasp_embedding = self.grasp_encoder(g) # B x 512 77 | grasp_pos_prompt_embedding = grasp_embedding + masked_pos_prompt_embedding # B x 512 78 | time_embedding = self.time_encoder(time) # B x 64, get the time positional embedding 79 | grasp_pos_prompt_time_embedding = self.grasp_pos_prompt_time_net(torch.cat((grasp_pos_prompt_embedding, time_embedding), dim=1)).unsqueeze(1) # B x 1 x 512 80 | 81 | e, _ = self.cross_attention(query=grasp_pos_prompt_time_embedding, key=scene_tokens, value=scene_tokens) # B x 1 x 512 82 | predicted_noise = self.final_net(e.squeeze(1) + grasp_pos_prompt_time_embedding.squeeze(1)) # residual connection 83 | 84 | return predicted_noise 85 | -------------------------------------------------------------------------------- /models/pointnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def square_distance(src, dst): 8 | """ 9 | Calculate Euclid distance between each two points. 10 | 11 | src^T * dst = xn * xm + yn * ym + zn * zm; 12 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 13 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 14 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 15 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 16 | 17 | Input: 18 | src: source points, [B, N, C] 19 | dst: target points, [B, M, C] 20 | Output: 21 | dist: per-point square distance, [B, N, M] 22 | """ 23 | B, N, _ = src.shape 24 | _, M, _ = dst.shape 25 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 26 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 27 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 28 | return dist 29 | 30 | 31 | def index_points(points, idx): 32 | """ 33 | 34 | Input: 35 | points: input points data, [B, N, C] 36 | idx: sample index data, [B, S] 37 | Return: 38 | new_points:, indexed points data, [B, S, C] 39 | """ 40 | device = points.device 41 | B = points.shape[0] 42 | view_shape = list(idx.shape) 43 | view_shape[1:] = [1] * (len(view_shape) - 1) 44 | repeat_shape = list(idx.shape) 45 | repeat_shape[0] = 1 46 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 47 | new_points = points[batch_indices, idx, :] 48 | return new_points 49 | 50 | 51 | def farthest_point_sample(xyz, npoint): 52 | """ 53 | Input: 54 | xyz: pointcloud data, [B, N, 3] 55 | npoint: number of samples 56 | Return: 57 | centroids: sampled pointcloud index, [B, npoint] 58 | """ 59 | device = xyz.device 60 | B, N, C = xyz.shape 61 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 62 | distance = torch.ones(B, N).to(device) * 1e10 63 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 64 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 65 | for i in range(npoint): 66 | centroids[:, i] = farthest 67 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 68 | dist = torch.sum((xyz - centroid) ** 2, -1) 69 | mask = dist < distance 70 | distance[mask] = dist[mask] 71 | farthest = torch.max(distance, -1)[1] 72 | return centroids 73 | 74 | 75 | def query_ball_point(radius, nsample, xyz, new_xyz): 76 | """ 77 | Input: 78 | radius: local region radius 79 | nsample: max sample number in local region 80 | xyz: all points, [B, N, 3] 81 | new_xyz: query points, [B, S, 3] 82 | Return: 83 | group_idx: grouped points index, [B, S, nsample] 84 | """ 85 | device = xyz.device 86 | B, N, C = xyz.shape 87 | _, S, _ = new_xyz.shape 88 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 89 | sqrdists = square_distance(new_xyz, xyz) 90 | group_idx[sqrdists > radius ** 2] = N 91 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 92 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 93 | mask = group_idx == N 94 | group_idx[mask] = group_first[mask] 95 | return group_idx 96 | 97 | 98 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 99 | """ 100 | Input: 101 | npoint: 102 | radius: 103 | nsample: 104 | xyz: input points position data, [B, N, 3] 105 | points: input points data, [B, N, D] 106 | Return: 107 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 108 | new_points: sampled points data, [B, npoint, nsample, 3+D] 109 | """ 110 | B, N, C = xyz.shape 111 | S = npoint 112 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 113 | new_xyz = index_points(xyz, fps_idx) 114 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 115 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 116 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 117 | 118 | if points is not None: 119 | grouped_points = index_points(points, idx) 120 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 121 | else: 122 | new_points = grouped_xyz_norm 123 | if returnfps: 124 | return new_xyz, new_points, grouped_xyz, fps_idx 125 | else: 126 | return new_xyz, new_points 127 | 128 | 129 | def sample_and_group_all(xyz, points): 130 | """ 131 | Input: 132 | xyz: input points position data, [B, N, 3] 133 | points: input points data, [B, N, D] 134 | Return: 135 | new_xyz: sampled points position data, [B, 1, 3] 136 | new_points: sampled points data, [B, 1, N, 3+D] 137 | """ 138 | device = xyz.device 139 | B, N, C = xyz.shape 140 | new_xyz = torch.zeros(B, 1, C).to(device) 141 | grouped_xyz = xyz.view(B, 1, N, C) 142 | if points is not None: 143 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 144 | else: 145 | new_points = grouped_xyz 146 | return new_xyz, new_points 147 | 148 | 149 | class PointNetSetAbstraction(nn.Module): 150 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 151 | super(PointNetSetAbstraction, self).__init__() 152 | self.npoint = npoint 153 | self.radius = radius 154 | self.nsample = nsample 155 | self.mlp_convs = nn.ModuleList() 156 | self.mlp_bns = nn.ModuleList() 157 | last_channel = in_channel 158 | for out_channel in mlp: 159 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 160 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 161 | last_channel = out_channel 162 | self.group_all = group_all 163 | 164 | def forward(self, xyz, points): 165 | """ 166 | Input: 167 | xyz: input points position data, [B, C, N] 168 | points: input points data, [B, D, N] 169 | Return: 170 | new_xyz: sampled points position data, [B, C, S] 171 | new_points_concat: sample points feature data, [B, D', S] 172 | """ 173 | xyz = xyz.permute(0, 2, 1) 174 | if points is not None: 175 | points = points.permute(0, 2, 1) 176 | 177 | if self.group_all: 178 | new_xyz, new_points = sample_and_group_all(xyz, points) 179 | else: 180 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 181 | # new_xyz: sampled points position data, [B, npoint, C] 182 | # new_points: sampled points data, [B, npoint, nsample, C+D] 183 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 184 | for i, conv in enumerate(self.mlp_convs): 185 | bn = self.mlp_bns[i] 186 | new_points = F.relu(bn(conv(new_points))) 187 | 188 | new_points = torch.max(new_points, 2)[0] 189 | new_xyz = new_xyz.permute(0, 2, 1) 190 | return new_xyz, new_points 191 | 192 | 193 | class PointNetSetAbstractionMsg(nn.Module): 194 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 195 | super(PointNetSetAbstractionMsg, self).__init__() 196 | self.npoint = npoint 197 | self.radius_list = radius_list 198 | self.nsample_list = nsample_list 199 | self.conv_blocks = nn.ModuleList() 200 | self.bn_blocks = nn.ModuleList() 201 | for i in range(len(mlp_list)): 202 | convs = nn.ModuleList() 203 | bns = nn.ModuleList() 204 | last_channel = in_channel + 3 205 | for out_channel in mlp_list[i]: 206 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 207 | bns.append(nn.BatchNorm2d(out_channel)) 208 | last_channel = out_channel 209 | self.conv_blocks.append(convs) 210 | self.bn_blocks.append(bns) 211 | 212 | def forward(self, xyz, points): 213 | """ 214 | Input: 215 | xyz: input points position data, [B, C, N] 216 | points: input points data, [B, D, N] 217 | Return: 218 | new_xyz: sampled points position data, [B, C, S] 219 | new_points_concat: sample points feature data, [B, D', S] 220 | """ 221 | xyz = xyz.permute(0, 2, 1) 222 | if points is not None: 223 | points = points.permute(0, 2, 1) 224 | 225 | B, N, C = xyz.shape 226 | S = self.npoint 227 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 228 | new_points_list = [] 229 | for i, radius in enumerate(self.radius_list): 230 | K = self.nsample_list[i] 231 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 232 | grouped_xyz = index_points(xyz, group_idx) 233 | grouped_xyz -= new_xyz.view(B, S, 1, C) 234 | if points is not None: 235 | grouped_points = index_points(points, group_idx) 236 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 237 | else: 238 | grouped_points = grouped_xyz 239 | 240 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 241 | for j in range(len(self.conv_blocks[i])): 242 | conv = self.conv_blocks[i][j] 243 | bn = self.bn_blocks[i][j] 244 | grouped_points = F.relu(bn(conv(grouped_points))) 245 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 246 | new_points_list.append(new_points) 247 | 248 | new_xyz = new_xyz.permute(0, 2, 1) 249 | new_points_concat = torch.cat(new_points_list, dim=1) 250 | return new_xyz, new_points_concat 251 | 252 | 253 | class PointNetFeaturePropagation(nn.Module): 254 | def __init__(self, in_channel, mlp): 255 | super(PointNetFeaturePropagation, self).__init__() 256 | self.mlp_convs = nn.ModuleList() 257 | self.mlp_bns = nn.ModuleList() 258 | last_channel = in_channel 259 | for out_channel in mlp: 260 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 261 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 262 | last_channel = out_channel 263 | 264 | def forward(self, xyz1, xyz2, points1, points2): 265 | """ 266 | Input: 267 | xyz1: input points position data, [B, C, N] 268 | xyz2: sampled input points position data, [B, C, S] 269 | points1: input points data, [B, D, N] 270 | points2: input points data, [B, D, S] 271 | Return: 272 | new_points: upsampled points data, [B, D', N] 273 | """ 274 | xyz1 = xyz1.permute(0, 2, 1) 275 | xyz2 = xyz2.permute(0, 2, 1) 276 | 277 | points2 = points2.permute(0, 2, 1) 278 | B, N, C = xyz1.shape 279 | _, S, _ = xyz2.shape 280 | 281 | if S == 1: 282 | interpolated_points = points2.repeat(1, N, 1) 283 | else: 284 | dists = square_distance(xyz1, xyz2) 285 | dists, idx = dists.sort(dim=-1) 286 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 287 | 288 | dist_recip = 1.0 / (dists + 1e-8) 289 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 290 | weight = dist_recip / norm 291 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 292 | 293 | if points1 is not None: 294 | points1 = points1.permute(0, 2, 1) 295 | new_points = torch.cat([points1, interpolated_points], dim=-1) 296 | else: 297 | new_points = interpolated_points 298 | 299 | new_points = new_points.permute(0, 2, 1) 300 | for i, conv in enumerate(self.mlp_convs): 301 | bn = self.mlp_bns[i] 302 | new_points = F.relu(bn(conv(new_points))) 303 | return new_points 304 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import open_clip 5 | import torch.nn.functional as F 6 | from .pointnet_utils import PointNetSetAbstractionMsg 7 | 8 | 9 | def linear_diffusion_schedule(betas, n_T): 10 | """ 11 | Linear scheduler for sampling in training. 12 | """ 13 | beta_t = (betas[1] - betas[0]) * torch.arange(0, n_T + 1, dtype=torch.float64) / n_T + betas[0] 14 | sqrt_beta_t = torch.sqrt(beta_t) 15 | alpha_t = 1 - beta_t 16 | log_alpha_t = torch.log(alpha_t) 17 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 18 | 19 | sqrtab = torch.sqrt(alphabar_t) 20 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 21 | 22 | sqrtmab = torch.sqrt(1 - alphabar_t) 23 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab 24 | 25 | return { 26 | "alpha_t": alpha_t, # \alpha_t 27 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t} 28 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t} 29 | "alphabar_t": alphabar_t, # \bar{\alpha_t} 30 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} 31 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} 32 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}} 33 | } 34 | 35 | 36 | def cosine_diffusion_schedule(cosine_s, n_T): 37 | """ 38 | Cosine scheduling for sampling in training. 39 | """ 40 | timesteps = ( 41 | torch.arange(n_T + 1, dtype=torch.float64) / n_T + cosine_s 42 | ) 43 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 44 | alphas = torch.cos(alphas).pow(2) 45 | alphas = alphas / alphas[0] 46 | beta_t = 1 - alphas[1:] / alphas[:-1] 47 | beta_t = beta_t.clamp(max=0.999) 48 | 49 | sqrt_beta_t = torch.sqrt(beta_t) 50 | alpha_t = 1 - beta_t 51 | log_alpha_t = torch.log(alpha_t) 52 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 53 | 54 | sqrtab = torch.sqrt(alphabar_t) 55 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 56 | 57 | sqrtmab = torch.sqrt(1 - alphabar_t) 58 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab 59 | 60 | return { 61 | "alpha_t": alpha_t, # \alpha_t 62 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t} 63 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t} 64 | "alphabar_t": alphabar_t, # \bar{\alpha_t} 65 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} 66 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} 67 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}} 68 | } 69 | 70 | 71 | class SinusoidalPositionEmbeddings(nn.Module): 72 | """ 73 | Sinusoidal embedding for time step. 74 | """ 75 | def __init__(self, dim, scale=1.0): 76 | super().__init__() 77 | self.dim = dim 78 | self.scale = scale 79 | 80 | def forward(self, time): 81 | time = time * self.scale 82 | device = time.device 83 | half_dim = self.dim // 2 84 | embeddings = math.log(10000) / (half_dim - 1) 85 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 86 | embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) 87 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 88 | return embeddings 89 | 90 | def __len__(self): 91 | return self.dim 92 | 93 | 94 | class TextEncoder(nn.Module): 95 | """ 96 | Text Encoder to encode the text prompt. 97 | """ 98 | def __init__(self, device): 99 | super(TextEncoder, self).__init__() 100 | self.device = device 101 | self.clip_model, _, _ = open_clip.create_model_and_transforms("ViT-B-32", pretrained="laion2b_s34b_b79k", 102 | device=self.device) 103 | 104 | def forward(self, texts): 105 | """ 106 | texts can be a single string or a list of strings. 107 | """ 108 | tokenizer = open_clip.get_tokenizer("ViT-B-32") 109 | tokens = tokenizer(texts).to(self.device) 110 | text_features = self.clip_model.encode_text(tokens).to(self.device) 111 | return text_features 112 | 113 | 114 | class PosNegTextEncoder(nn.Module): 115 | """ 116 | Text encoder that performs differently for positive and negative prompts. 117 | """ 118 | def __init__(self, device): 119 | super(PosNegTextEncoder, self).__init__() 120 | self.text_encoder = TextEncoder(device=device) 121 | 122 | def forward(self, texts, type): 123 | if type == "pos": # if positive prompt 124 | return self.text_encoder(texts) 125 | elif type == "neg": # if negative prompts 126 | B = len(texts[0]) 127 | l = list(zip(*texts)) 128 | l = [item for sublist in l for item in sublist] 129 | embeddings = self.text_encoder(l) # (B x 4) x 512 130 | embeddings = embeddings.reshape(B, -1, embeddings.shape[-1]) 131 | return embeddings 132 | 133 | 134 | class SceneEncoderPointNetPlusPlus(nn.Module): 135 | """ 136 | Scene encoder based on PointNet++, returns scene tokens. 137 | """ 138 | def __init__(self, additional_channel): 139 | super(SceneEncoderPointNetPlusPlus, self).__init__() 140 | self.sa1 = PointNetSetAbstractionMsg(2048, [0.05, 0.1, 0.2], [ 141 | 32, 64, 128], 3+additional_channel, [[16, 16, 32], [32, 32, 64], [32, 48, 64]]) 142 | self.sa2 = PointNetSetAbstractionMsg( 143 | 512, [0.2, 0.4], [64, 128], 64+64+32, [[64, 64, 128], [64, 96, 128]]) 144 | self.sa3 = PointNetSetAbstractionMsg( 145 | 128, [0.4, 0.8], [128, 256], 128+128, [[128, 128, 256], [128, 196, 256]]) 146 | 147 | def forward(self, xyz): 148 | """ 149 | Return point cloud embedding. 150 | """ 151 | # Set Abstraction layers 152 | xyz = xyz.contiguous() 153 | 154 | l0_xyz = xyz[:, :3, :] 155 | l0_points = xyz 156 | 157 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 158 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 159 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 160 | return l3_points 161 | 162 | 163 | class GraspNet(nn.Module): 164 | """ 165 | Class to encoder the grasping pose. 166 | """ 167 | def __init__(self): 168 | super(GraspNet, self).__init__() 169 | self.net = nn.Sequential( 170 | nn.Linear(7, 64), 171 | nn.ReLU(), 172 | nn.Linear(64, 256), 173 | nn.ReLU(), 174 | nn.Linear(256, 512) 175 | ) 176 | 177 | def forward(self, g): 178 | return self.net(g) 179 | 180 | 181 | class BNMomentum(object): 182 | """ 183 | Class for BatchNormMomentum. 184 | """ 185 | def __init__(self, origin_m, m_decay, step): 186 | super().__init__() 187 | self.origin_m = origin_m 188 | self.m_decay = m_decay 189 | self.step = step 190 | return 191 | 192 | def __call__(self, m, epoch): 193 | momentum = self.origin_m * (self.m_decay**(epoch//self.step)) 194 | if momentum < 0.01: 195 | momentum = 0.01 196 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 197 | m.momentum = momentum 198 | return -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchvision 4 | torchaudio 5 | gorilla-core 6 | open_clip_torch 7 | pytorchse3 -------------------------------------------------------------------------------- /robot_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | from gorilla.config import Config 5 | from utils import * 6 | import argparse 7 | from models.utils import PosNegTextEncoder 8 | from pytorchse3.se3 import se3_exp_map 9 | 10 | # LOAD YOUR POINT CLOUD HERE, make sure its size is N x (3+3), 3 for coordinate and 3 for color, 11 | # N is the number of points, can be varied, but preferred to be 8192 12 | pc = None 13 | # SPECIFY YOUR TEXT, for example, "Grasp me the pencil." 14 | text = None 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="Robot experiments") 18 | parser.add_argument("--config", type=str, help="config file path") 19 | parser.add_argument("--checkpoint", type=str, help="path to checkpoint model") 20 | parser.add_argument("--n_sample", type=int, help="number of samples to generate for the\ 21 | point cloud-text pair") 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | if __name__ == "__main__": 27 | args = parse_args() 28 | cfg = Config.fromfile(args.config) 29 | if cfg.get("seed") is not None: 30 | set_random_seed(cfg.seed) 31 | 32 | # Build the model 33 | model = build_model(cfg) 34 | model = model.to("cuda") 35 | model = torch.nn.DataParallel(model) 36 | model.load_state_dict(torch.load(args.checkpoint)) 37 | posneg_text_encoder = PosNegTextEncoder(device=torch.device("cuda")) 38 | n_sample = args.n_sample 39 | 40 | model.eval() 41 | 42 | pc = pc.unsqueeze(0).repeat(n_sample, 1, 1).float().to("cuda") 43 | with torch.no_grad(): 44 | text_embedding = posneg_text_encoder(text, type="pos").repeat(n_sample, 1) 45 | generated_grasps = se3_exp_map(model.module.generate(pc, text_embedding)).cpu().detach().numpy() # use 1 GPU only 46 | 47 | # Save generated grasps to file 48 | with open(os.path.join(cfg.log_dir, "generated_grasps.pkl"), 'wb') as f: 49 | pickle.dump(generated_grasps, f) 50 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from gorilla.config import Config 3 | from utils import * 4 | import argparse 5 | import torch 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description="Train a model") 10 | parser.add_argument("--config", help="train config file path") 11 | args = parser.parse_args() 12 | return args 13 | 14 | 15 | if __name__ == "__main__": 16 | args = parse_args() 17 | cfg = Config.fromfile(args.config) 18 | 19 | logger = IOStream(os.path.join(cfg.log_dir, "run.log")) 20 | if cfg.get("seed") is not None: 21 | set_random_seed(cfg.seed) 22 | logger.cprint("Set seed to %d" % cfg.seed) 23 | model = build_model(cfg) 24 | model = model.to("cuda") 25 | model = torch.nn.DataParallel(model) 26 | 27 | print("Training from scratch!") 28 | 29 | dataset_dict = build_dataset(cfg) 30 | loader_dict = build_loader(cfg, dataset_dict) 31 | optim_dict = build_optimizer(cfg, model) 32 | 33 | training = dict( 34 | model=model, 35 | dataset_dict=dataset_dict, 36 | loader_dict=loader_dict, 37 | optim_dict=optim_dict, 38 | logger=logger 39 | ) 40 | 41 | task_trainer = Trainer(cfg, training) 42 | task_trainer.run() 43 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_utils import set_random_seed, simple_weights_init, IOStream 2 | from .builder import build_model, build_dataset, build_loader, build_optimizer 3 | from .trainer import Trainer 4 | 5 | 6 | __all__ = ["set_random_seed", "simple_weights_init", "IOStream", 7 | "build_model", "build_dataset", "build_loader", "build_optimizer", "Trainer"] -------------------------------------------------------------------------------- /utils/builder.py: -------------------------------------------------------------------------------- 1 | from dataset import Grasp6DDataset_Train 2 | from models import * 3 | from utils.config_utils import simple_weights_init 4 | from torch.utils.data import DataLoader 5 | from torch.optim import Adam 6 | 7 | 8 | model_pool = { 9 | "Denoiser": Denoiser 10 | } 11 | 12 | optimizer_pool = { 13 | "adam": Adam 14 | } 15 | 16 | init_pool = { 17 | "simple_weights_init": simple_weights_init 18 | } 19 | 20 | 21 | def build_dataset(cfg): 22 | """ 23 | Function to build the dataset. 24 | """ 25 | if hasattr(cfg, "data"): 26 | data_info = cfg.data 27 | dataset_path = data_info.dataset_path # get the path to the dataset 28 | num_neg_prompts = data_info.num_neg_prompts # get the maximum number of negative prompts 29 | train_set = Grasp6DDataset_Train(dataset_path, num_neg_prompts=num_neg_prompts) # the training set 30 | dataset_dict = dict( 31 | train_set=train_set, 32 | ) 33 | return dataset_dict 34 | else: 35 | raise ValueError("Configuration does not have data config!") 36 | 37 | 38 | def build_loader(cfg, dataset_dict): 39 | """ 40 | Function to build the loader 41 | """ 42 | train_set = dataset_dict["train_set"] 43 | train_loader = DataLoader(train_set, batch_size=cfg.training_cfg.batch_size, shuffle=True, drop_last=False, num_workers=8) 44 | loader_dict = dict( 45 | train_loader=train_loader, 46 | ) 47 | 48 | return loader_dict 49 | 50 | 51 | def build_model(cfg): 52 | """ 53 | Function to build the model. 54 | """ 55 | if hasattr(cfg, "model"): 56 | model_info = cfg.model 57 | weights_init = model_info.get("weights_init", None) 58 | model_name = model_info.type 59 | model_cls = model_pool[model_name] 60 | 61 | if model_name in ["Denoiser"]: 62 | betas = model_info.get("betas") 63 | n_T = model_info.get("n_T") 64 | drop_prob = model_info.get("drop_prob") 65 | model = model_cls(n_T, betas, drop_prob) 66 | else: 67 | raise ValueError("Name of model does not exist!") 68 | if weights_init is not None: 69 | init_fn = init_pool[weights_init] 70 | model.apply(init_fn) 71 | return model 72 | else: 73 | raise ValueError("Configuration does not have model config!") 74 | 75 | 76 | def build_optimizer(cfg, model): 77 | """ 78 | Function to build the optimizer. 79 | """ 80 | if hasattr(cfg, "optimizer"): 81 | optimizer_info = cfg.optimizer 82 | optimizer_type = optimizer_info.type 83 | optimizer_info.pop("type") 84 | optimizer_cls = optimizer_pool[optimizer_type] 85 | optimizer = optimizer_cls(model.parameters(), **optimizer_info) 86 | optim_dict = dict( 87 | optimizer=optimizer 88 | ) 89 | return optim_dict 90 | else: 91 | raise ValueError("Configuration does not have optimizer config!") 92 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | class IOStream(): 7 | def __init__(self, path): 8 | self.f = open(path, "a") 9 | 10 | def cprint(self, text): 11 | print(text) 12 | self.f.write(text+"\n") 13 | self.f.flush() 14 | 15 | def close(self): 16 | self.f.close() 17 | 18 | 19 | def set_random_seed(seed): 20 | """ 21 | Function to set seed for reproducibility. 22 | """ 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | 28 | 29 | def simple_weights_init(m): 30 | """ 31 | Function to initialize weights. 32 | """ 33 | classname = m.__class__.__name__ 34 | if classname.find("Conv2d") != -1: 35 | torch.nn.init.xavier_normal_(m.weight.data) 36 | if m.state_dict().get("bias") is not None: 37 | torch.nn.init.constant_(m.bias.data, 0.0) 38 | elif classname.find("Linear") != -1: 39 | torch.nn.init.xavier_normal_(m.weight.data) 40 | if m.state_dict().get("bias") is not None: 41 | torch.nn.init.constant_(m.bias.data, 0.0) -------------------------------------------------------------------------------- /utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchse3.se3 import se3_log_map, se3_exp_map 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment 5 | from scipy.spatial.distance import cdist 6 | import trimesh 7 | 8 | MAX_WIDTH = 0.202 9 | NORMAL_WIDTH = 0.140 10 | 11 | 12 | def create_gripper_marker(width_scale=1.0, color=[0, 255, 0], tube_radius=0.005, sections=6): 13 | """Create a 3D mesh visualizing a parallel yaw gripper. It consists of four cylinders. 14 | 15 | Args: 16 | width_scale (float, optional): Scale of the grasp with w.r.t. the normal width of 140mm, i.e., 0.14. 17 | color (list, optional): RGB values of marker. 18 | tube_radius (float, optional): Radius of cylinders. 19 | sections (int, optional): Number of sections of each cylinder. 20 | 21 | Returns: 22 | trimesh.Trimesh: A mesh that represents a simple parallel yaw gripper. 23 | """ 24 | cfl = trimesh.creation.cylinder( 25 | radius=tube_radius, 26 | sections=sections, 27 | segment=[ 28 | [7.10000000e-02*width_scale, -7.27595772e-12, 1.154999996e-01], 29 | [7.10000000e-02*width_scale, -7.27595772e-12, 1.959999998e-01], 30 | ], 31 | ) 32 | cfr = trimesh.creation.cylinder( 33 | radius=tube_radius, 34 | sections=sections, 35 | segment=[ 36 | [-7.100000e-02*width_scale, -7.27595772e-12, 1.154999996e-01], 37 | [-7.100000e-02*width_scale, -7.27595772e-12, 1.959999998e-01], 38 | ], 39 | ) 40 | cb1 = trimesh.creation.cylinder( 41 | radius=tube_radius, sections=sections, segment=[[0, 0, 0], [0, 0, 1.154999996e-01]] 42 | ) 43 | cb2 = trimesh.creation.cylinder( 44 | radius=tube_radius, 45 | sections=sections, 46 | segment=[[-7.100000e-02*width_scale, 0, 1.154999996e-01], [7.100000e-02*width_scale, 0, 1.154999996e-01]], 47 | ) 48 | 49 | tmp = trimesh.util.concatenate([cb1, cb2, cfr, cfl]) 50 | tmp.visual.face_colors = color 51 | 52 | return tmp 53 | 54 | 55 | def earth_movers_distance(train_grasps, gen_grasps): 56 | """ 57 | Compute Earth Mover's Distance between two sets of vectors. 58 | """ 59 | # Ensure the input sets have the same dimensionality 60 | assert train_grasps.shape[1] == gen_grasps.shape[1] 61 | if np.isnan(train_grasps).any() or np.isnan(gen_grasps).any(): 62 | raise ValueError("NaN values exist!") 63 | # Calculate pairwise distances between vectors in the sets 64 | distances = np.linalg.norm(train_grasps[:, np.newaxis] - gen_grasps, axis=-1) 65 | # Solve the linear sum assignment problem 66 | _, assignment = linear_sum_assignment(distances) 67 | # Compute the total Earth Mover's Distance 68 | emd = distances[np.arange(len(assignment)), assignment].mean() 69 | return emd 70 | 71 | 72 | def coverage_rate(train_grasps, gen_grasps): 73 | """ 74 | Function to compute the coverage rate metric. 75 | """ 76 | assert train_grasps.shape[1] == gen_grasps.shape[1] 77 | if np.isnan(train_grasps).any() or np.isnan(gen_grasps).any(): 78 | raise ValueError("NaN values exist!") 79 | dist = cdist(train_grasps, gen_grasps) 80 | rate = np.sum(np.any(dist <= 0.4, axis=1)) / train_grasps.shape[0] 81 | return rate 82 | 83 | 84 | def collision_check(pc, gripper): 85 | """ 86 | Function to check collision between a point cloud and a gripper. 87 | """ 88 | return np.sum(gripper.contains(pc)) > 0 89 | 90 | 91 | def collision_free_rate(pc, grasps): 92 | """ 93 | Function to compute the collision rate metric. 94 | pc' size: N x 6 95 | """ 96 | if np.isnan(grasps).any(): 97 | return None 98 | pc = pc[:, :3] # use only the coordinates of the point cloud 99 | Rts, ws = se3_exp_map(torch.from_numpy(grasps[:, :-1])).numpy(), (grasps[:, -1] + 1.0)*MAX_WIDTH/2 100 | grippers = [create_gripper_marker(width_scale=w/NORMAL_WIDTH).apply_transform(Rt) for w, Rt in zip(ws, Rts)] 101 | collision_free_rate = 1.0 - np.mean(np.array([collision_check(pc, gripper) for gripper in grippers])) 102 | return collision_free_rate -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from utils import * 5 | from models.utils import PosNegTextEncoder 6 | from models.loss import DenoiserEuclideanLoss 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, cfg, running): 11 | super().__init__() 12 | self.cfg = cfg 13 | self.logger = running["logger"] 14 | self.model = running["model"] 15 | self.dataset_dict = running["dataset_dict"] 16 | self.loader_dict = running["loader_dict"] 17 | self.train_loader = self.loader_dict.get("train_loader", None) 18 | self.optimizer_dict = running["optim_dict"] 19 | self.optimizer = self.optimizer_dict.get("optimizer", None) 20 | self.epoch = 0 21 | 22 | self.gamma = cfg.training_cfg.get("gamma", 0.9) # gamma for loss functions 23 | 24 | # define text encoder 25 | self.posneg_text_encoder = PosNegTextEncoder(device=torch.device("cuda")) 26 | 27 | def train(self): 28 | denoiser_euclidean_loss = DenoiserEuclideanLoss() 29 | self.model.train() 30 | self.logger.cprint("Epoch(%d) begin training........" % self.epoch) 31 | pbar = tqdm(self.train_loader) 32 | if self.epoch > 100: # freeze the scene encoder aftr 100 epochs to accelerate the training. 33 | for p in self.model.noise_predictor.scene_encoder.parameters(): 34 | p.requires_grad = False 35 | 36 | for _, pc, pos_prompt, neg_prompts, Rt, w in pbar: 37 | B = pc.shape[0] 38 | pc = pc.float().to("cuda") 39 | Rt, w = Rt.float().to("cuda"), w.float().to("cuda") 40 | noise = torch.randn(B, 7).to("cuda") 41 | with torch.no_grad(): 42 | pos_prompt_embedding = self.posneg_text_encoder(pos_prompt, type="pos") 43 | neg_prompt_embeddings = self.posneg_text_encoder(neg_prompts, type="neg") 44 | predicted_noise, neg_prompt_pred, neg_prompt_embeddings = self.model(Rt, w, pc, pos_prompt_embedding, neg_prompt_embeddings, noise) 45 | mse_loss, neg_loss = denoiser_euclidean_loss(predicted_noise, noise, neg_prompt_pred, neg_prompt_embeddings) 46 | pbar.set_description(f"MSE loss: {mse_loss.item():.5f}, Neg loss: {neg_loss.item():.5f}") 47 | 48 | loss = self.gamma * mse_loss + (1 - self.gamma) * neg_loss 49 | loss.backward() 50 | self.optimizer.step() 51 | self.optimizer.zero_grad() 52 | 53 | self.logger.cprint(f"\nEpoch {self.epoch}, Real-time mse loss: {mse_loss.item():.5f},\ 54 | Real-time neg loss: {neg_loss.item():.5f}") 55 | print("Saving checkpoint\n----------------------------------------\n") 56 | torch.save(self.model.state_dict(), os.path.join(self.cfg.log_dir, "current_model.t7")) 57 | self.epoch += 1 58 | 59 | def val(self): 60 | raise NotImplementedError 61 | 62 | def run(self): 63 | EPOCH = self.cfg.training_cfg.epoch 64 | workflow = self.cfg.training_cfg.workflow 65 | while self.epoch < EPOCH: 66 | for key, running_epoch in workflow.items(): 67 | epoch_runner = getattr(self, key) 68 | for _ in range(running_epoch): 69 | epoch_runner() 70 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample code for point cloud and grasps visualization using trimesh. 3 | """ 4 | 5 | 6 | import pickle 7 | import trimesh 8 | import numpy as np 9 | from utils.test_utils import create_gripper_marker 10 | 11 | 12 | id = "007d90628c0eb878feeb4c44c1a7b717cec5e4cf10c160689198bd908a1e9d04" 13 | pc = np.load(f"/cm/shared/toannt28/grasp-anything/pc/{id}.npy") 14 | pc = trimesh.points.PointCloud(vertices=pc[:, :3], colors=pc[:, 3:]) 15 | 16 | with open(f"/cm/shared/toannt28/grasp-anything/grasp/{id}_0", "rb") as file: 17 | grasp = pickle.load(file) 18 | 19 | grasp = [create_gripper_marker(w/0.14).apply_transform(Rt) for (Rt, w) in zip(grasp[0], grasp[1])] 20 | scene = trimesh.scene.Scene([pc, grasp]) 21 | scene.show(line_settings={'point_size':20}) 22 | --------------------------------------------------------------------------------