├── .gitignore ├── dataset ├── __init__.py └── ThreeDAPDataset.py ├── .gitmodules ├── assets ├── intro.png ├── method.png └── visualization.png ├── requirements.txt ├── models ├── __init__.py ├── weights_init.py ├── main_nets.py ├── components.py └── pointnet_util.py ├── utils ├── __init__.py ├── utils.py ├── visualization.py ├── trainer.py ├── eval.py └── builder.py ├── test.py ├── LICENSE ├── config └── detectiondiffusion.py ├── train.py ├── detect.py ├── visualize.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | log/ -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .ThreeDAPDataset import ThreeDAPDataset 2 | 3 | 4 | __all__ = ['ThreeDAPDataset'] -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorchse3"] 2 | path = pytorchse3 3 | url = https://github.com/eigenvivek/pytorchse3 4 | -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/HEAD/assets/intro.png -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/HEAD/assets/method.png -------------------------------------------------------------------------------- /assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/HEAD/assets/visualization.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | h5py 4 | scikit_learn==1.3.0 5 | gorilla-core==0.2.7.8 6 | torch==2.0.1 7 | scipy==1.11.1 8 | trimesh==4.0.7 9 | open_clip_torch -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .main_nets import DetectionDiffusion 2 | from .weights_init import weights_init 3 | 4 | 5 | __all__ = ['DetectionDiffusion', 'weights_init'] -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_optimizer, build_dataset, build_loader, build_model 2 | from .trainer import Trainer 3 | from .utils import set_random_seed, IOStream, PN2_BNMomentum, PN2_Scheduler 4 | 5 | __all__ = ['build_optimizer', 'build_dataset', 'build_loader', 'build_model', 6 | 'Trainer', 'set_random_seed', 'IOStream', 'PN2_BNMomentum', 'PN2_Scheduler'] 7 | -------------------------------------------------------------------------------- /models/weights_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def weights_init(m): 4 | """_summary_ 5 | Weights initialization 6 | """ 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv2d') != -1: 9 | torch.nn.init.xavier_normal_(m.weight.data) 10 | if m.state_dict().get('bias') != None: 11 | torch.nn.init.constant_(m.bias.data, 0.0) 12 | elif classname.find('Linear') != -1: 13 | torch.nn.init.xavier_normal_(m.weight.data) 14 | if m.state_dict().get('bias') != None: 15 | torch.nn.init.constant_(m.bias.data, 0.0) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils.eval import affordance_eval, pose_eval 2 | import argparse 3 | import pickle 4 | 5 | 6 | AFFORDANCE_LIST = ['grasp to pour', 'grasp to stab', 'stab', 'pourable', 'lift', 'wrap_grasp', 'listen', 'contain', 'displaY', 'grasp to cut', 'cut', 'wear', 'openable', 'grasp'] 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="Test a model") 11 | parser.add_argument("--result", help="result file") 12 | args = parser.parse_args() 13 | return args 14 | 15 | 16 | if __name__ == "__main__": 17 | args = parse_args() 18 | with open(args.result, 'rb') as f: 19 | result = pickle.load(f) 20 | mIoU, Acc, mAcc = affordance_eval(AFFORDANCE_LIST, result) 21 | print(f'mIoU: {mIoU}, Acc: {Acc}, mAcc: {mAcc}') 22 | 23 | mESM, mCR = pose_eval(result) 24 | print(f'mESM: {mESM}, mCR: {mCR}') 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Toan Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /config/detectiondiffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from os.path import join as opj 4 | from utils import PN2_BNMomentum 5 | 6 | exp_name = 'detectiondiffusion' 7 | seed = 1 8 | log_dir = opj("./log/", exp_name) 9 | try: 10 | os.makedirs(log_dir) 11 | except: 12 | print('Logging Dir is already existed!') 13 | 14 | # scheduler = dict( 15 | # type='lr_lambda', 16 | # lr_lambda=PN2_Scheduler(init_lr=0.001, step=20, 17 | # decay_rate=0.5, min_lr=1e-5) 18 | # ) 19 | 20 | scheduler = None 21 | 22 | optimizer = dict( 23 | type='adam', 24 | lr=1e-3, 25 | betas=(0.9, 0.999), 26 | eps=1e-08, 27 | weight_decay=1e-5, 28 | ) 29 | 30 | model = dict( 31 | type='detectiondiffusion', 32 | device=torch.device('cuda'), 33 | background_text='none', 34 | betas=[1e-4, 0.02], 35 | n_T=1000, 36 | drop_prob=0.1, 37 | weights_init='default_init', 38 | ) 39 | 40 | training_cfg = dict( 41 | model=model, 42 | batch_size=32, 43 | epoch=200, 44 | gpu='0', 45 | workflow=dict( 46 | train=1, 47 | ), 48 | bn_momentum=PN2_BNMomentum(origin_m=0.1, m_decay=0.5, step=20), 49 | ) 50 | 51 | data = dict( 52 | data_path="../full_shape_release.pkl", 53 | ) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as opj 3 | from gorilla.config import Config 4 | from utils import * 5 | import argparse 6 | import torch 7 | 8 | 9 | # Argument Parser 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Train a model") 12 | parser.add_argument("--config", help="train config file path") 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | if __name__ == "__main__": 18 | args = parse_args() 19 | cfg = Config.fromfile(args.config) 20 | 21 | logger = IOStream(opj(cfg.log_dir, 'run.log')) 22 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.training_cfg.gpu 23 | num_gpu = len(cfg.training_cfg.gpu.split(',')) # number of GPUs to use 24 | logger.cprint('Use %d GPUs: %s' % (num_gpu, cfg.training_cfg.gpu)) 25 | if cfg.get('seed') != None: # set random seed 26 | set_random_seed(cfg.seed) 27 | logger.cprint('Set seed to %d' % cfg.seed) 28 | model = build_model(cfg).cuda() # build the model from configuration 29 | 30 | print("Training from scratch!") 31 | 32 | dataset_dict = build_dataset(cfg) # build the dataset 33 | loader_dict = build_loader(cfg, dataset_dict) # build the loader 34 | optim_dict = build_optimizer(cfg, model) # build the optimizer 35 | 36 | # construct the training process 37 | training = dict( 38 | model=model, 39 | dataset_dict=dataset_dict, 40 | loader_dict=loader_dict, 41 | optim_dict=optim_dict, 42 | logger=logger 43 | ) 44 | 45 | task_trainer = Trainer(cfg, training) 46 | task_trainer.run() 47 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import random 5 | 6 | 7 | class IOStream(): 8 | def __init__(self, path): 9 | self.f = open(path, 'a') 10 | 11 | def cprint(self, text): 12 | print(text) 13 | self.f.write(text+'\n') 14 | self.f.flush() 15 | 16 | def close(self): 17 | self.f.close() 18 | 19 | 20 | class PN2_Scheduler(object): 21 | def __init__(self, init_lr, step, decay_rate, min_lr): 22 | super().__init__() 23 | self.init_lr = init_lr 24 | self.step = step 25 | self.decay_rate = decay_rate 26 | self.min_lr = min_lr 27 | return 28 | 29 | def __call__(self, epoch): 30 | factor = self.decay_rate**(epoch//self.step) 31 | if self.init_lr*factor < self.min_lr: 32 | factor = self.min_lr / self.init_lr 33 | return factor 34 | 35 | 36 | class PN2_BNMomentum(object): 37 | def __init__(self, origin_m, m_decay, step): 38 | super().__init__() 39 | self.origin_m = origin_m 40 | self.m_decay = m_decay 41 | self.step = step 42 | return 43 | 44 | def __call__(self, m, epoch): 45 | momentum = self.origin_m * (self.m_decay**(epoch//self.step)) 46 | if momentum < 0.01: 47 | momentum = 0.01 48 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 49 | m.momentum = momentum 50 | return 51 | 52 | 53 | def set_random_seed(seed): 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed(seed) -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | 3 | 4 | def create_gripper_marker(color=[0, 255, 0], tube_radius=0.002, sections=6): 5 | """Create a 3D mesh visualizing a parallel yaw gripper. It consists of four cylinders. 6 | 7 | Args: 8 | color (list, optional): RGB values of marker. Defaults to [0, 0, 255]. 9 | tube_radius (float, optional): Radius of cylinders. Defaults to 0.001. 10 | sections (int, optional): Number of sections of each cylinder. Defaults to 6. 11 | 12 | Returns: 13 | trimesh.Trimesh: A mesh that represents a simple parallel yaw gripper. 14 | """ 15 | cfl = trimesh.creation.cylinder( 16 | radius=tube_radius, 17 | sections=sections, 18 | segment=[ 19 | [4.10000000e-02, -7.27595772e-12, 6.59999996e-02], 20 | [4.10000000e-02, -7.27595772e-12, 1.12169998e-01], 21 | ], 22 | ) 23 | cfr = trimesh.creation.cylinder( 24 | radius=tube_radius, 25 | sections=sections, 26 | segment=[ 27 | [-4.100000e-02, -7.27595772e-12, 6.59999996e-02], 28 | [-4.100000e-02, -7.27595772e-12, 1.12169998e-01], 29 | ], 30 | ) 31 | cb1 = trimesh.creation.cylinder( 32 | radius=tube_radius, sections=sections, segment=[[0, 0, 0], [0, 0, 6.59999996e-02]] 33 | ) 34 | cb2 = trimesh.creation.cylinder( 35 | radius=tube_radius, 36 | sections=sections, 37 | segment=[[-4.100000e-02, 0, 6.59999996e-02], [4.100000e-02, 0, 6.59999996e-02]], 38 | ) 39 | 40 | tmp = trimesh.util.concatenate([cb1, cb2, cfr, cfl]) 41 | tmp.visual.face_colors = color 42 | 43 | return tmp -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from gorilla.config import Config 4 | from utils import * 5 | import argparse 6 | import pickle 7 | from tqdm import tqdm 8 | import random 9 | 10 | 11 | GUIDE_W = 0.2 12 | DEVICE = torch.device('cuda') 13 | 14 | 15 | # Argument Parser 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="Detect affordance and poses") 18 | parser.add_argument("--config", help="test config file path") 19 | parser.add_argument("--checkpoint", help="path to checkpoint model") 20 | parser.add_argument("--test_data", help="path to test_data") 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | if __name__ == "__main__": 26 | args = parse_args() 27 | cfg = Config.fromfile(args.config) 28 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.training_cfg.gpu 29 | model = build_model(cfg).to(DEVICE) 30 | 31 | if args.checkpoint != None: 32 | print("Loading checkpoint....") 33 | _, exten = os.path.splitext(args.checkpoint) 34 | if exten == '.t7': 35 | model.load_state_dict(torch.load(args.checkpoint)) 36 | elif exten == '.pth': 37 | check = torch.load(args.checkpoint) 38 | model.load_state_dict(check['model_state_dict']) 39 | else: 40 | raise ValueError("Must specify a checkpoint path!") 41 | 42 | if cfg.get('seed') != None: 43 | set_random_seed(cfg.seed) 44 | 45 | with open(args.test_data, 'rb') as f: 46 | shape_data = pickle.load(f) 47 | random.shuffle(shape_data) 48 | shape_data = shape_data[int(0.8 * len(shape_data)):] 49 | 50 | print("Detecting") 51 | model.eval() 52 | with torch.no_grad(): 53 | for shape in tqdm(shape_data): 54 | xyz = torch.from_numpy(shape['full_shape']['coordinate']).unsqueeze(0).float().cuda() 55 | shape['result'] = {text: [*(model.detect_and_sample(xyz, text, 2000, guide_w=GUIDE_W))] for text in shape['affordance']} 56 | 57 | with open(f'{cfg.log_dir}/result.pkl', 'wb') as f: 58 | pickle.dump(shape_data, f) -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | import pickle 4 | from scipy.spatial.transform import Rotation as R 5 | import argparse 6 | from utils.visualization import create_gripper_marker 7 | 8 | color_code_1 = np.array([0, 0, 255]) # color code for affordance region 9 | color_code_2 = np.array([0, 255, 0]) # color code for gripper pose 10 | num_pose = 100 # number of poses to visualize per each object-affordance pair 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="Visualize") 15 | parser.add_argument("--result", help="result file") 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | if __name__ == "__main__": 21 | args = parse_args() 22 | result_file = args.result_file 23 | with open(result_file, 'rb') as f: 24 | result = pickle.load(f) 25 | 26 | for i in range(len(result)): 27 | if result[i]['semantic class'] == 'Bottle': 28 | shape_index = i 29 | shape = result[shape_index] 30 | 31 | for affordance in shape['affordance']: 32 | colors = np.transpose(shape['result'][affordance][0]) * color_code_1 33 | point_cloud = trimesh.points.PointCloud(shape['full_shape']['coordinate'], colors=colors) 34 | print(f"Affordance: {affordance}") 35 | T = shape['result'][affordance][1][:num_pose] 36 | rotation = np.concatenate((R.from_quat(T[:, :4]).as_matrix(), np.zeros((num_pose, 1, 3), dtype=np.float32)), axis=1) 37 | translation = np.expand_dims(np.concatenate((T[:, 4:], np.ones((num_pose, 1), dtype=np.float32)), axis=1), axis=2) 38 | T = np.concatenate((rotation, translation), axis=2) 39 | poses = [create_gripper_marker(color=color_code_2).apply_transform(t) for t in T 40 | if np.min(np.linalg.norm(point_cloud - (t @ np.array([0., 0., 6.59999996e-02, 1.]))[:3], axis=1)) <= 0.03] # this line is used to get reliable poses only 41 | scene = trimesh.Scene([point_cloud, poses]) 42 | scene.show(line_settings={'point size': 10}) -------------------------------------------------------------------------------- /dataset/ThreeDAPDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.utils.data import Dataset 3 | import pickle as pkl 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | class ThreeDAPDataset(Dataset): 8 | """_summary_ 9 | This class is for the data loading. 10 | """ 11 | def __init__(self, data_path, mode): 12 | """_summary_ 13 | 14 | Args: 15 | data_path (str): path to the dataset 16 | """ 17 | super().__init__() 18 | self.data_path = data_path 19 | self.mode = mode 20 | if self.mode in ["train", "val", "test"]: 21 | self._load_data() 22 | else: 23 | raise ValueError("Mode must be train, val, or test!") 24 | 25 | def _load_data(self): 26 | self.all_data = [] 27 | 28 | with open(self.data_path, "rb") as f: 29 | data = pkl.load(f) 30 | random.shuffle(data) 31 | 32 | if self.mode == "train": data = data[:int(0.7 * len(data))] 33 | elif self.mode == "val": data = data[int(0.7 * len(data)):int(0.8 * len(data))] 34 | else: data = data[int(0.8 * len(data)):] 35 | 36 | for data_point in data: 37 | for affordance in data_point["affordance"]: 38 | for pose in data_point["pose"][affordance]: 39 | new_data_dict = { 40 | "shape_id": data_point["shape_id"], 41 | "semantic class": data_point["semantic class"], 42 | "point cloud": data_point["full_shape"]["coordinate"], 43 | "affordance": affordance, 44 | "affordance label": data_point["full_shape"]["label"][affordance], 45 | "rotation": R.from_matrix(pose[:3, :3]).as_quat(), 46 | "translation": pose[:3, 3] 47 | } 48 | self.all_data.append(new_data_dict) 49 | 50 | def __getitem__(self, index): 51 | """_summary_ 52 | 53 | Args: 54 | index (int): the element index 55 | 56 | Returns: 57 | shape id, semantic class, coordinate, affordance text, affordance label, rotation and translation 58 | """ 59 | data_dict = self.all_data[index] 60 | return data_dict['shape_id'], data_dict['semantic class'], data_dict['point cloud'], data_dict['affordance'], \ 61 | data_dict['affordance label'], data_dict['rotation'], data_dict['translation'] 62 | 63 | def __len__(self): 64 | return len(self.all_data) 65 | 66 | 67 | if __name__ == "__main__": 68 | random.seed(1) 69 | dataset = ThreeDAPDataset(data_path="../full_shape_release.pkl", mode="train") 70 | print(len(dataset)) -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from os.path import join as opj 4 | from utils import * 5 | 6 | 7 | DEVICE = torch.device('cuda') 8 | 9 | 10 | class Trainer(object): 11 | def __init__(self, cfg, running): 12 | super().__init__() 13 | self.cfg = cfg 14 | self.logger = running['logger'] 15 | self.model = running["model"] 16 | self.dataset_dict = running["dataset_dict"] 17 | self.loader_dict = running["loader_dict"] 18 | self.train_loader = self.loader_dict.get("train_loader", None) 19 | self.optimizer_dict = running["optim_dict"] 20 | self.optimizer = self.optimizer_dict.get("optimizer", None) 21 | self.scheduler = self.optimizer_dict.get("scheduler", None) 22 | self.epoch = 0 23 | self.bn_momentum = self.cfg.training_cfg.get('bn_momentum', None) 24 | 25 | def train(self): 26 | self.model.train() 27 | self.logger.cprint("Epoch(%d) begin training........" % self.epoch) 28 | pbar = tqdm(self.train_loader) 29 | for _, _, xyz, text, affordance_label, rotation, translation in pbar: 30 | self.optimizer.zero_grad() 31 | xyz = xyz.float() 32 | rotation = rotation.float() 33 | translation = translation.float() 34 | affordance_label = affordance_label.squeeze().long() 35 | 36 | g = torch.cat((rotation, translation), dim=1) 37 | xyz = xyz.to(DEVICE) 38 | affordance_label = affordance_label.to(DEVICE) 39 | g = g.to(DEVICE) 40 | 41 | affordance_loss, pose_loss = self.model(xyz, text, affordance_label, g) 42 | loss = affordance_loss + pose_loss 43 | loss.backward() 44 | 45 | affordance_l = affordance_loss.item() 46 | pose_l = pose_loss.item() 47 | pbar.set_description(f'Affordance loss: {affordance_l:.5f}, Pose loss: {pose_l:.5f}') 48 | self.optimizer.step() 49 | 50 | if self.scheduler != None: 51 | self.scheduler.step() 52 | if self.bn_momentum != None: 53 | self.model.apply(lambda x: self.bn_momentum(x, self.epoch)) 54 | 55 | outstr = f"\nEpoch {self.epoch}, Last Affordance loss: {affordance_l:.5f}, Last Pose loss: {pose_l:.5f}" 56 | self.logger.cprint(outstr) 57 | print('Saving checkpoint') 58 | torch.save(self.model.state_dict(), opj(self.cfg.log_dir, 'current_model.t7')) 59 | self.epoch += 1 60 | 61 | def val(self): 62 | raise NotImplementedError 63 | 64 | def test(self): 65 | raise NotImplementedError 66 | 67 | def run(self): 68 | EPOCH = self.cfg.training_cfg.epoch 69 | workflow = self.cfg.training_cfg.workflow 70 | 71 | while self.epoch < EPOCH: 72 | for key, running_epoch in workflow.items(): 73 | epoch_runner = getattr(self, key) 74 | for _ in range(running_epoch): 75 | epoch_runner() 76 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | from scipy.spatial.transform import Rotation as R 4 | 5 | 6 | def affordance_eval(affordance_list, result): 7 | """_summary_ 8 | This fuction evaluates the affordance detection capability. 9 | `result` is loaded from result.pkl file produced by detect.py. 10 | """ 11 | num_correct = 0 12 | num_all = 0 13 | num_points = {aff: 0 for aff in affordance_list} 14 | num_label_points = {aff: 0 for aff in affordance_list} 15 | num_correct_fg_points = {aff: 0 for aff in affordance_list} 16 | num_correct_bg_points = {aff: 0 for aff in affordance_list} 17 | num_union_points = {aff: 0 for aff in affordance_list} 18 | num_appearances = {aff: 0 for aff in affordance_list} 19 | 20 | for shape in result: 21 | for affordance in shape['affordance']: 22 | label = np.transpose(shape['full_shape']['label'][affordance]) 23 | prediction = shape['result'][affordance][0] 24 | 25 | num_correct += np.sum(label == prediction) 26 | num_all += 2048 27 | num_points[affordance] += 2048 28 | num_label_points[affordance] += np.sum(label == 1.) 29 | num_correct_fg_points[affordance] += np.sum((label == 1.) & (prediction == 1.)) 30 | num_correct_bg_points[affordance] += np.sum((label == 0.) & (prediction == 0.)) 31 | num_union_points[affordance] += np.sum((label == 1.) | (prediction == 1.)) 32 | mIoU = np.average(np.array(list(num_correct_fg_points.values())) / np.array(list(num_union_points.values())), 33 | weights=np.array(list(num_appearances.values()))) 34 | Acc = num_correct / num_all 35 | mAcc = np.mean((np.array(list(num_correct_fg_points.values())) + np.array(list(num_correct_bg_points.values()))) / \ 36 | np.array(list(num_points.values()))) 37 | 38 | return mIoU, Acc, mAcc 39 | 40 | 41 | def pose_eval(result): 42 | """_summary_ 43 | This function evaluates the pose detection capability. 44 | `result` is loaded from result.pkl file produced by detect.py. 45 | """ 46 | all_min_dist = [] 47 | all_rate = [] 48 | for object in result: 49 | for affordance in object['affordance']: 50 | gt_poses = np.array([np.concatenate((R.from_matrix(p[:3, :3]).as_quat(), p[:3, 3]), axis=0) for p in object['pose'][affordance]]) 51 | distances = cdist(gt_poses, object['result'][affordance][1]) 52 | rate = np.sum(np.any(distances <= 0.2, axis=1)) / len(object['pose'][affordance]) 53 | all_rate.append(rate) 54 | 55 | g = gt_poses[:, np.newaxis, :] 56 | g_pred = object['result'][affordance][1] 57 | l2_distances = np.sqrt(np.sum((g-g_pred)**2, axis=2)) 58 | min_distance = np.min(l2_distances) 59 | 60 | # discard cases when set of gt poses and set of detected poses too far from each other, to get a stable result 61 | if min_distance <= 1.0: 62 | all_min_dist.append(min_distance) 63 | return (np.mean(np.array(all_min_dist)), np.mean(np.array(all_rate))) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
11 |
12 | We address the task of language-driven affordance-pose detection in 3D point clouds. Our method simultaneously detect open-vocabulary affordances and generate affordance-specific 6-DoF poses.
13 |
14 | 
15 |
16 | We present 3DAPNet, a new method for affordance-pose joint learning. Given the captured 3D point cloud of an object and a set of affordance labels conveyed through natural language texts, our objective is to jointly produce both the relevant affordance regions and the appropriate pose configurations that facilitate the affordances.
17 |
18 |
59 |
60 | ## 6. Citation
61 |
62 | If you find our work useful for your research, please cite:
63 | ```
64 | @inproceedings{Nguyen2024language,
65 | title={Language-Conditioned Affordance-Pose Detection in 3D Point Clouds},
66 | author={Nguyen, Toan and Vu, Minh Nhat and Huang, Baoru and Van Vo, Tuan and Truong, Vy and Le, Ngan and Vo, Thieu and Le, Bac and Nguyen, Anh},
67 | booktitle = ICRA,
68 | year = {2024}
69 | }
70 | ```
71 | Thank you very much.
72 |
73 | ## 7. Acknowledgement
74 |
75 | Our source code is built based on [3D AffordaceNet](https://github.com/Gorilla-Lab-SCUT/AffordanceNet). We express a huge thank to them.
--------------------------------------------------------------------------------
/utils/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LambdaLR, MultiStepLR
3 | from dataset import *
4 | from models import *
5 | from torch.utils.data import DataLoader
6 | from torch.optim import SGD, Adam
7 |
8 | # Pools of models, optimizers, weights initialization methods, schedulers
9 | model_pool = {
10 | 'detectiondiffusion': DetectionDiffusion,
11 | }
12 |
13 | optimizer_pool = {
14 | 'sgd': SGD,
15 | 'adam': Adam
16 | }
17 |
18 | init_pool = {
19 | 'default_init': weights_init
20 | }
21 |
22 | scheduler_pool = {
23 | 'step': StepLR,
24 | 'cos': CosineAnnealingLR,
25 | 'lr_lambda': LambdaLR,
26 | 'multi_step': MultiStepLR
27 | }
28 |
29 |
30 | def build_model(cfg):
31 | """_summary_
32 | Function to build the model before training
33 | """
34 | if hasattr(cfg, 'model'):
35 | model_info = cfg.model
36 | weights_init = model_info.get('weights_init', None)
37 | background_text = model_info.get('background_text', 'none')
38 | device = model_info.get('device', torch.device('cuda'))
39 | model_name = model_info.type
40 | model_cls = model_pool[model_name]
41 | if model_name in ['detectiondiffusion']:
42 | betas = model_info.get('betas', [1e-4, 0.02])
43 | n_T = model_info.get('n_T', 1000)
44 | drop_prob = model_info.get('drop_prob', 0.1)
45 | model = model_cls(betas, n_T, device, background_text, drop_prob)
46 | else:
47 | raise ValueError("The model name does not exist!")
48 | if weights_init != None:
49 | init_fn = init_pool[weights_init]
50 | model.apply(init_fn)
51 | return model
52 | else:
53 | raise ValueError("Configuration does not have model config!")
54 |
55 |
56 | def build_dataset(cfg):
57 | """_summary_
58 | Function to build the dataset
59 | """
60 | if hasattr(cfg, 'data'):
61 | data_info = cfg.data
62 | data_path = data_info.data_path
63 | train_set = ThreeDAPDataset(data_path, mode='train')
64 | val_set = ThreeDAPDataset(data_path, mode='val')
65 | test_set = ThreeDAPDataset(data_path, mode='test')
66 | dataset_dict = dict(
67 | train_set=train_set,
68 | val_set=val_set,
69 | test_set=test_set
70 | )
71 | return dataset_dict
72 | else:
73 | raise ValueError("Configuration does not have data config!")
74 |
75 |
76 | def build_loader(cfg, dataset_dict):
77 | """_summary_
78 | Function to build the loader
79 | """
80 | train_set = dataset_dict["train_set"]
81 | train_loader = DataLoader(train_set, batch_size=cfg.training_cfg.batch_size,
82 | shuffle=True, drop_last=False, num_workers=8)
83 | loader_dict = dict(
84 | train_loader=train_loader,
85 | )
86 |
87 | return loader_dict
88 |
89 |
90 | def build_optimizer(cfg, model):
91 | """_summary_
92 | Function to build the optimizer
93 | """
94 | optimizer_info = cfg.optimizer
95 | optimizer_type = optimizer_info.type
96 | optimizer_info.pop('type')
97 | optimizer_cls = optimizer_pool[optimizer_type]
98 | optimizer = optimizer_cls(model.parameters(), **optimizer_info)
99 | scheduler_info = cfg.scheduler
100 | if scheduler_info:
101 | scheduler_name = scheduler_info.type
102 | scheduler_info.pop('type')
103 | scheduler_cls = scheduler_pool[scheduler_name]
104 | scheduler = scheduler_cls(optimizer, **scheduler_info)
105 | else:
106 | scheduler = None
107 | optim_dict = dict(
108 | scheduler=scheduler,
109 | optimizer=optimizer
110 | )
111 | return optim_dict
112 |
--------------------------------------------------------------------------------
/models/main_nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .components import TextEncoder, PointNetPlusPlus, PoseNet
6 |
7 |
8 | text_encoder = TextEncoder(device=torch.device('cuda'))
9 |
10 |
11 | # Linear noise scheduler
12 | def linear_diffusion_schedule(betas, T):
13 | """_summary_
14 | Linear cheduling for sampling in training.
15 | """
16 | beta_t = (betas[1] - betas[0]) * torch.arange(0, T + 1, dtype=torch.float32) / T + betas[0]
17 | sqrt_beta_t = torch.sqrt(beta_t)
18 | alpha_t = 1 - beta_t
19 | log_alpha_t = torch.log(alpha_t)
20 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
21 |
22 | sqrtab = torch.sqrt(alphabar_t)
23 | oneover_sqrta = 1 / torch.sqrt(alpha_t)
24 |
25 | sqrtmab = torch.sqrt(1 - alphabar_t)
26 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
27 |
28 | return {
29 | "alpha_t": alpha_t, # \alpha_t
30 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
31 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
32 | "alphabar_t": alphabar_t, # \bar{\alpha_t}
33 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}}
34 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}}
35 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
36 | }
37 |
38 |
39 | # Main network for affordance detection and pose generation
40 | class DetectionDiffusion(nn.Module):
41 | def __init__(self, betas, n_T, device, background_text, drop_prob=0.1):
42 | """_summary_
43 |
44 | Args:
45 | drop_prob: probability to drop the conditions
46 | """
47 | super(DetectionDiffusion, self).__init__()
48 | self.posenet = PoseNet()
49 | self.pointnetplusplus = PointNetPlusPlus()
50 |
51 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
52 |
53 | # Register_buffer allows accessing dictionary, e.g. can access self.sqrtab later
54 | for k, v in linear_diffusion_schedule(betas, n_T).items():
55 | self.register_buffer(k, v)
56 |
57 | self.n_T = n_T
58 | self.device = device
59 | self.background_text = background_text
60 | self.drop_prob = drop_prob
61 | self.loss_mse = nn.MSELoss()
62 |
63 | def forward(self, xyz, text, affordance_label, g):
64 | """_summary_
65 | This method is used in training, so samples _ts and noise randomly.
66 | """
67 | B = xyz.shape[0] # xyz's size [B, 3, 2048]
68 | point_features, c = self.pointnetplusplus(xyz) # point_features' size [B, 512, 2048], c'size [B, 1024]
69 | with torch.no_grad():
70 | foreground_text_features = text_encoder(text) # size [B, 512]
71 | background_text_features = text_encoder([self.background_text] * B)
72 | text_features = torch.cat((background_text_features.unsqueeze(1), \
73 | foreground_text_features.unsqueeze(1)), dim=1) # size [B, 2, 512]
74 |
75 | affordance_prediction = self.logit_scale * torch.einsum('bij,bjk->bik', text_features, point_features) \
76 | / (torch.einsum('bij,bjk->bik', torch.norm(text_features, dim=2, keepdim=True), \
77 | torch.norm(point_features, dim=1, keepdim=True))) # size [B, 2, 2048]
78 |
79 | affordance_prediction = F.log_softmax(affordance_prediction, dim=1)
80 | affordance_loss = F.nll_loss(affordance_prediction, affordance_label)
81 |
82 | _ts = torch.randint(1, self.n_T + 1, (B,)).to(self.device)
83 | noise = torch.randn_like(g) # eps ~ N(0, 1), g size [B, 7]
84 | g_t = (
85 | self.sqrtab[_ts - 1, None] * g
86 | + self.sqrtmab[_ts - 1, None] * noise
87 | ) # This is the g_t, which is sqrt(alphabar) g_0 + sqrt(1-alphabar) * eps
88 |
89 | # dropout context with some probability
90 | context_mask = torch.bernoulli(torch.zeros(B, 1) + 1 - self.drop_prob).to(self.device)
91 |
92 | # Loss for poseing is MSE between added noise, and our predicted noise
93 | pose_loss = self.loss_mse(noise, self.posenet(g_t, c, foreground_text_features, context_mask, _ts / self.n_T))
94 | return affordance_loss, pose_loss
95 |
96 | def detect_and_sample(self, xyz, text, n_sample, guide_w):
97 | """_summary_
98 | Detect affordance for one point cloud and sample [n_sample] poses that support the 'text' affordance task,
99 | following the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'.
100 | """
101 | g_i = torch.randn(n_sample, (7)).to(self.device) # start by sampling from Gaussian noise
102 | point_features, c = self.pointnetplusplus(xyz) # point_features size [1, 512, 2048], c size [1, 1024]
103 | foreground_text_features = text_encoder(text) # size [1, 512]
104 | background_text_features = text_encoder([self.background_text] * 1)
105 | text_features = torch.cat((background_text_features.unsqueeze(1), \
106 | foreground_text_features.unsqueeze(1)), dim=1) # size [B, 2, 512]
107 |
108 | affordance_prediction = self.logit_scale * torch.einsum('bij,bjk->bik', text_features, point_features) \
109 | / (torch.einsum('bij,bjk->bik', torch.norm(text_features, dim=2, keepdim=True), \
110 | torch.norm(point_features, dim=1, keepdim=True))) # size [1, 2, 2048]
111 |
112 | affordance_prediction = F.log_softmax(affordance_prediction, dim=1) # .cpu().numpy()
113 | c_i = c.repeat(n_sample, 1)
114 | t_i = foreground_text_features.repeat(n_sample, 1)
115 | context_mask = torch.ones((n_sample, 1)).float().to(self.device)
116 |
117 | # Double the batch
118 | c_i = c_i.repeat(2, 1)
119 | t_i = t_i.repeat(2, 1)
120 | context_mask = context_mask.repeat(2, 1)
121 | context_mask[n_sample:] = 0. # make second half of the back context-free
122 |
123 | for i in range(self.n_T, 0, -1):
124 | _t_is = torch.tensor([i / self.n_T]).repeat(n_sample).repeat(2).to(self.device)
125 | g_i = g_i.repeat(2, 1)
126 |
127 | z = torch.randn(n_sample, (7)) if i > 1 else torch.zeros((n_sample, 7))
128 | z = z.to(self.device)
129 | eps = self.posenet(g_i, c_i, t_i, context_mask, _t_is)
130 | eps1 = eps[:n_sample]
131 | eps2 = eps[n_sample:]
132 | eps = (1 + guide_w) * eps1 - guide_w * eps2
133 |
134 | g_i = g_i[:n_sample]
135 | g_i = self.oneover_sqrta[i] * (g_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
136 | return np.argmax(affordance_prediction.cpu().numpy(), axis=1), g_i.cpu().numpy()
--------------------------------------------------------------------------------
/models/components.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import open_clip
4 | import math
5 | import torch.nn.functional as F
6 | from .pointnet_util import PointNetSetAbstractionMsg, PointNetSetAbstraction, PointNetFeaturePropagation
7 |
8 |
9 | class SinusoidalPositionEmbeddings(nn.Module):
10 | """
11 | Sinusoidal embedding for time step.
12 | """
13 | def __init__(self, dim, scale=1.0):
14 | super().__init__()
15 | self.dim = dim
16 | self.scale = scale
17 |
18 | def forward(self, time):
19 | time = time * self.scale
20 | device = time.device
21 | half_dim = self.dim // 2
22 | embeddings = math.log(10000) / (half_dim - 1 + 1e-5)
23 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
24 | embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0)
25 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
26 | return embeddings
27 |
28 | def __len__(self):
29 | return self.dim
30 |
31 |
32 | class TimeNet(nn.Module):
33 | """
34 | Time Embeddings
35 | """
36 | def __init__(self, dim):
37 | super().__init__()
38 | self.net = nn.Sequential(
39 | nn.Linear(1, dim),
40 | nn.GELU(),
41 | nn.Linear(dim, dim)
42 | )
43 | def forward(self, t):
44 | return self.net(t)
45 |
46 |
47 | class TextEncoder(nn.Module):
48 | """
49 | Text Encoder to encode the text prompt.
50 | """
51 | def __init__(self, device):
52 | super(TextEncoder, self).__init__()
53 | self.device = device
54 | self.clip_model, _, _ = open_clip.create_model_and_transforms("ViT-B-32", pretrained="laion2b_s34b_b79k",
55 | device=self.device)
56 |
57 | def forward(self, texts):
58 | """
59 | texts can be a single string or a list of strings.
60 | """
61 | tokenizer = open_clip.get_tokenizer("ViT-B-32")
62 | tokens = tokenizer(texts).to(self.device)
63 | text_features = self.clip_model.encode_text(tokens).to(self.device)
64 | return text_features
65 |
66 |
67 | class PointNetPlusPlus(nn.Module):
68 | """_summary_
69 | PointNet++ class.
70 | """
71 | def __init__(self):
72 | super(PointNetPlusPlus, self).__init__()
73 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [
74 | 32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
75 | self.sa2 = PointNetSetAbstractionMsg(
76 | 128, [0.4, 0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
77 | self.sa3 = PointNetSetAbstraction(
78 | npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
79 |
80 | self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
81 | self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
82 | self.fp1 = PointNetFeaturePropagation(in_channel=134, mlp=[128, 128])
83 |
84 | self.conv1 = nn.Conv1d(128, 512, 1)
85 | self.bn1 = nn.BatchNorm1d(512)
86 |
87 | def forward(self, xyz):
88 | """_summary_
89 | Return point-wise features and point cloud representation.
90 | """
91 | # Set Abstraction layers
92 | xyz = xyz.contiguous().transpose(1, 2)
93 | l0_xyz = xyz
94 | l0_points = xyz
95 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
96 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
97 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
98 | c = l3_points.squeeze()
99 |
100 | # Feature Propagation layers
101 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
102 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
103 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat(
104 | [l0_xyz, l0_points], 1), l1_points)
105 | l0_points = self.bn1(self.conv1(l0_points))
106 | return l0_points, c
107 |
108 |
109 | class PoseNet(nn.Module):
110 | """_summary_
111 | ContextPoseNet class. This class is for a denoising step in the diffusion.
112 | """
113 | def __init__(self):
114 | super(PoseNet, self).__init__()
115 | self.cloud_net0 = nn.Sequential(
116 | nn.Linear(1024, 512),
117 | nn.GroupNorm(8, 512),
118 | nn.GELU(),
119 | nn.Linear(512, 128),
120 | nn.GELU(),
121 | nn.Linear(128, 32)
122 | )
123 | self.cloud_net3 = nn.Sequential(
124 | nn.Linear(32, 16),
125 | nn.GroupNorm(4, 16),
126 | nn.GELU(),
127 | nn.Linear(16, 6)
128 | )
129 | self.cloud_net2 = nn.Sequential(
130 | nn.Linear(32, 16),
131 | nn.GroupNorm(4, 16),
132 | nn.GELU(),
133 | nn.Linear(16, 4)
134 | )
135 | self.cloud_net1 = nn.Sequential(
136 | nn.Linear(32, 16),
137 | nn.GroupNorm(4, 16),
138 | nn.GELU(),
139 | nn.Linear(16, 2)
140 | )
141 | self.cloud_influence_net3 = nn.Sequential(
142 | nn.Linear(6 + 6 + 7, 6),
143 | nn.GELU(),
144 | nn.Linear(6, 6)
145 | )
146 | self.cloud_influence_net2 = nn.Sequential(
147 | nn.Linear(4 + 4 + 7, 4),
148 | nn.GELU(),
149 | nn.Linear(4, 4)
150 | )
151 | self.cloud_influence_net1 = nn.Sequential(
152 | nn.Linear(2 + 2 + 7, 2),
153 | nn.GELU(),
154 | nn.Linear(2, 2)
155 | )
156 |
157 | self.text_net0 = nn.Sequential(
158 | nn.Linear(512, 256),
159 | nn.GroupNorm(8, 256),
160 | nn.GELU(),
161 | nn.Linear(256, 128),
162 | nn.GELU(),
163 | nn.Linear(128, 32)
164 | )
165 | self.text_net3 = nn.Sequential(
166 | nn.Linear(32, 16),
167 | nn.GroupNorm(4, 16),
168 | nn.GELU(),
169 | nn.Linear(16, 6)
170 | )
171 | self.text_net2 = nn.Sequential(
172 | nn.Linear(32, 16),
173 | nn.GroupNorm(4, 16),
174 | nn.GELU(),
175 | nn.Linear(16, 4)
176 | )
177 | self.text_net1 = nn.Sequential(
178 | nn.Linear(32, 16),
179 | nn.GroupNorm(4, 16),
180 | nn.GELU(),
181 | nn.Linear(16, 2)
182 | )
183 | self.text_influence_net3 = nn.Sequential(
184 | nn.Linear(6 + 6 + 7, 6),
185 | nn.GELU(),
186 | nn.Linear(6, 6)
187 | )
188 | self.text_influence_net2 = nn.Sequential(
189 | nn.Linear(4 + 4 + 7, 4),
190 | nn.GELU(),
191 | nn.Linear(4, 4)
192 | )
193 | self.text_influence_net1 = nn.Sequential(
194 | nn.Linear(2 + 2 + 7, 2),
195 | nn.GELU(),
196 | nn.Linear(2, 2)
197 | )
198 |
199 | # self.time_net3 = SinusoidalPositionEmbeddings(dim=6)
200 | # self.time_net2 = SinusoidalPositionEmbeddings(dim=4)
201 | # self.time_net1 = SinusoidalPositionEmbeddings(dim=2)
202 | self.time_net3 = TimeNet(dim=6)
203 | self.time_net2 = TimeNet(dim=4)
204 | self.time_net1 = TimeNet(dim=2)
205 |
206 | self.down1 = nn.Sequential(
207 | nn.Linear(7, 6),
208 | nn.GELU(),
209 | nn.Linear(6, 6)
210 | )
211 | self.down2 = nn.Sequential(
212 | nn.Linear(6, 4),
213 | nn.GELU(),
214 | nn.Linear(4, 4)
215 | )
216 | self.down3 = nn.Sequential(
217 | nn.Linear(4, 2),
218 | nn.GELU(),
219 | nn.Linear(2, 2)
220 | )
221 |
222 | self.up1 = nn.Sequential(
223 | nn.Linear(2 + 4, 4),
224 | nn.GELU(),
225 | nn.Linear(4, 4)
226 | )
227 | self.up2 = nn.Sequential(
228 | nn.Linear(4 + 6, 6),
229 | nn.GELU(),
230 | nn.Linear(6, 6)
231 | )
232 | self.up3 = nn.Sequential(
233 | nn.Linear(6 + 7, 7),
234 | nn.GELU(),
235 | nn.Linear(7, 7)
236 | )
237 |
238 | def forward(self, g, c, t, context_mask, _t):
239 | """_summary_
240 | Args:
241 | g: pose representations, size [B, 7]
242 | c: point cloud representations, size [B, 1024]
243 | t: affordance texts, size [B, 512]
244 | context_mask: masks {0, 1} for the contexts, size [B, 1]
245 | _t is for the timesteps, size [B,]
246 | """
247 | c = c * context_mask
248 | c0 = self.cloud_net0(c)
249 | c1 = self.cloud_net1(c0)
250 | c2 = self.cloud_net2(c0)
251 | c3 = self.cloud_net3(c0)
252 |
253 | t = t * context_mask
254 | t0 = self.text_net0(t)
255 | t1 = self.text_net1(t0)
256 | t2 = self.text_net2(t0)
257 | t3 = self.text_net3(t0)
258 |
259 | _t0 = _t.unsqueeze(1)
260 | _t1 = self.time_net1(_t0)
261 | _t2 = self.time_net2(_t0)
262 | _t3 = self.time_net3(_t0)
263 |
264 | g = g.float()
265 | g_down1 = self.down1(g) # 6
266 | g_down2 = self.down2(g_down1) # 4
267 | g_down3 = self.down3(g_down2) # 2
268 |
269 | c1_influence = self.cloud_influence_net1(torch.cat((c1, g, _t1), dim=1))
270 | t1_influence = self.text_influence_net1(torch.cat((t1, g, _t1), dim=1))
271 | influences1 = F.softmax(torch.cat((c1_influence.unsqueeze(1), t1_influence.unsqueeze(1)), dim=1), dim=1)
272 | ct1 = (c1 * influences1[:, 0, :] + t1 * influences1[:, 1, :])
273 | up1 = self.up1(torch.cat((g_down3 * ct1 + _t1, g_down2), dim=1))
274 |
275 | c2_influence = self.cloud_influence_net2(torch.cat((c2, g, _t2), dim=1))
276 | t2_influence = self.text_influence_net2(torch.cat((t2, g, _t2), dim=1))
277 | influences2 = F.softmax(torch.cat((c2_influence.unsqueeze(1), t2_influence.unsqueeze(1)), dim=1), dim=1)
278 | ct2 = (c2 * influences2[:, 0, :] + t2 * influences2[:, 1, :])
279 | up2 = self.up2(torch.cat((up1 * ct2 + _t2, g_down1), dim=1))
280 |
281 | c3_influence = self.cloud_influence_net3(torch.cat((c3, g, _t3), dim=1))
282 | t3_influence = self.text_influence_net3(torch.cat((t3, g, _t3), dim=1))
283 | influences3 = F.softmax(torch.cat((c3_influence.unsqueeze(1), t3_influence.unsqueeze(1)), dim=1), dim=1)
284 | ct3 = (c3 * influences3[:, 0, :] + t3 * influences3[:, 1, :])
285 | up3 = self.up3(torch.cat((up2 * ct3 + _t3, g), dim=1)) # size [B, 7]
286 |
287 | return up3
--------------------------------------------------------------------------------
/models/pointnet_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 |
8 | def timeit(tag, t):
9 | print("{}: {}s".format(tag, time() - t))
10 | return time()
11 |
12 |
13 | def pc_normalize(pc):
14 | l = pc.shape[0]
15 | centroid = np.mean(pc, axis=0)
16 | pc = pc - centroid
17 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
18 | pc = pc / m
19 | return pc
20 |
21 |
22 | def square_distance(src, dst):
23 | """_summary_
24 | Calculate Euclid distance between each two points.
25 |
26 | src^T * dst = xn * xm + yn * ym + zn * zm;
27 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
28 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
29 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
30 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
31 |
32 | Input:
33 | src: source points, [B, N, C]
34 | dst: target points, [B, M, C]
35 | Output:
36 | dist: per-point square distance, [B, N, M]
37 | """
38 | B, N, _ = src.shape
39 | _, M, _ = dst.shape
40 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
41 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
42 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
43 | return dist
44 |
45 |
46 | def index_points(points, idx):
47 | """_summary_
48 | Input:
49 | points: input points data, [B, N, C]
50 | idx: sample index data, [B, S]
51 | Return:
52 | new_points:, indexed points data, [B, S, C]
53 | """
54 | device = points.device
55 | B = points.shape[0]
56 | view_shape = list(idx.shape)
57 | view_shape[1:] = [1] * (len(view_shape) - 1)
58 | repeat_shape = list(idx.shape)
59 | repeat_shape[0] = 1
60 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
61 | new_points = points[batch_indices, idx, :]
62 | return new_points
63 |
64 |
65 | def farthest_point_sample(xyz, npoint):
66 | """_summary_
67 | Input:
68 | xyz: pointcloud data, [B, N, 3]
69 | npoint: number of samples
70 | Return:
71 | centroids: sampled pointcloud index, [B, npoint]
72 | """
73 | device = xyz.device
74 | B, N, C = xyz.shape
75 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
76 | distance = torch.ones(B, N).to(device) * 1e10
77 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
78 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
79 | for i in range(npoint):
80 | centroids[:, i] = farthest
81 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
82 | dist = torch.sum((xyz - centroid) ** 2, -1)
83 | mask = dist < distance
84 | distance[mask] = dist[mask]
85 | farthest = torch.max(distance, -1)[1]
86 | return centroids
87 |
88 |
89 | def query_ball_point(radius, nsample, xyz, new_xyz):
90 | """_summary_
91 | Input:
92 | radius: local region radius
93 | nsample: max sample number in local region
94 | xyz: all points, [B, N, 3]
95 | new_xyz: query points, [B, S, 3]
96 | Return:
97 | group_idx: grouped points index, [B, S, nsample]
98 | """
99 | device = xyz.device
100 | B, N, C = xyz.shape
101 | _, S, _ = new_xyz.shape
102 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
103 | sqrdists = square_distance(new_xyz, xyz)
104 | group_idx[sqrdists > radius ** 2] = N
105 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
106 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
107 | mask = group_idx == N
108 | group_idx[mask] = group_first[mask]
109 | return group_idx
110 |
111 |
112 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
113 | """_summary_
114 | Input:
115 | npoint:
116 | radius:
117 | nsample:
118 | xyz: input points position data, [B, N, 3]
119 | points: input points data, [B, N, D]
120 | Return:
121 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
122 | new_points: sampled points data, [B, npoint, nsample, 3+D]
123 | """
124 | B, N, C = xyz.shape
125 | S = npoint
126 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
127 | new_xyz = index_points(xyz, fps_idx)
128 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
129 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
130 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
131 |
132 | if points is not None:
133 | grouped_points = index_points(points, idx)
134 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
135 | else:
136 | new_points = grouped_xyz_norm
137 | if returnfps:
138 | return new_xyz, new_points, grouped_xyz, fps_idx
139 | else:
140 | return new_xyz, new_points
141 |
142 |
143 | def sample_and_group_all(xyz, points):
144 | """_summary_
145 | Input:
146 | xyz: input points position data, [B, N, 3]
147 | points: input points data, [B, N, D]
148 | Return:
149 | new_xyz: sampled points position data, [B, 1, 3]
150 | new_points: sampled points data, [B, 1, N, 3+D]
151 | """
152 | device = xyz.device
153 | B, N, C = xyz.shape
154 | new_xyz = torch.zeros(B, 1, C).to(device)
155 | grouped_xyz = xyz.view(B, 1, N, C)
156 | if points is not None:
157 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
158 | else:
159 | new_points = grouped_xyz
160 | return new_xyz, new_points
161 |
162 |
163 | class PointNetSetAbstraction(nn.Module):
164 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
165 | super(PointNetSetAbstraction, self).__init__()
166 | self.npoint = npoint
167 | self.radius = radius
168 | self.nsample = nsample
169 | self.mlp_convs = nn.ModuleList()
170 | self.mlp_bns = nn.ModuleList()
171 | last_channel = in_channel
172 | for out_channel in mlp:
173 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
174 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
175 | last_channel = out_channel
176 | self.group_all = group_all
177 |
178 | def forward(self, xyz, points):
179 | """_summary_
180 | Input:
181 | xyz: input points position data, [B, C, N]
182 | points: input points data, [B, D, N]
183 | Return:
184 | new_xyz: sampled points position data, [B, C, S]
185 | new_points_concat: sample points feature data, [B, D', S]
186 | """
187 | xyz = xyz.permute(0, 2, 1)
188 | if points is not None:
189 | points = points.permute(0, 2, 1)
190 |
191 | if self.group_all:
192 | new_xyz, new_points = sample_and_group_all(xyz, points)
193 | else:
194 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
195 | # new_xyz: sampled points position data, [B, npoint, C]
196 | # new_points: sampled points data, [B, npoint, nsample, C+D]
197 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
198 | for i, conv in enumerate(self.mlp_convs):
199 | bn = self.mlp_bns[i]
200 | new_points = F.relu(bn(conv(new_points)))
201 |
202 | new_points = torch.max(new_points, 2)[0]
203 | new_xyz = new_xyz.permute(0, 2, 1)
204 | return new_xyz, new_points
205 |
206 |
207 | class PointNetSetAbstractionMsg(nn.Module):
208 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
209 | super(PointNetSetAbstractionMsg, self).__init__()
210 | self.npoint = npoint
211 | self.radius_list = radius_list
212 | self.nsample_list = nsample_list
213 | self.conv_blocks = nn.ModuleList()
214 | self.bn_blocks = nn.ModuleList()
215 | for i in range(len(mlp_list)):
216 | convs = nn.ModuleList()
217 | bns = nn.ModuleList()
218 | last_channel = in_channel + 3
219 | for out_channel in mlp_list[i]:
220 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
221 | bns.append(nn.BatchNorm2d(out_channel))
222 | last_channel = out_channel
223 | self.conv_blocks.append(convs)
224 | self.bn_blocks.append(bns)
225 |
226 | def forward(self, xyz, points):
227 | """_summary_
228 | Input:
229 | xyz: input points position data, [B, C, N]
230 | points: input points data, [B, D, N]
231 | Return:
232 | new_xyz: sampled points position data, [B, C, S]
233 | new_points_concat: sample points feature data, [B, D', S]
234 | """
235 | xyz = xyz.permute(0, 2, 1)
236 | if points is not None:
237 | points = points.permute(0, 2, 1)
238 |
239 | B, N, C = xyz.shape
240 | S = self.npoint
241 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
242 | new_points_list = []
243 | for i, radius in enumerate(self.radius_list):
244 | K = self.nsample_list[i]
245 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
246 | grouped_xyz = index_points(xyz, group_idx)
247 | grouped_xyz -= new_xyz.view(B, S, 1, C)
248 | if points is not None:
249 | grouped_points = index_points(points, group_idx)
250 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
251 | else:
252 | grouped_points = grouped_xyz
253 |
254 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
255 | for j in range(len(self.conv_blocks[i])):
256 | conv = self.conv_blocks[i][j]
257 | bn = self.bn_blocks[i][j]
258 | grouped_points = F.relu(bn(conv(grouped_points)))
259 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
260 | new_points_list.append(new_points)
261 |
262 | new_xyz = new_xyz.permute(0, 2, 1)
263 | new_points_concat = torch.cat(new_points_list, dim=1)
264 | return new_xyz, new_points_concat
265 |
266 |
267 | class PointNetFeaturePropagation(nn.Module):
268 | def __init__(self, in_channel, mlp):
269 | super(PointNetFeaturePropagation, self).__init__()
270 | self.mlp_convs = nn.ModuleList()
271 | self.mlp_bns = nn.ModuleList()
272 | last_channel = in_channel
273 | for out_channel in mlp:
274 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
275 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
276 | last_channel = out_channel
277 |
278 | def forward(self, xyz1, xyz2, points1, points2):
279 | """_summary_
280 | Input:
281 | xyz1: input points position data, [B, C, N]
282 | xyz2: sampled input points position data, [B, C, S]
283 | points1: input points data, [B, D, N]
284 | points2: input points data, [B, D, S]
285 | Return:
286 | new_points: upsampled points data, [B, D', N]
287 | """
288 | xyz1 = xyz1.permute(0, 2, 1)
289 | xyz2 = xyz2.permute(0, 2, 1)
290 |
291 | points2 = points2.permute(0, 2, 1)
292 | B, N, C = xyz1.shape
293 | _, S, _ = xyz2.shape
294 |
295 | if S == 1:
296 | interpolated_points = points2.repeat(1, N, 1)
297 | else:
298 | dists = square_distance(xyz1, xyz2)
299 | dists, idx = dists.sort(dim=-1)
300 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
301 |
302 | dist_recip = 1.0 / (dists + 1e-8)
303 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
304 | weight = dist_recip / norm
305 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
306 |
307 | if points1 is not None:
308 | points1 = points1.permute(0, 2, 1)
309 | new_points = torch.cat([points1, interpolated_points], dim=-1)
310 | else:
311 | new_points = interpolated_points
312 |
313 | new_points = new_points.permute(0, 2, 1)
314 | for i, conv in enumerate(self.mlp_convs):
315 | bn = self.mlp_bns[i]
316 | new_points = F.relu(bn(conv(new_points)))
317 | return new_points
--------------------------------------------------------------------------------