├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── intro.png ├── method.png └── visualization.png ├── config └── detectiondiffusion.py ├── dataset ├── ThreeDAPDataset.py └── __init__.py ├── detect.py ├── models ├── __init__.py ├── components.py ├── main_nets.py ├── pointnet_util.py └── weights_init.py ├── requirements.txt ├── test.py ├── train.py ├── utils ├── __init__.py ├── builder.py ├── eval.py ├── trainer.py ├── utils.py └── visualization.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | log/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorchse3"] 2 | path = pytorchse3 3 | url = https://github.com/eigenvivek/pytorchse3 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Toan Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Language-Conditioned Affordance-Pose Detection in 3D Point Clouds 4 | 5 | [![Conference](https://img.shields.io/badge/ICRA-2024-FF0B0B.svg)](https://2024.ieee-icra.org/) 6 | [![Paper](https://img.shields.io/badge/Paper-arxiv.2303.02401-0009F6.svg)](https://arxiv.org/abs/2309.10911) 7 | 8 | Official code for the ICRA 2024 paper "Language-Conditioned Affordance-Pose Detection in 3D Point Clouds". 9 | 10 | 11 | 12 | We address the task of language-driven affordance-pose detection in 3D point clouds. Our method simultaneously detect open-vocabulary affordances and generate affordance-specific 6-DoF poses. 13 | 14 | ![image](./assets/method.png) 15 | 16 | We present 3DAPNet, a new method for affordance-pose joint learning. Given the captured 3D point cloud of an object and a set of affordance labels conveyed through natural language texts, our objective is to jointly produce both the relevant affordance regions and the appropriate pose configurations that facilitate the affordances. 17 | 18 |
19 | 20 | 21 | ## 1. Getting Started 22 | We strongly encourage you to create a separate conda environment. 23 | 24 | conda create -n affpose python=3.8 25 | conda activate affpose 26 | conda install pip 27 | pip install -r requirements.txt 28 | 29 | ## 2. Dataset 30 | Our 3DAP dataset is available at [this drive folder](https://drive.google.com/drive/folders/1vDGHs3QZmmF2rGluGlqBIyCp8sPR4Yws?usp=sharing). 31 | 32 | ## 3. Training 33 | Current framework supports training on a single GPU. Followings are the steps for training our method with configuration file ```config/detectiondiffusion.py```. 34 | 35 | * In ```config/detectiondiffusion.py```, change the value of ```data_path``` to your downloaded pickle file. 36 | * Change other hyperparameters if needed. 37 | * Run the following command to start training: 38 | 39 | python3 train.py --config ./config/detectiondiffusion.py 40 | 41 | ## 4. Testing 42 | Executing the following command for testing of your trained model: 43 | 44 | python3 detect.py --config --checkpoint --test_data 45 | 46 | Note that we current generate 2000 poses for each affordance-object pair. 47 | The guidance scale is currently set to 0.2. Feel free to change these hyperparameters according to your preference. 48 | 49 | The result will be saved to a ```result.pkl``` file. 50 | 51 | ## 5. Visualization 52 | To visuaize the result of affordance detection and pose estimation, execute the following script: 53 | 54 | python3 visualize.py --result_file 55 | 56 | Example of training data visualization: 57 | 58 | 59 | 60 | ## 6. Citation 61 | 62 | If you find our work useful for your research, please cite: 63 | ``` 64 | @inproceedings{Nguyen2024language, 65 | title={Language-Conditioned Affordance-Pose Detection in 3D Point Clouds}, 66 | author={Nguyen, Toan and Vu, Minh Nhat and Huang, Baoru and Van Vo, Tuan and Truong, Vy and Le, Ngan and Vo, Thieu and Le, Bac and Nguyen, Anh}, 67 | booktitle = ICRA, 68 | year = {2024} 69 | } 70 | ``` 71 | Thank you very much. 72 | 73 | ## 7. Acknowledgement 74 | 75 | Our source code is built based on [3D AffordaceNet](https://github.com/Gorilla-Lab-SCUT/AffordanceNet). We express a huge thank to them. -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/1ec2917f53ea0925ab214fb560c4056751c84bf7/assets/intro.png -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/1ec2917f53ea0925ab214fb560c4056751c84bf7/assets/method.png -------------------------------------------------------------------------------- /assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/1ec2917f53ea0925ab214fb560c4056751c84bf7/assets/visualization.png -------------------------------------------------------------------------------- /config/detectiondiffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from os.path import join as opj 4 | from utils import PN2_BNMomentum 5 | 6 | exp_name = 'detectiondiffusion' 7 | seed = 1 8 | log_dir = opj("./log/", exp_name) 9 | try: 10 | os.makedirs(log_dir) 11 | except: 12 | print('Logging Dir is already existed!') 13 | 14 | # scheduler = dict( 15 | # type='lr_lambda', 16 | # lr_lambda=PN2_Scheduler(init_lr=0.001, step=20, 17 | # decay_rate=0.5, min_lr=1e-5) 18 | # ) 19 | 20 | scheduler = None 21 | 22 | optimizer = dict( 23 | type='adam', 24 | lr=1e-3, 25 | betas=(0.9, 0.999), 26 | eps=1e-08, 27 | weight_decay=1e-5, 28 | ) 29 | 30 | model = dict( 31 | type='detectiondiffusion', 32 | device=torch.device('cuda'), 33 | background_text='none', 34 | betas=[1e-4, 0.02], 35 | n_T=1000, 36 | drop_prob=0.1, 37 | weights_init='default_init', 38 | ) 39 | 40 | training_cfg = dict( 41 | model=model, 42 | batch_size=32, 43 | epoch=200, 44 | gpu='0', 45 | workflow=dict( 46 | train=1, 47 | ), 48 | bn_momentum=PN2_BNMomentum(origin_m=0.1, m_decay=0.5, step=20), 49 | ) 50 | 51 | data = dict( 52 | data_path="../full_shape_release.pkl", 53 | ) -------------------------------------------------------------------------------- /dataset/ThreeDAPDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.utils.data import Dataset 3 | import pickle as pkl 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | class ThreeDAPDataset(Dataset): 8 | """_summary_ 9 | This class is for the data loading. 10 | """ 11 | def __init__(self, data_path, mode): 12 | """_summary_ 13 | 14 | Args: 15 | data_path (str): path to the dataset 16 | """ 17 | super().__init__() 18 | self.data_path = data_path 19 | self.mode = mode 20 | if self.mode in ["train", "val", "test"]: 21 | self._load_data() 22 | else: 23 | raise ValueError("Mode must be train, val, or test!") 24 | 25 | def _load_data(self): 26 | self.all_data = [] 27 | 28 | with open(self.data_path, "rb") as f: 29 | data = pkl.load(f) 30 | random.shuffle(data) 31 | 32 | if self.mode == "train": data = data[:int(0.7 * len(data))] 33 | elif self.mode == "val": data = data[int(0.7 * len(data)):int(0.8 * len(data))] 34 | else: data = data[int(0.8 * len(data)):] 35 | 36 | for data_point in data: 37 | for affordance in data_point["affordance"]: 38 | for pose in data_point["pose"][affordance]: 39 | new_data_dict = { 40 | "shape_id": data_point["shape_id"], 41 | "semantic class": data_point["semantic class"], 42 | "point cloud": data_point["full_shape"]["coordinate"], 43 | "affordance": affordance, 44 | "affordance label": data_point["full_shape"]["label"][affordance], 45 | "rotation": R.from_matrix(pose[:3, :3]).as_quat(), 46 | "translation": pose[:3, 3] 47 | } 48 | self.all_data.append(new_data_dict) 49 | 50 | def __getitem__(self, index): 51 | """_summary_ 52 | 53 | Args: 54 | index (int): the element index 55 | 56 | Returns: 57 | shape id, semantic class, coordinate, affordance text, affordance label, rotation and translation 58 | """ 59 | data_dict = self.all_data[index] 60 | return data_dict['shape_id'], data_dict['semantic class'], data_dict['point cloud'], data_dict['affordance'], \ 61 | data_dict['affordance label'], data_dict['rotation'], data_dict['translation'] 62 | 63 | def __len__(self): 64 | return len(self.all_data) 65 | 66 | 67 | if __name__ == "__main__": 68 | random.seed(1) 69 | dataset = ThreeDAPDataset(data_path="../full_shape_release.pkl", mode="train") 70 | print(len(dataset)) -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .ThreeDAPDataset import ThreeDAPDataset 2 | 3 | 4 | __all__ = ['ThreeDAPDataset'] -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from gorilla.config import Config 4 | from utils import * 5 | import argparse 6 | import pickle 7 | from tqdm import tqdm 8 | import random 9 | 10 | 11 | GUIDE_W = 0.2 12 | DEVICE = torch.device('cuda') 13 | 14 | 15 | # Argument Parser 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="Detect affordance and poses") 18 | parser.add_argument("--config", help="test config file path") 19 | parser.add_argument("--checkpoint", help="path to checkpoint model") 20 | parser.add_argument("--test_data", help="path to test_data") 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | if __name__ == "__main__": 26 | args = parse_args() 27 | cfg = Config.fromfile(args.config) 28 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.training_cfg.gpu 29 | model = build_model(cfg).to(DEVICE) 30 | 31 | if args.checkpoint != None: 32 | print("Loading checkpoint....") 33 | _, exten = os.path.splitext(args.checkpoint) 34 | if exten == '.t7': 35 | model.load_state_dict(torch.load(args.checkpoint)) 36 | elif exten == '.pth': 37 | check = torch.load(args.checkpoint) 38 | model.load_state_dict(check['model_state_dict']) 39 | else: 40 | raise ValueError("Must specify a checkpoint path!") 41 | 42 | if cfg.get('seed') != None: 43 | set_random_seed(cfg.seed) 44 | 45 | with open(args.test_data, 'rb') as f: 46 | shape_data = pickle.load(f) 47 | random.shuffle(shape_data) 48 | shape_data = shape_data[int(0.8 * len(shape_data)):] 49 | 50 | print("Detecting") 51 | model.eval() 52 | with torch.no_grad(): 53 | for shape in tqdm(shape_data): 54 | xyz = torch.from_numpy(shape['full_shape']['coordinate']).unsqueeze(0).float().cuda() 55 | shape['result'] = {text: [*(model.detect_and_sample(xyz, text, 2000, guide_w=GUIDE_W))] for text in shape['affordance']} 56 | 57 | with open(f'{cfg.log_dir}/result.pkl', 'wb') as f: 58 | pickle.dump(shape_data, f) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .main_nets import DetectionDiffusion 2 | from .weights_init import weights_init 3 | 4 | 5 | __all__ = ['DetectionDiffusion', 'weights_init'] -------------------------------------------------------------------------------- /models/components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import open_clip 4 | import math 5 | import torch.nn.functional as F 6 | from .pointnet_util import PointNetSetAbstractionMsg, PointNetSetAbstraction, PointNetFeaturePropagation 7 | 8 | 9 | class SinusoidalPositionEmbeddings(nn.Module): 10 | """ 11 | Sinusoidal embedding for time step. 12 | """ 13 | def __init__(self, dim, scale=1.0): 14 | super().__init__() 15 | self.dim = dim 16 | self.scale = scale 17 | 18 | def forward(self, time): 19 | time = time * self.scale 20 | device = time.device 21 | half_dim = self.dim // 2 22 | embeddings = math.log(10000) / (half_dim - 1 + 1e-5) 23 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 24 | embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) 25 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 26 | return embeddings 27 | 28 | def __len__(self): 29 | return self.dim 30 | 31 | 32 | class TimeNet(nn.Module): 33 | """ 34 | Time Embeddings 35 | """ 36 | def __init__(self, dim): 37 | super().__init__() 38 | self.net = nn.Sequential( 39 | nn.Linear(1, dim), 40 | nn.GELU(), 41 | nn.Linear(dim, dim) 42 | ) 43 | def forward(self, t): 44 | return self.net(t) 45 | 46 | 47 | class TextEncoder(nn.Module): 48 | """ 49 | Text Encoder to encode the text prompt. 50 | """ 51 | def __init__(self, device): 52 | super(TextEncoder, self).__init__() 53 | self.device = device 54 | self.clip_model, _, _ = open_clip.create_model_and_transforms("ViT-B-32", pretrained="laion2b_s34b_b79k", 55 | device=self.device) 56 | 57 | def forward(self, texts): 58 | """ 59 | texts can be a single string or a list of strings. 60 | """ 61 | tokenizer = open_clip.get_tokenizer("ViT-B-32") 62 | tokens = tokenizer(texts).to(self.device) 63 | text_features = self.clip_model.encode_text(tokens).to(self.device) 64 | return text_features 65 | 66 | 67 | class PointNetPlusPlus(nn.Module): 68 | """_summary_ 69 | PointNet++ class. 70 | """ 71 | def __init__(self): 72 | super(PointNetPlusPlus, self).__init__() 73 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [ 74 | 32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 75 | self.sa2 = PointNetSetAbstractionMsg( 76 | 128, [0.4, 0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]]) 77 | self.sa3 = PointNetSetAbstraction( 78 | npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True) 79 | 80 | self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256]) 81 | self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128]) 82 | self.fp1 = PointNetFeaturePropagation(in_channel=134, mlp=[128, 128]) 83 | 84 | self.conv1 = nn.Conv1d(128, 512, 1) 85 | self.bn1 = nn.BatchNorm1d(512) 86 | 87 | def forward(self, xyz): 88 | """_summary_ 89 | Return point-wise features and point cloud representation. 90 | """ 91 | # Set Abstraction layers 92 | xyz = xyz.contiguous().transpose(1, 2) 93 | l0_xyz = xyz 94 | l0_points = xyz 95 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 96 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 97 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 98 | c = l3_points.squeeze() 99 | 100 | # Feature Propagation layers 101 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 102 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 103 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat( 104 | [l0_xyz, l0_points], 1), l1_points) 105 | l0_points = self.bn1(self.conv1(l0_points)) 106 | return l0_points, c 107 | 108 | 109 | class PoseNet(nn.Module): 110 | """_summary_ 111 | ContextPoseNet class. This class is for a denoising step in the diffusion. 112 | """ 113 | def __init__(self): 114 | super(PoseNet, self).__init__() 115 | self.cloud_net0 = nn.Sequential( 116 | nn.Linear(1024, 512), 117 | nn.GroupNorm(8, 512), 118 | nn.GELU(), 119 | nn.Linear(512, 128), 120 | nn.GELU(), 121 | nn.Linear(128, 32) 122 | ) 123 | self.cloud_net3 = nn.Sequential( 124 | nn.Linear(32, 16), 125 | nn.GroupNorm(4, 16), 126 | nn.GELU(), 127 | nn.Linear(16, 6) 128 | ) 129 | self.cloud_net2 = nn.Sequential( 130 | nn.Linear(32, 16), 131 | nn.GroupNorm(4, 16), 132 | nn.GELU(), 133 | nn.Linear(16, 4) 134 | ) 135 | self.cloud_net1 = nn.Sequential( 136 | nn.Linear(32, 16), 137 | nn.GroupNorm(4, 16), 138 | nn.GELU(), 139 | nn.Linear(16, 2) 140 | ) 141 | self.cloud_influence_net3 = nn.Sequential( 142 | nn.Linear(6 + 6 + 7, 6), 143 | nn.GELU(), 144 | nn.Linear(6, 6) 145 | ) 146 | self.cloud_influence_net2 = nn.Sequential( 147 | nn.Linear(4 + 4 + 7, 4), 148 | nn.GELU(), 149 | nn.Linear(4, 4) 150 | ) 151 | self.cloud_influence_net1 = nn.Sequential( 152 | nn.Linear(2 + 2 + 7, 2), 153 | nn.GELU(), 154 | nn.Linear(2, 2) 155 | ) 156 | 157 | self.text_net0 = nn.Sequential( 158 | nn.Linear(512, 256), 159 | nn.GroupNorm(8, 256), 160 | nn.GELU(), 161 | nn.Linear(256, 128), 162 | nn.GELU(), 163 | nn.Linear(128, 32) 164 | ) 165 | self.text_net3 = nn.Sequential( 166 | nn.Linear(32, 16), 167 | nn.GroupNorm(4, 16), 168 | nn.GELU(), 169 | nn.Linear(16, 6) 170 | ) 171 | self.text_net2 = nn.Sequential( 172 | nn.Linear(32, 16), 173 | nn.GroupNorm(4, 16), 174 | nn.GELU(), 175 | nn.Linear(16, 4) 176 | ) 177 | self.text_net1 = nn.Sequential( 178 | nn.Linear(32, 16), 179 | nn.GroupNorm(4, 16), 180 | nn.GELU(), 181 | nn.Linear(16, 2) 182 | ) 183 | self.text_influence_net3 = nn.Sequential( 184 | nn.Linear(6 + 6 + 7, 6), 185 | nn.GELU(), 186 | nn.Linear(6, 6) 187 | ) 188 | self.text_influence_net2 = nn.Sequential( 189 | nn.Linear(4 + 4 + 7, 4), 190 | nn.GELU(), 191 | nn.Linear(4, 4) 192 | ) 193 | self.text_influence_net1 = nn.Sequential( 194 | nn.Linear(2 + 2 + 7, 2), 195 | nn.GELU(), 196 | nn.Linear(2, 2) 197 | ) 198 | 199 | # self.time_net3 = SinusoidalPositionEmbeddings(dim=6) 200 | # self.time_net2 = SinusoidalPositionEmbeddings(dim=4) 201 | # self.time_net1 = SinusoidalPositionEmbeddings(dim=2) 202 | self.time_net3 = TimeNet(dim=6) 203 | self.time_net2 = TimeNet(dim=4) 204 | self.time_net1 = TimeNet(dim=2) 205 | 206 | self.down1 = nn.Sequential( 207 | nn.Linear(7, 6), 208 | nn.GELU(), 209 | nn.Linear(6, 6) 210 | ) 211 | self.down2 = nn.Sequential( 212 | nn.Linear(6, 4), 213 | nn.GELU(), 214 | nn.Linear(4, 4) 215 | ) 216 | self.down3 = nn.Sequential( 217 | nn.Linear(4, 2), 218 | nn.GELU(), 219 | nn.Linear(2, 2) 220 | ) 221 | 222 | self.up1 = nn.Sequential( 223 | nn.Linear(2 + 4, 4), 224 | nn.GELU(), 225 | nn.Linear(4, 4) 226 | ) 227 | self.up2 = nn.Sequential( 228 | nn.Linear(4 + 6, 6), 229 | nn.GELU(), 230 | nn.Linear(6, 6) 231 | ) 232 | self.up3 = nn.Sequential( 233 | nn.Linear(6 + 7, 7), 234 | nn.GELU(), 235 | nn.Linear(7, 7) 236 | ) 237 | 238 | def forward(self, g, c, t, context_mask, _t): 239 | """_summary_ 240 | Args: 241 | g: pose representations, size [B, 7] 242 | c: point cloud representations, size [B, 1024] 243 | t: affordance texts, size [B, 512] 244 | context_mask: masks {0, 1} for the contexts, size [B, 1] 245 | _t is for the timesteps, size [B,] 246 | """ 247 | c = c * context_mask 248 | c0 = self.cloud_net0(c) 249 | c1 = self.cloud_net1(c0) 250 | c2 = self.cloud_net2(c0) 251 | c3 = self.cloud_net3(c0) 252 | 253 | t = t * context_mask 254 | t0 = self.text_net0(t) 255 | t1 = self.text_net1(t0) 256 | t2 = self.text_net2(t0) 257 | t3 = self.text_net3(t0) 258 | 259 | _t0 = _t.unsqueeze(1) 260 | _t1 = self.time_net1(_t0) 261 | _t2 = self.time_net2(_t0) 262 | _t3 = self.time_net3(_t0) 263 | 264 | g = g.float() 265 | g_down1 = self.down1(g) # 6 266 | g_down2 = self.down2(g_down1) # 4 267 | g_down3 = self.down3(g_down2) # 2 268 | 269 | c1_influence = self.cloud_influence_net1(torch.cat((c1, g, _t1), dim=1)) 270 | t1_influence = self.text_influence_net1(torch.cat((t1, g, _t1), dim=1)) 271 | influences1 = F.softmax(torch.cat((c1_influence.unsqueeze(1), t1_influence.unsqueeze(1)), dim=1), dim=1) 272 | ct1 = (c1 * influences1[:, 0, :] + t1 * influences1[:, 1, :]) 273 | up1 = self.up1(torch.cat((g_down3 * ct1 + _t1, g_down2), dim=1)) 274 | 275 | c2_influence = self.cloud_influence_net2(torch.cat((c2, g, _t2), dim=1)) 276 | t2_influence = self.text_influence_net2(torch.cat((t2, g, _t2), dim=1)) 277 | influences2 = F.softmax(torch.cat((c2_influence.unsqueeze(1), t2_influence.unsqueeze(1)), dim=1), dim=1) 278 | ct2 = (c2 * influences2[:, 0, :] + t2 * influences2[:, 1, :]) 279 | up2 = self.up2(torch.cat((up1 * ct2 + _t2, g_down1), dim=1)) 280 | 281 | c3_influence = self.cloud_influence_net3(torch.cat((c3, g, _t3), dim=1)) 282 | t3_influence = self.text_influence_net3(torch.cat((t3, g, _t3), dim=1)) 283 | influences3 = F.softmax(torch.cat((c3_influence.unsqueeze(1), t3_influence.unsqueeze(1)), dim=1), dim=1) 284 | ct3 = (c3 * influences3[:, 0, :] + t3 * influences3[:, 1, :]) 285 | up3 = self.up3(torch.cat((up2 * ct3 + _t3, g), dim=1)) # size [B, 7] 286 | 287 | return up3 -------------------------------------------------------------------------------- /models/main_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .components import TextEncoder, PointNetPlusPlus, PoseNet 6 | 7 | 8 | text_encoder = TextEncoder(device=torch.device('cuda')) 9 | 10 | 11 | # Linear noise scheduler 12 | def linear_diffusion_schedule(betas, T): 13 | """_summary_ 14 | Linear cheduling for sampling in training. 15 | """ 16 | beta_t = (betas[1] - betas[0]) * torch.arange(0, T + 1, dtype=torch.float32) / T + betas[0] 17 | sqrt_beta_t = torch.sqrt(beta_t) 18 | alpha_t = 1 - beta_t 19 | log_alpha_t = torch.log(alpha_t) 20 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() 21 | 22 | sqrtab = torch.sqrt(alphabar_t) 23 | oneover_sqrta = 1 / torch.sqrt(alpha_t) 24 | 25 | sqrtmab = torch.sqrt(1 - alphabar_t) 26 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab 27 | 28 | return { 29 | "alpha_t": alpha_t, # \alpha_t 30 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t} 31 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t} 32 | "alphabar_t": alphabar_t, # \bar{\alpha_t} 33 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} 34 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} 35 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}} 36 | } 37 | 38 | 39 | # Main network for affordance detection and pose generation 40 | class DetectionDiffusion(nn.Module): 41 | def __init__(self, betas, n_T, device, background_text, drop_prob=0.1): 42 | """_summary_ 43 | 44 | Args: 45 | drop_prob: probability to drop the conditions 46 | """ 47 | super(DetectionDiffusion, self).__init__() 48 | self.posenet = PoseNet() 49 | self.pointnetplusplus = PointNetPlusPlus() 50 | 51 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 52 | 53 | # Register_buffer allows accessing dictionary, e.g. can access self.sqrtab later 54 | for k, v in linear_diffusion_schedule(betas, n_T).items(): 55 | self.register_buffer(k, v) 56 | 57 | self.n_T = n_T 58 | self.device = device 59 | self.background_text = background_text 60 | self.drop_prob = drop_prob 61 | self.loss_mse = nn.MSELoss() 62 | 63 | def forward(self, xyz, text, affordance_label, g): 64 | """_summary_ 65 | This method is used in training, so samples _ts and noise randomly. 66 | """ 67 | B = xyz.shape[0] # xyz's size [B, 3, 2048] 68 | point_features, c = self.pointnetplusplus(xyz) # point_features' size [B, 512, 2048], c'size [B, 1024] 69 | with torch.no_grad(): 70 | foreground_text_features = text_encoder(text) # size [B, 512] 71 | background_text_features = text_encoder([self.background_text] * B) 72 | text_features = torch.cat((background_text_features.unsqueeze(1), \ 73 | foreground_text_features.unsqueeze(1)), dim=1) # size [B, 2, 512] 74 | 75 | affordance_prediction = self.logit_scale * torch.einsum('bij,bjk->bik', text_features, point_features) \ 76 | / (torch.einsum('bij,bjk->bik', torch.norm(text_features, dim=2, keepdim=True), \ 77 | torch.norm(point_features, dim=1, keepdim=True))) # size [B, 2, 2048] 78 | 79 | affordance_prediction = F.log_softmax(affordance_prediction, dim=1) 80 | affordance_loss = F.nll_loss(affordance_prediction, affordance_label) 81 | 82 | _ts = torch.randint(1, self.n_T + 1, (B,)).to(self.device) 83 | noise = torch.randn_like(g) # eps ~ N(0, 1), g size [B, 7] 84 | g_t = ( 85 | self.sqrtab[_ts - 1, None] * g 86 | + self.sqrtmab[_ts - 1, None] * noise 87 | ) # This is the g_t, which is sqrt(alphabar) g_0 + sqrt(1-alphabar) * eps 88 | 89 | # dropout context with some probability 90 | context_mask = torch.bernoulli(torch.zeros(B, 1) + 1 - self.drop_prob).to(self.device) 91 | 92 | # Loss for poseing is MSE between added noise, and our predicted noise 93 | pose_loss = self.loss_mse(noise, self.posenet(g_t, c, foreground_text_features, context_mask, _ts / self.n_T)) 94 | return affordance_loss, pose_loss 95 | 96 | def detect_and_sample(self, xyz, text, n_sample, guide_w): 97 | """_summary_ 98 | Detect affordance for one point cloud and sample [n_sample] poses that support the 'text' affordance task, 99 | following the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'. 100 | """ 101 | g_i = torch.randn(n_sample, (7)).to(self.device) # start by sampling from Gaussian noise 102 | point_features, c = self.pointnetplusplus(xyz) # point_features size [1, 512, 2048], c size [1, 1024] 103 | foreground_text_features = text_encoder(text) # size [1, 512] 104 | background_text_features = text_encoder([self.background_text] * 1) 105 | text_features = torch.cat((background_text_features.unsqueeze(1), \ 106 | foreground_text_features.unsqueeze(1)), dim=1) # size [B, 2, 512] 107 | 108 | affordance_prediction = self.logit_scale * torch.einsum('bij,bjk->bik', text_features, point_features) \ 109 | / (torch.einsum('bij,bjk->bik', torch.norm(text_features, dim=2, keepdim=True), \ 110 | torch.norm(point_features, dim=1, keepdim=True))) # size [1, 2, 2048] 111 | 112 | affordance_prediction = F.log_softmax(affordance_prediction, dim=1) # .cpu().numpy() 113 | c_i = c.repeat(n_sample, 1) 114 | t_i = foreground_text_features.repeat(n_sample, 1) 115 | context_mask = torch.ones((n_sample, 1)).float().to(self.device) 116 | 117 | # Double the batch 118 | c_i = c_i.repeat(2, 1) 119 | t_i = t_i.repeat(2, 1) 120 | context_mask = context_mask.repeat(2, 1) 121 | context_mask[n_sample:] = 0. # make second half of the back context-free 122 | 123 | for i in range(self.n_T, 0, -1): 124 | _t_is = torch.tensor([i / self.n_T]).repeat(n_sample).repeat(2).to(self.device) 125 | g_i = g_i.repeat(2, 1) 126 | 127 | z = torch.randn(n_sample, (7)) if i > 1 else torch.zeros((n_sample, 7)) 128 | z = z.to(self.device) 129 | eps = self.posenet(g_i, c_i, t_i, context_mask, _t_is) 130 | eps1 = eps[:n_sample] 131 | eps2 = eps[n_sample:] 132 | eps = (1 + guide_w) * eps1 - guide_w * eps2 133 | 134 | g_i = g_i[:n_sample] 135 | g_i = self.oneover_sqrta[i] * (g_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z 136 | return np.argmax(affordance_prediction.cpu().numpy(), axis=1), g_i.cpu().numpy() -------------------------------------------------------------------------------- /models/pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | 8 | def timeit(tag, t): 9 | print("{}: {}s".format(tag, time() - t)) 10 | return time() 11 | 12 | 13 | def pc_normalize(pc): 14 | l = pc.shape[0] 15 | centroid = np.mean(pc, axis=0) 16 | pc = pc - centroid 17 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 18 | pc = pc / m 19 | return pc 20 | 21 | 22 | def square_distance(src, dst): 23 | """_summary_ 24 | Calculate Euclid distance between each two points. 25 | 26 | src^T * dst = xn * xm + yn * ym + zn * zm; 27 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 28 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 29 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 30 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 31 | 32 | Input: 33 | src: source points, [B, N, C] 34 | dst: target points, [B, M, C] 35 | Output: 36 | dist: per-point square distance, [B, N, M] 37 | """ 38 | B, N, _ = src.shape 39 | _, M, _ = dst.shape 40 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 41 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 42 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 43 | return dist 44 | 45 | 46 | def index_points(points, idx): 47 | """_summary_ 48 | Input: 49 | points: input points data, [B, N, C] 50 | idx: sample index data, [B, S] 51 | Return: 52 | new_points:, indexed points data, [B, S, C] 53 | """ 54 | device = points.device 55 | B = points.shape[0] 56 | view_shape = list(idx.shape) 57 | view_shape[1:] = [1] * (len(view_shape) - 1) 58 | repeat_shape = list(idx.shape) 59 | repeat_shape[0] = 1 60 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 61 | new_points = points[batch_indices, idx, :] 62 | return new_points 63 | 64 | 65 | def farthest_point_sample(xyz, npoint): 66 | """_summary_ 67 | Input: 68 | xyz: pointcloud data, [B, N, 3] 69 | npoint: number of samples 70 | Return: 71 | centroids: sampled pointcloud index, [B, npoint] 72 | """ 73 | device = xyz.device 74 | B, N, C = xyz.shape 75 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 76 | distance = torch.ones(B, N).to(device) * 1e10 77 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 78 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 79 | for i in range(npoint): 80 | centroids[:, i] = farthest 81 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 82 | dist = torch.sum((xyz - centroid) ** 2, -1) 83 | mask = dist < distance 84 | distance[mask] = dist[mask] 85 | farthest = torch.max(distance, -1)[1] 86 | return centroids 87 | 88 | 89 | def query_ball_point(radius, nsample, xyz, new_xyz): 90 | """_summary_ 91 | Input: 92 | radius: local region radius 93 | nsample: max sample number in local region 94 | xyz: all points, [B, N, 3] 95 | new_xyz: query points, [B, S, 3] 96 | Return: 97 | group_idx: grouped points index, [B, S, nsample] 98 | """ 99 | device = xyz.device 100 | B, N, C = xyz.shape 101 | _, S, _ = new_xyz.shape 102 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 103 | sqrdists = square_distance(new_xyz, xyz) 104 | group_idx[sqrdists > radius ** 2] = N 105 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 106 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 107 | mask = group_idx == N 108 | group_idx[mask] = group_first[mask] 109 | return group_idx 110 | 111 | 112 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 113 | """_summary_ 114 | Input: 115 | npoint: 116 | radius: 117 | nsample: 118 | xyz: input points position data, [B, N, 3] 119 | points: input points data, [B, N, D] 120 | Return: 121 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 122 | new_points: sampled points data, [B, npoint, nsample, 3+D] 123 | """ 124 | B, N, C = xyz.shape 125 | S = npoint 126 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 127 | new_xyz = index_points(xyz, fps_idx) 128 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 129 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 130 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 131 | 132 | if points is not None: 133 | grouped_points = index_points(points, idx) 134 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 135 | else: 136 | new_points = grouped_xyz_norm 137 | if returnfps: 138 | return new_xyz, new_points, grouped_xyz, fps_idx 139 | else: 140 | return new_xyz, new_points 141 | 142 | 143 | def sample_and_group_all(xyz, points): 144 | """_summary_ 145 | Input: 146 | xyz: input points position data, [B, N, 3] 147 | points: input points data, [B, N, D] 148 | Return: 149 | new_xyz: sampled points position data, [B, 1, 3] 150 | new_points: sampled points data, [B, 1, N, 3+D] 151 | """ 152 | device = xyz.device 153 | B, N, C = xyz.shape 154 | new_xyz = torch.zeros(B, 1, C).to(device) 155 | grouped_xyz = xyz.view(B, 1, N, C) 156 | if points is not None: 157 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 158 | else: 159 | new_points = grouped_xyz 160 | return new_xyz, new_points 161 | 162 | 163 | class PointNetSetAbstraction(nn.Module): 164 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 165 | super(PointNetSetAbstraction, self).__init__() 166 | self.npoint = npoint 167 | self.radius = radius 168 | self.nsample = nsample 169 | self.mlp_convs = nn.ModuleList() 170 | self.mlp_bns = nn.ModuleList() 171 | last_channel = in_channel 172 | for out_channel in mlp: 173 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 174 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 175 | last_channel = out_channel 176 | self.group_all = group_all 177 | 178 | def forward(self, xyz, points): 179 | """_summary_ 180 | Input: 181 | xyz: input points position data, [B, C, N] 182 | points: input points data, [B, D, N] 183 | Return: 184 | new_xyz: sampled points position data, [B, C, S] 185 | new_points_concat: sample points feature data, [B, D', S] 186 | """ 187 | xyz = xyz.permute(0, 2, 1) 188 | if points is not None: 189 | points = points.permute(0, 2, 1) 190 | 191 | if self.group_all: 192 | new_xyz, new_points = sample_and_group_all(xyz, points) 193 | else: 194 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 195 | # new_xyz: sampled points position data, [B, npoint, C] 196 | # new_points: sampled points data, [B, npoint, nsample, C+D] 197 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 198 | for i, conv in enumerate(self.mlp_convs): 199 | bn = self.mlp_bns[i] 200 | new_points = F.relu(bn(conv(new_points))) 201 | 202 | new_points = torch.max(new_points, 2)[0] 203 | new_xyz = new_xyz.permute(0, 2, 1) 204 | return new_xyz, new_points 205 | 206 | 207 | class PointNetSetAbstractionMsg(nn.Module): 208 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 209 | super(PointNetSetAbstractionMsg, self).__init__() 210 | self.npoint = npoint 211 | self.radius_list = radius_list 212 | self.nsample_list = nsample_list 213 | self.conv_blocks = nn.ModuleList() 214 | self.bn_blocks = nn.ModuleList() 215 | for i in range(len(mlp_list)): 216 | convs = nn.ModuleList() 217 | bns = nn.ModuleList() 218 | last_channel = in_channel + 3 219 | for out_channel in mlp_list[i]: 220 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 221 | bns.append(nn.BatchNorm2d(out_channel)) 222 | last_channel = out_channel 223 | self.conv_blocks.append(convs) 224 | self.bn_blocks.append(bns) 225 | 226 | def forward(self, xyz, points): 227 | """_summary_ 228 | Input: 229 | xyz: input points position data, [B, C, N] 230 | points: input points data, [B, D, N] 231 | Return: 232 | new_xyz: sampled points position data, [B, C, S] 233 | new_points_concat: sample points feature data, [B, D', S] 234 | """ 235 | xyz = xyz.permute(0, 2, 1) 236 | if points is not None: 237 | points = points.permute(0, 2, 1) 238 | 239 | B, N, C = xyz.shape 240 | S = self.npoint 241 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 242 | new_points_list = [] 243 | for i, radius in enumerate(self.radius_list): 244 | K = self.nsample_list[i] 245 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 246 | grouped_xyz = index_points(xyz, group_idx) 247 | grouped_xyz -= new_xyz.view(B, S, 1, C) 248 | if points is not None: 249 | grouped_points = index_points(points, group_idx) 250 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 251 | else: 252 | grouped_points = grouped_xyz 253 | 254 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 255 | for j in range(len(self.conv_blocks[i])): 256 | conv = self.conv_blocks[i][j] 257 | bn = self.bn_blocks[i][j] 258 | grouped_points = F.relu(bn(conv(grouped_points))) 259 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 260 | new_points_list.append(new_points) 261 | 262 | new_xyz = new_xyz.permute(0, 2, 1) 263 | new_points_concat = torch.cat(new_points_list, dim=1) 264 | return new_xyz, new_points_concat 265 | 266 | 267 | class PointNetFeaturePropagation(nn.Module): 268 | def __init__(self, in_channel, mlp): 269 | super(PointNetFeaturePropagation, self).__init__() 270 | self.mlp_convs = nn.ModuleList() 271 | self.mlp_bns = nn.ModuleList() 272 | last_channel = in_channel 273 | for out_channel in mlp: 274 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 275 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 276 | last_channel = out_channel 277 | 278 | def forward(self, xyz1, xyz2, points1, points2): 279 | """_summary_ 280 | Input: 281 | xyz1: input points position data, [B, C, N] 282 | xyz2: sampled input points position data, [B, C, S] 283 | points1: input points data, [B, D, N] 284 | points2: input points data, [B, D, S] 285 | Return: 286 | new_points: upsampled points data, [B, D', N] 287 | """ 288 | xyz1 = xyz1.permute(0, 2, 1) 289 | xyz2 = xyz2.permute(0, 2, 1) 290 | 291 | points2 = points2.permute(0, 2, 1) 292 | B, N, C = xyz1.shape 293 | _, S, _ = xyz2.shape 294 | 295 | if S == 1: 296 | interpolated_points = points2.repeat(1, N, 1) 297 | else: 298 | dists = square_distance(xyz1, xyz2) 299 | dists, idx = dists.sort(dim=-1) 300 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 301 | 302 | dist_recip = 1.0 / (dists + 1e-8) 303 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 304 | weight = dist_recip / norm 305 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 306 | 307 | if points1 is not None: 308 | points1 = points1.permute(0, 2, 1) 309 | new_points = torch.cat([points1, interpolated_points], dim=-1) 310 | else: 311 | new_points = interpolated_points 312 | 313 | new_points = new_points.permute(0, 2, 1) 314 | for i, conv in enumerate(self.mlp_convs): 315 | bn = self.mlp_bns[i] 316 | new_points = F.relu(bn(conv(new_points))) 317 | return new_points -------------------------------------------------------------------------------- /models/weights_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def weights_init(m): 4 | """_summary_ 5 | Weights initialization 6 | """ 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv2d') != -1: 9 | torch.nn.init.xavier_normal_(m.weight.data) 10 | if m.state_dict().get('bias') != None: 11 | torch.nn.init.constant_(m.bias.data, 0.0) 12 | elif classname.find('Linear') != -1: 13 | torch.nn.init.xavier_normal_(m.weight.data) 14 | if m.state_dict().get('bias') != None: 15 | torch.nn.init.constant_(m.bias.data, 0.0) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | h5py 4 | scikit_learn==1.3.0 5 | gorilla-core==0.2.7.8 6 | torch==2.0.1 7 | scipy==1.11.1 8 | trimesh==4.0.7 9 | open_clip_torch -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils.eval import affordance_eval, pose_eval 2 | import argparse 3 | import pickle 4 | 5 | 6 | AFFORDANCE_LIST = ['grasp to pour', 'grasp to stab', 'stab', 'pourable', 'lift', 'wrap_grasp', 'listen', 'contain', 'displaY', 'grasp to cut', 'cut', 'wear', 'openable', 'grasp'] 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="Test a model") 11 | parser.add_argument("--result", help="result file") 12 | args = parser.parse_args() 13 | return args 14 | 15 | 16 | if __name__ == "__main__": 17 | args = parse_args() 18 | with open(args.result, 'rb') as f: 19 | result = pickle.load(f) 20 | mIoU, Acc, mAcc = affordance_eval(AFFORDANCE_LIST, result) 21 | print(f'mIoU: {mIoU}, Acc: {Acc}, mAcc: {mAcc}') 22 | 23 | mESM, mCR = pose_eval(result) 24 | print(f'mESM: {mESM}, mCR: {mCR}') 25 | 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as opj 3 | from gorilla.config import Config 4 | from utils import * 5 | import argparse 6 | import torch 7 | 8 | 9 | # Argument Parser 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Train a model") 12 | parser.add_argument("--config", help="train config file path") 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | if __name__ == "__main__": 18 | args = parse_args() 19 | cfg = Config.fromfile(args.config) 20 | 21 | logger = IOStream(opj(cfg.log_dir, 'run.log')) 22 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.training_cfg.gpu 23 | num_gpu = len(cfg.training_cfg.gpu.split(',')) # number of GPUs to use 24 | logger.cprint('Use %d GPUs: %s' % (num_gpu, cfg.training_cfg.gpu)) 25 | if cfg.get('seed') != None: # set random seed 26 | set_random_seed(cfg.seed) 27 | logger.cprint('Set seed to %d' % cfg.seed) 28 | model = build_model(cfg).cuda() # build the model from configuration 29 | 30 | print("Training from scratch!") 31 | 32 | dataset_dict = build_dataset(cfg) # build the dataset 33 | loader_dict = build_loader(cfg, dataset_dict) # build the loader 34 | optim_dict = build_optimizer(cfg, model) # build the optimizer 35 | 36 | # construct the training process 37 | training = dict( 38 | model=model, 39 | dataset_dict=dataset_dict, 40 | loader_dict=loader_dict, 41 | optim_dict=optim_dict, 42 | logger=logger 43 | ) 44 | 45 | task_trainer = Trainer(cfg, training) 46 | task_trainer.run() 47 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_optimizer, build_dataset, build_loader, build_model 2 | from .trainer import Trainer 3 | from .utils import set_random_seed, IOStream, PN2_BNMomentum, PN2_Scheduler 4 | 5 | __all__ = ['build_optimizer', 'build_dataset', 'build_loader', 'build_model', 6 | 'Trainer', 'set_random_seed', 'IOStream', 'PN2_BNMomentum', 'PN2_Scheduler'] 7 | -------------------------------------------------------------------------------- /utils/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LambdaLR, MultiStepLR 3 | from dataset import * 4 | from models import * 5 | from torch.utils.data import DataLoader 6 | from torch.optim import SGD, Adam 7 | 8 | # Pools of models, optimizers, weights initialization methods, schedulers 9 | model_pool = { 10 | 'detectiondiffusion': DetectionDiffusion, 11 | } 12 | 13 | optimizer_pool = { 14 | 'sgd': SGD, 15 | 'adam': Adam 16 | } 17 | 18 | init_pool = { 19 | 'default_init': weights_init 20 | } 21 | 22 | scheduler_pool = { 23 | 'step': StepLR, 24 | 'cos': CosineAnnealingLR, 25 | 'lr_lambda': LambdaLR, 26 | 'multi_step': MultiStepLR 27 | } 28 | 29 | 30 | def build_model(cfg): 31 | """_summary_ 32 | Function to build the model before training 33 | """ 34 | if hasattr(cfg, 'model'): 35 | model_info = cfg.model 36 | weights_init = model_info.get('weights_init', None) 37 | background_text = model_info.get('background_text', 'none') 38 | device = model_info.get('device', torch.device('cuda')) 39 | model_name = model_info.type 40 | model_cls = model_pool[model_name] 41 | if model_name in ['detectiondiffusion']: 42 | betas = model_info.get('betas', [1e-4, 0.02]) 43 | n_T = model_info.get('n_T', 1000) 44 | drop_prob = model_info.get('drop_prob', 0.1) 45 | model = model_cls(betas, n_T, device, background_text, drop_prob) 46 | else: 47 | raise ValueError("The model name does not exist!") 48 | if weights_init != None: 49 | init_fn = init_pool[weights_init] 50 | model.apply(init_fn) 51 | return model 52 | else: 53 | raise ValueError("Configuration does not have model config!") 54 | 55 | 56 | def build_dataset(cfg): 57 | """_summary_ 58 | Function to build the dataset 59 | """ 60 | if hasattr(cfg, 'data'): 61 | data_info = cfg.data 62 | data_path = data_info.data_path 63 | train_set = ThreeDAPDataset(data_path, mode='train') 64 | val_set = ThreeDAPDataset(data_path, mode='val') 65 | test_set = ThreeDAPDataset(data_path, mode='test') 66 | dataset_dict = dict( 67 | train_set=train_set, 68 | val_set=val_set, 69 | test_set=test_set 70 | ) 71 | return dataset_dict 72 | else: 73 | raise ValueError("Configuration does not have data config!") 74 | 75 | 76 | def build_loader(cfg, dataset_dict): 77 | """_summary_ 78 | Function to build the loader 79 | """ 80 | train_set = dataset_dict["train_set"] 81 | train_loader = DataLoader(train_set, batch_size=cfg.training_cfg.batch_size, 82 | shuffle=True, drop_last=False, num_workers=8) 83 | loader_dict = dict( 84 | train_loader=train_loader, 85 | ) 86 | 87 | return loader_dict 88 | 89 | 90 | def build_optimizer(cfg, model): 91 | """_summary_ 92 | Function to build the optimizer 93 | """ 94 | optimizer_info = cfg.optimizer 95 | optimizer_type = optimizer_info.type 96 | optimizer_info.pop('type') 97 | optimizer_cls = optimizer_pool[optimizer_type] 98 | optimizer = optimizer_cls(model.parameters(), **optimizer_info) 99 | scheduler_info = cfg.scheduler 100 | if scheduler_info: 101 | scheduler_name = scheduler_info.type 102 | scheduler_info.pop('type') 103 | scheduler_cls = scheduler_pool[scheduler_name] 104 | scheduler = scheduler_cls(optimizer, **scheduler_info) 105 | else: 106 | scheduler = None 107 | optim_dict = dict( 108 | scheduler=scheduler, 109 | optimizer=optimizer 110 | ) 111 | return optim_dict 112 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | from scipy.spatial.transform import Rotation as R 4 | 5 | 6 | def affordance_eval(affordance_list, result): 7 | """_summary_ 8 | This fuction evaluates the affordance detection capability. 9 | `result` is loaded from result.pkl file produced by detect.py. 10 | """ 11 | num_correct = 0 12 | num_all = 0 13 | num_points = {aff: 0 for aff in affordance_list} 14 | num_label_points = {aff: 0 for aff in affordance_list} 15 | num_correct_fg_points = {aff: 0 for aff in affordance_list} 16 | num_correct_bg_points = {aff: 0 for aff in affordance_list} 17 | num_union_points = {aff: 0 for aff in affordance_list} 18 | num_appearances = {aff: 0 for aff in affordance_list} 19 | 20 | for shape in result: 21 | for affordance in shape['affordance']: 22 | label = np.transpose(shape['full_shape']['label'][affordance]) 23 | prediction = shape['result'][affordance][0] 24 | 25 | num_correct += np.sum(label == prediction) 26 | num_all += 2048 27 | num_points[affordance] += 2048 28 | num_label_points[affordance] += np.sum(label == 1.) 29 | num_correct_fg_points[affordance] += np.sum((label == 1.) & (prediction == 1.)) 30 | num_correct_bg_points[affordance] += np.sum((label == 0.) & (prediction == 0.)) 31 | num_union_points[affordance] += np.sum((label == 1.) | (prediction == 1.)) 32 | mIoU = np.average(np.array(list(num_correct_fg_points.values())) / np.array(list(num_union_points.values())), 33 | weights=np.array(list(num_appearances.values()))) 34 | Acc = num_correct / num_all 35 | mAcc = np.mean((np.array(list(num_correct_fg_points.values())) + np.array(list(num_correct_bg_points.values()))) / \ 36 | np.array(list(num_points.values()))) 37 | 38 | return mIoU, Acc, mAcc 39 | 40 | 41 | def pose_eval(result): 42 | """_summary_ 43 | This function evaluates the pose detection capability. 44 | `result` is loaded from result.pkl file produced by detect.py. 45 | """ 46 | all_min_dist = [] 47 | all_rate = [] 48 | for object in result: 49 | for affordance in object['affordance']: 50 | gt_poses = np.array([np.concatenate((R.from_matrix(p[:3, :3]).as_quat(), p[:3, 3]), axis=0) for p in object['pose'][affordance]]) 51 | distances = cdist(gt_poses, object['result'][affordance][1]) 52 | rate = np.sum(np.any(distances <= 0.2, axis=1)) / len(object['pose'][affordance]) 53 | all_rate.append(rate) 54 | 55 | g = gt_poses[:, np.newaxis, :] 56 | g_pred = object['result'][affordance][1] 57 | l2_distances = np.sqrt(np.sum((g-g_pred)**2, axis=2)) 58 | min_distance = np.min(l2_distances) 59 | 60 | # discard cases when set of gt poses and set of detected poses too far from each other, to get a stable result 61 | if min_distance <= 1.0: 62 | all_min_dist.append(min_distance) 63 | return (np.mean(np.array(all_min_dist)), np.mean(np.array(all_rate))) -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from os.path import join as opj 4 | from utils import * 5 | 6 | 7 | DEVICE = torch.device('cuda') 8 | 9 | 10 | class Trainer(object): 11 | def __init__(self, cfg, running): 12 | super().__init__() 13 | self.cfg = cfg 14 | self.logger = running['logger'] 15 | self.model = running["model"] 16 | self.dataset_dict = running["dataset_dict"] 17 | self.loader_dict = running["loader_dict"] 18 | self.train_loader = self.loader_dict.get("train_loader", None) 19 | self.optimizer_dict = running["optim_dict"] 20 | self.optimizer = self.optimizer_dict.get("optimizer", None) 21 | self.scheduler = self.optimizer_dict.get("scheduler", None) 22 | self.epoch = 0 23 | self.bn_momentum = self.cfg.training_cfg.get('bn_momentum', None) 24 | 25 | def train(self): 26 | self.model.train() 27 | self.logger.cprint("Epoch(%d) begin training........" % self.epoch) 28 | pbar = tqdm(self.train_loader) 29 | for _, _, xyz, text, affordance_label, rotation, translation in pbar: 30 | self.optimizer.zero_grad() 31 | xyz = xyz.float() 32 | rotation = rotation.float() 33 | translation = translation.float() 34 | affordance_label = affordance_label.squeeze().long() 35 | 36 | g = torch.cat((rotation, translation), dim=1) 37 | xyz = xyz.to(DEVICE) 38 | affordance_label = affordance_label.to(DEVICE) 39 | g = g.to(DEVICE) 40 | 41 | affordance_loss, pose_loss = self.model(xyz, text, affordance_label, g) 42 | loss = affordance_loss + pose_loss 43 | loss.backward() 44 | 45 | affordance_l = affordance_loss.item() 46 | pose_l = pose_loss.item() 47 | pbar.set_description(f'Affordance loss: {affordance_l:.5f}, Pose loss: {pose_l:.5f}') 48 | self.optimizer.step() 49 | 50 | if self.scheduler != None: 51 | self.scheduler.step() 52 | if self.bn_momentum != None: 53 | self.model.apply(lambda x: self.bn_momentum(x, self.epoch)) 54 | 55 | outstr = f"\nEpoch {self.epoch}, Last Affordance loss: {affordance_l:.5f}, Last Pose loss: {pose_l:.5f}" 56 | self.logger.cprint(outstr) 57 | print('Saving checkpoint') 58 | torch.save(self.model.state_dict(), opj(self.cfg.log_dir, 'current_model.t7')) 59 | self.epoch += 1 60 | 61 | def val(self): 62 | raise NotImplementedError 63 | 64 | def test(self): 65 | raise NotImplementedError 66 | 67 | def run(self): 68 | EPOCH = self.cfg.training_cfg.epoch 69 | workflow = self.cfg.training_cfg.workflow 70 | 71 | while self.epoch < EPOCH: 72 | for key, running_epoch in workflow.items(): 73 | epoch_runner = getattr(self, key) 74 | for _ in range(running_epoch): 75 | epoch_runner() 76 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import random 5 | 6 | 7 | class IOStream(): 8 | def __init__(self, path): 9 | self.f = open(path, 'a') 10 | 11 | def cprint(self, text): 12 | print(text) 13 | self.f.write(text+'\n') 14 | self.f.flush() 15 | 16 | def close(self): 17 | self.f.close() 18 | 19 | 20 | class PN2_Scheduler(object): 21 | def __init__(self, init_lr, step, decay_rate, min_lr): 22 | super().__init__() 23 | self.init_lr = init_lr 24 | self.step = step 25 | self.decay_rate = decay_rate 26 | self.min_lr = min_lr 27 | return 28 | 29 | def __call__(self, epoch): 30 | factor = self.decay_rate**(epoch//self.step) 31 | if self.init_lr*factor < self.min_lr: 32 | factor = self.min_lr / self.init_lr 33 | return factor 34 | 35 | 36 | class PN2_BNMomentum(object): 37 | def __init__(self, origin_m, m_decay, step): 38 | super().__init__() 39 | self.origin_m = origin_m 40 | self.m_decay = m_decay 41 | self.step = step 42 | return 43 | 44 | def __call__(self, m, epoch): 45 | momentum = self.origin_m * (self.m_decay**(epoch//self.step)) 46 | if momentum < 0.01: 47 | momentum = 0.01 48 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 49 | m.momentum = momentum 50 | return 51 | 52 | 53 | def set_random_seed(seed): 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed(seed) -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | 3 | 4 | def create_gripper_marker(color=[0, 255, 0], tube_radius=0.002, sections=6): 5 | """Create a 3D mesh visualizing a parallel yaw gripper. It consists of four cylinders. 6 | 7 | Args: 8 | color (list, optional): RGB values of marker. Defaults to [0, 0, 255]. 9 | tube_radius (float, optional): Radius of cylinders. Defaults to 0.001. 10 | sections (int, optional): Number of sections of each cylinder. Defaults to 6. 11 | 12 | Returns: 13 | trimesh.Trimesh: A mesh that represents a simple parallel yaw gripper. 14 | """ 15 | cfl = trimesh.creation.cylinder( 16 | radius=tube_radius, 17 | sections=sections, 18 | segment=[ 19 | [4.10000000e-02, -7.27595772e-12, 6.59999996e-02], 20 | [4.10000000e-02, -7.27595772e-12, 1.12169998e-01], 21 | ], 22 | ) 23 | cfr = trimesh.creation.cylinder( 24 | radius=tube_radius, 25 | sections=sections, 26 | segment=[ 27 | [-4.100000e-02, -7.27595772e-12, 6.59999996e-02], 28 | [-4.100000e-02, -7.27595772e-12, 1.12169998e-01], 29 | ], 30 | ) 31 | cb1 = trimesh.creation.cylinder( 32 | radius=tube_radius, sections=sections, segment=[[0, 0, 0], [0, 0, 6.59999996e-02]] 33 | ) 34 | cb2 = trimesh.creation.cylinder( 35 | radius=tube_radius, 36 | sections=sections, 37 | segment=[[-4.100000e-02, 0, 6.59999996e-02], [4.100000e-02, 0, 6.59999996e-02]], 38 | ) 39 | 40 | tmp = trimesh.util.concatenate([cb1, cb2, cfr, cfl]) 41 | tmp.visual.face_colors = color 42 | 43 | return tmp -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | import pickle 4 | from scipy.spatial.transform import Rotation as R 5 | import argparse 6 | from utils.visualization import create_gripper_marker 7 | 8 | color_code_1 = np.array([0, 0, 255]) # color code for affordance region 9 | color_code_2 = np.array([0, 255, 0]) # color code for gripper pose 10 | num_pose = 100 # number of poses to visualize per each object-affordance pair 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="Visualize") 15 | parser.add_argument("--result", help="result file") 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | if __name__ == "__main__": 21 | args = parse_args() 22 | result_file = args.result_file 23 | with open(result_file, 'rb') as f: 24 | result = pickle.load(f) 25 | 26 | for i in range(len(result)): 27 | if result[i]['semantic class'] == 'Bottle': 28 | shape_index = i 29 | shape = result[shape_index] 30 | 31 | for affordance in shape['affordance']: 32 | colors = np.transpose(shape['result'][affordance][0]) * color_code_1 33 | point_cloud = trimesh.points.PointCloud(shape['full_shape']['coordinate'], colors=colors) 34 | print(f"Affordance: {affordance}") 35 | T = shape['result'][affordance][1][:num_pose] 36 | rotation = np.concatenate((R.from_quat(T[:, :4]).as_matrix(), np.zeros((num_pose, 1, 3), dtype=np.float32)), axis=1) 37 | translation = np.expand_dims(np.concatenate((T[:, 4:], np.ones((num_pose, 1), dtype=np.float32)), axis=1), axis=2) 38 | T = np.concatenate((rotation, translation), axis=2) 39 | poses = [create_gripper_marker(color=color_code_2).apply_transform(t) for t in T 40 | if np.min(np.linalg.norm(point_cloud - (t @ np.array([0., 0., 6.59999996e-02, 1.]))[:3], axis=1)) <= 0.03] # this line is used to get reliable poses only 41 | scene = trimesh.Scene([point_cloud, poses]) 42 | scene.show(line_settings={'point size': 10}) --------------------------------------------------------------------------------