├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
├── intro.png
├── method.png
└── visualization.png
├── config
└── detectiondiffusion.py
├── dataset
├── ThreeDAPDataset.py
└── __init__.py
├── detect.py
├── models
├── __init__.py
├── components.py
├── main_nets.py
├── pointnet_util.py
└── weights_init.py
├── requirements.txt
├── test.py
├── train.py
├── utils
├── __init__.py
├── builder.py
├── eval.py
├── trainer.py
├── utils.py
└── visualization.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | log/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "pytorchse3"]
2 | path = pytorchse3
3 | url = https://github.com/eigenvivek/pytorchse3
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Toan Nguyen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Language-Conditioned Affordance-Pose Detection in 3D Point Clouds
4 |
5 | [](https://2024.ieee-icra.org/)
6 | [](https://arxiv.org/abs/2309.10911)
7 |
8 | Official code for the ICRA 2024 paper "Language-Conditioned Affordance-Pose Detection in 3D Point Clouds".
9 |
10 |

11 |
12 | We address the task of language-driven affordance-pose detection in 3D point clouds. Our method simultaneously detect open-vocabulary affordances and generate affordance-specific 6-DoF poses.
13 |
14 | 
15 |
16 | We present 3DAPNet, a new method for affordance-pose joint learning. Given the captured 3D point cloud of an object and a set of affordance labels conveyed through natural language texts, our objective is to jointly produce both the relevant affordance regions and the appropriate pose configurations that facilitate the affordances.
17 |
18 |
19 |
20 |
21 | ## 1. Getting Started
22 | We strongly encourage you to create a separate conda environment.
23 |
24 | conda create -n affpose python=3.8
25 | conda activate affpose
26 | conda install pip
27 | pip install -r requirements.txt
28 |
29 | ## 2. Dataset
30 | Our 3DAP dataset is available at [this drive folder](https://drive.google.com/drive/folders/1vDGHs3QZmmF2rGluGlqBIyCp8sPR4Yws?usp=sharing).
31 |
32 | ## 3. Training
33 | Current framework supports training on a single GPU. Followings are the steps for training our method with configuration file ```config/detectiondiffusion.py```.
34 |
35 | * In ```config/detectiondiffusion.py```, change the value of ```data_path``` to your downloaded pickle file.
36 | * Change other hyperparameters if needed.
37 | * Run the following command to start training:
38 |
39 | python3 train.py --config ./config/detectiondiffusion.py
40 |
41 | ## 4. Testing
42 | Executing the following command for testing of your trained model:
43 |
44 | python3 detect.py --config --checkpoint --test_data
45 |
46 | Note that we current generate 2000 poses for each affordance-object pair.
47 | The guidance scale is currently set to 0.2. Feel free to change these hyperparameters according to your preference.
48 |
49 | The result will be saved to a ```result.pkl``` file.
50 |
51 | ## 5. Visualization
52 | To visuaize the result of affordance detection and pose estimation, execute the following script:
53 |
54 | python3 visualize.py --result_file
55 |
56 | Example of training data visualization:
57 |
58 |
59 |
60 | ## 6. Citation
61 |
62 | If you find our work useful for your research, please cite:
63 | ```
64 | @inproceedings{Nguyen2024language,
65 | title={Language-Conditioned Affordance-Pose Detection in 3D Point Clouds},
66 | author={Nguyen, Toan and Vu, Minh Nhat and Huang, Baoru and Van Vo, Tuan and Truong, Vy and Le, Ngan and Vo, Thieu and Le, Bac and Nguyen, Anh},
67 | booktitle = ICRA,
68 | year = {2024}
69 | }
70 | ```
71 | Thank you very much.
72 |
73 | ## 7. Acknowledgement
74 |
75 | Our source code is built based on [3D AffordaceNet](https://github.com/Gorilla-Lab-SCUT/AffordanceNet). We express a huge thank to them.
--------------------------------------------------------------------------------
/assets/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/1ec2917f53ea0925ab214fb560c4056751c84bf7/assets/intro.png
--------------------------------------------------------------------------------
/assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/1ec2917f53ea0925ab214fb560c4056751c84bf7/assets/method.png
--------------------------------------------------------------------------------
/assets/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fsoft-AIC/Language-Conditioned-Affordance-Pose-Detection-in-3D-Point-Clouds/1ec2917f53ea0925ab214fb560c4056751c84bf7/assets/visualization.png
--------------------------------------------------------------------------------
/config/detectiondiffusion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from os.path import join as opj
4 | from utils import PN2_BNMomentum
5 |
6 | exp_name = 'detectiondiffusion'
7 | seed = 1
8 | log_dir = opj("./log/", exp_name)
9 | try:
10 | os.makedirs(log_dir)
11 | except:
12 | print('Logging Dir is already existed!')
13 |
14 | # scheduler = dict(
15 | # type='lr_lambda',
16 | # lr_lambda=PN2_Scheduler(init_lr=0.001, step=20,
17 | # decay_rate=0.5, min_lr=1e-5)
18 | # )
19 |
20 | scheduler = None
21 |
22 | optimizer = dict(
23 | type='adam',
24 | lr=1e-3,
25 | betas=(0.9, 0.999),
26 | eps=1e-08,
27 | weight_decay=1e-5,
28 | )
29 |
30 | model = dict(
31 | type='detectiondiffusion',
32 | device=torch.device('cuda'),
33 | background_text='none',
34 | betas=[1e-4, 0.02],
35 | n_T=1000,
36 | drop_prob=0.1,
37 | weights_init='default_init',
38 | )
39 |
40 | training_cfg = dict(
41 | model=model,
42 | batch_size=32,
43 | epoch=200,
44 | gpu='0',
45 | workflow=dict(
46 | train=1,
47 | ),
48 | bn_momentum=PN2_BNMomentum(origin_m=0.1, m_decay=0.5, step=20),
49 | )
50 |
51 | data = dict(
52 | data_path="../full_shape_release.pkl",
53 | )
--------------------------------------------------------------------------------
/dataset/ThreeDAPDataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from torch.utils.data import Dataset
3 | import pickle as pkl
4 | from scipy.spatial.transform import Rotation as R
5 |
6 |
7 | class ThreeDAPDataset(Dataset):
8 | """_summary_
9 | This class is for the data loading.
10 | """
11 | def __init__(self, data_path, mode):
12 | """_summary_
13 |
14 | Args:
15 | data_path (str): path to the dataset
16 | """
17 | super().__init__()
18 | self.data_path = data_path
19 | self.mode = mode
20 | if self.mode in ["train", "val", "test"]:
21 | self._load_data()
22 | else:
23 | raise ValueError("Mode must be train, val, or test!")
24 |
25 | def _load_data(self):
26 | self.all_data = []
27 |
28 | with open(self.data_path, "rb") as f:
29 | data = pkl.load(f)
30 | random.shuffle(data)
31 |
32 | if self.mode == "train": data = data[:int(0.7 * len(data))]
33 | elif self.mode == "val": data = data[int(0.7 * len(data)):int(0.8 * len(data))]
34 | else: data = data[int(0.8 * len(data)):]
35 |
36 | for data_point in data:
37 | for affordance in data_point["affordance"]:
38 | for pose in data_point["pose"][affordance]:
39 | new_data_dict = {
40 | "shape_id": data_point["shape_id"],
41 | "semantic class": data_point["semantic class"],
42 | "point cloud": data_point["full_shape"]["coordinate"],
43 | "affordance": affordance,
44 | "affordance label": data_point["full_shape"]["label"][affordance],
45 | "rotation": R.from_matrix(pose[:3, :3]).as_quat(),
46 | "translation": pose[:3, 3]
47 | }
48 | self.all_data.append(new_data_dict)
49 |
50 | def __getitem__(self, index):
51 | """_summary_
52 |
53 | Args:
54 | index (int): the element index
55 |
56 | Returns:
57 | shape id, semantic class, coordinate, affordance text, affordance label, rotation and translation
58 | """
59 | data_dict = self.all_data[index]
60 | return data_dict['shape_id'], data_dict['semantic class'], data_dict['point cloud'], data_dict['affordance'], \
61 | data_dict['affordance label'], data_dict['rotation'], data_dict['translation']
62 |
63 | def __len__(self):
64 | return len(self.all_data)
65 |
66 |
67 | if __name__ == "__main__":
68 | random.seed(1)
69 | dataset = ThreeDAPDataset(data_path="../full_shape_release.pkl", mode="train")
70 | print(len(dataset))
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .ThreeDAPDataset import ThreeDAPDataset
2 |
3 |
4 | __all__ = ['ThreeDAPDataset']
--------------------------------------------------------------------------------
/detect.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from gorilla.config import Config
4 | from utils import *
5 | import argparse
6 | import pickle
7 | from tqdm import tqdm
8 | import random
9 |
10 |
11 | GUIDE_W = 0.2
12 | DEVICE = torch.device('cuda')
13 |
14 |
15 | # Argument Parser
16 | def parse_args():
17 | parser = argparse.ArgumentParser(description="Detect affordance and poses")
18 | parser.add_argument("--config", help="test config file path")
19 | parser.add_argument("--checkpoint", help="path to checkpoint model")
20 | parser.add_argument("--test_data", help="path to test_data")
21 | args = parser.parse_args()
22 | return args
23 |
24 |
25 | if __name__ == "__main__":
26 | args = parse_args()
27 | cfg = Config.fromfile(args.config)
28 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.training_cfg.gpu
29 | model = build_model(cfg).to(DEVICE)
30 |
31 | if args.checkpoint != None:
32 | print("Loading checkpoint....")
33 | _, exten = os.path.splitext(args.checkpoint)
34 | if exten == '.t7':
35 | model.load_state_dict(torch.load(args.checkpoint))
36 | elif exten == '.pth':
37 | check = torch.load(args.checkpoint)
38 | model.load_state_dict(check['model_state_dict'])
39 | else:
40 | raise ValueError("Must specify a checkpoint path!")
41 |
42 | if cfg.get('seed') != None:
43 | set_random_seed(cfg.seed)
44 |
45 | with open(args.test_data, 'rb') as f:
46 | shape_data = pickle.load(f)
47 | random.shuffle(shape_data)
48 | shape_data = shape_data[int(0.8 * len(shape_data)):]
49 |
50 | print("Detecting")
51 | model.eval()
52 | with torch.no_grad():
53 | for shape in tqdm(shape_data):
54 | xyz = torch.from_numpy(shape['full_shape']['coordinate']).unsqueeze(0).float().cuda()
55 | shape['result'] = {text: [*(model.detect_and_sample(xyz, text, 2000, guide_w=GUIDE_W))] for text in shape['affordance']}
56 |
57 | with open(f'{cfg.log_dir}/result.pkl', 'wb') as f:
58 | pickle.dump(shape_data, f)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .main_nets import DetectionDiffusion
2 | from .weights_init import weights_init
3 |
4 |
5 | __all__ = ['DetectionDiffusion', 'weights_init']
--------------------------------------------------------------------------------
/models/components.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import open_clip
4 | import math
5 | import torch.nn.functional as F
6 | from .pointnet_util import PointNetSetAbstractionMsg, PointNetSetAbstraction, PointNetFeaturePropagation
7 |
8 |
9 | class SinusoidalPositionEmbeddings(nn.Module):
10 | """
11 | Sinusoidal embedding for time step.
12 | """
13 | def __init__(self, dim, scale=1.0):
14 | super().__init__()
15 | self.dim = dim
16 | self.scale = scale
17 |
18 | def forward(self, time):
19 | time = time * self.scale
20 | device = time.device
21 | half_dim = self.dim // 2
22 | embeddings = math.log(10000) / (half_dim - 1 + 1e-5)
23 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
24 | embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0)
25 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
26 | return embeddings
27 |
28 | def __len__(self):
29 | return self.dim
30 |
31 |
32 | class TimeNet(nn.Module):
33 | """
34 | Time Embeddings
35 | """
36 | def __init__(self, dim):
37 | super().__init__()
38 | self.net = nn.Sequential(
39 | nn.Linear(1, dim),
40 | nn.GELU(),
41 | nn.Linear(dim, dim)
42 | )
43 | def forward(self, t):
44 | return self.net(t)
45 |
46 |
47 | class TextEncoder(nn.Module):
48 | """
49 | Text Encoder to encode the text prompt.
50 | """
51 | def __init__(self, device):
52 | super(TextEncoder, self).__init__()
53 | self.device = device
54 | self.clip_model, _, _ = open_clip.create_model_and_transforms("ViT-B-32", pretrained="laion2b_s34b_b79k",
55 | device=self.device)
56 |
57 | def forward(self, texts):
58 | """
59 | texts can be a single string or a list of strings.
60 | """
61 | tokenizer = open_clip.get_tokenizer("ViT-B-32")
62 | tokens = tokenizer(texts).to(self.device)
63 | text_features = self.clip_model.encode_text(tokens).to(self.device)
64 | return text_features
65 |
66 |
67 | class PointNetPlusPlus(nn.Module):
68 | """_summary_
69 | PointNet++ class.
70 | """
71 | def __init__(self):
72 | super(PointNetPlusPlus, self).__init__()
73 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [
74 | 32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
75 | self.sa2 = PointNetSetAbstractionMsg(
76 | 128, [0.4, 0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
77 | self.sa3 = PointNetSetAbstraction(
78 | npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
79 |
80 | self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
81 | self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
82 | self.fp1 = PointNetFeaturePropagation(in_channel=134, mlp=[128, 128])
83 |
84 | self.conv1 = nn.Conv1d(128, 512, 1)
85 | self.bn1 = nn.BatchNorm1d(512)
86 |
87 | def forward(self, xyz):
88 | """_summary_
89 | Return point-wise features and point cloud representation.
90 | """
91 | # Set Abstraction layers
92 | xyz = xyz.contiguous().transpose(1, 2)
93 | l0_xyz = xyz
94 | l0_points = xyz
95 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
96 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
97 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
98 | c = l3_points.squeeze()
99 |
100 | # Feature Propagation layers
101 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
102 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
103 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat(
104 | [l0_xyz, l0_points], 1), l1_points)
105 | l0_points = self.bn1(self.conv1(l0_points))
106 | return l0_points, c
107 |
108 |
109 | class PoseNet(nn.Module):
110 | """_summary_
111 | ContextPoseNet class. This class is for a denoising step in the diffusion.
112 | """
113 | def __init__(self):
114 | super(PoseNet, self).__init__()
115 | self.cloud_net0 = nn.Sequential(
116 | nn.Linear(1024, 512),
117 | nn.GroupNorm(8, 512),
118 | nn.GELU(),
119 | nn.Linear(512, 128),
120 | nn.GELU(),
121 | nn.Linear(128, 32)
122 | )
123 | self.cloud_net3 = nn.Sequential(
124 | nn.Linear(32, 16),
125 | nn.GroupNorm(4, 16),
126 | nn.GELU(),
127 | nn.Linear(16, 6)
128 | )
129 | self.cloud_net2 = nn.Sequential(
130 | nn.Linear(32, 16),
131 | nn.GroupNorm(4, 16),
132 | nn.GELU(),
133 | nn.Linear(16, 4)
134 | )
135 | self.cloud_net1 = nn.Sequential(
136 | nn.Linear(32, 16),
137 | nn.GroupNorm(4, 16),
138 | nn.GELU(),
139 | nn.Linear(16, 2)
140 | )
141 | self.cloud_influence_net3 = nn.Sequential(
142 | nn.Linear(6 + 6 + 7, 6),
143 | nn.GELU(),
144 | nn.Linear(6, 6)
145 | )
146 | self.cloud_influence_net2 = nn.Sequential(
147 | nn.Linear(4 + 4 + 7, 4),
148 | nn.GELU(),
149 | nn.Linear(4, 4)
150 | )
151 | self.cloud_influence_net1 = nn.Sequential(
152 | nn.Linear(2 + 2 + 7, 2),
153 | nn.GELU(),
154 | nn.Linear(2, 2)
155 | )
156 |
157 | self.text_net0 = nn.Sequential(
158 | nn.Linear(512, 256),
159 | nn.GroupNorm(8, 256),
160 | nn.GELU(),
161 | nn.Linear(256, 128),
162 | nn.GELU(),
163 | nn.Linear(128, 32)
164 | )
165 | self.text_net3 = nn.Sequential(
166 | nn.Linear(32, 16),
167 | nn.GroupNorm(4, 16),
168 | nn.GELU(),
169 | nn.Linear(16, 6)
170 | )
171 | self.text_net2 = nn.Sequential(
172 | nn.Linear(32, 16),
173 | nn.GroupNorm(4, 16),
174 | nn.GELU(),
175 | nn.Linear(16, 4)
176 | )
177 | self.text_net1 = nn.Sequential(
178 | nn.Linear(32, 16),
179 | nn.GroupNorm(4, 16),
180 | nn.GELU(),
181 | nn.Linear(16, 2)
182 | )
183 | self.text_influence_net3 = nn.Sequential(
184 | nn.Linear(6 + 6 + 7, 6),
185 | nn.GELU(),
186 | nn.Linear(6, 6)
187 | )
188 | self.text_influence_net2 = nn.Sequential(
189 | nn.Linear(4 + 4 + 7, 4),
190 | nn.GELU(),
191 | nn.Linear(4, 4)
192 | )
193 | self.text_influence_net1 = nn.Sequential(
194 | nn.Linear(2 + 2 + 7, 2),
195 | nn.GELU(),
196 | nn.Linear(2, 2)
197 | )
198 |
199 | # self.time_net3 = SinusoidalPositionEmbeddings(dim=6)
200 | # self.time_net2 = SinusoidalPositionEmbeddings(dim=4)
201 | # self.time_net1 = SinusoidalPositionEmbeddings(dim=2)
202 | self.time_net3 = TimeNet(dim=6)
203 | self.time_net2 = TimeNet(dim=4)
204 | self.time_net1 = TimeNet(dim=2)
205 |
206 | self.down1 = nn.Sequential(
207 | nn.Linear(7, 6),
208 | nn.GELU(),
209 | nn.Linear(6, 6)
210 | )
211 | self.down2 = nn.Sequential(
212 | nn.Linear(6, 4),
213 | nn.GELU(),
214 | nn.Linear(4, 4)
215 | )
216 | self.down3 = nn.Sequential(
217 | nn.Linear(4, 2),
218 | nn.GELU(),
219 | nn.Linear(2, 2)
220 | )
221 |
222 | self.up1 = nn.Sequential(
223 | nn.Linear(2 + 4, 4),
224 | nn.GELU(),
225 | nn.Linear(4, 4)
226 | )
227 | self.up2 = nn.Sequential(
228 | nn.Linear(4 + 6, 6),
229 | nn.GELU(),
230 | nn.Linear(6, 6)
231 | )
232 | self.up3 = nn.Sequential(
233 | nn.Linear(6 + 7, 7),
234 | nn.GELU(),
235 | nn.Linear(7, 7)
236 | )
237 |
238 | def forward(self, g, c, t, context_mask, _t):
239 | """_summary_
240 | Args:
241 | g: pose representations, size [B, 7]
242 | c: point cloud representations, size [B, 1024]
243 | t: affordance texts, size [B, 512]
244 | context_mask: masks {0, 1} for the contexts, size [B, 1]
245 | _t is for the timesteps, size [B,]
246 | """
247 | c = c * context_mask
248 | c0 = self.cloud_net0(c)
249 | c1 = self.cloud_net1(c0)
250 | c2 = self.cloud_net2(c0)
251 | c3 = self.cloud_net3(c0)
252 |
253 | t = t * context_mask
254 | t0 = self.text_net0(t)
255 | t1 = self.text_net1(t0)
256 | t2 = self.text_net2(t0)
257 | t3 = self.text_net3(t0)
258 |
259 | _t0 = _t.unsqueeze(1)
260 | _t1 = self.time_net1(_t0)
261 | _t2 = self.time_net2(_t0)
262 | _t3 = self.time_net3(_t0)
263 |
264 | g = g.float()
265 | g_down1 = self.down1(g) # 6
266 | g_down2 = self.down2(g_down1) # 4
267 | g_down3 = self.down3(g_down2) # 2
268 |
269 | c1_influence = self.cloud_influence_net1(torch.cat((c1, g, _t1), dim=1))
270 | t1_influence = self.text_influence_net1(torch.cat((t1, g, _t1), dim=1))
271 | influences1 = F.softmax(torch.cat((c1_influence.unsqueeze(1), t1_influence.unsqueeze(1)), dim=1), dim=1)
272 | ct1 = (c1 * influences1[:, 0, :] + t1 * influences1[:, 1, :])
273 | up1 = self.up1(torch.cat((g_down3 * ct1 + _t1, g_down2), dim=1))
274 |
275 | c2_influence = self.cloud_influence_net2(torch.cat((c2, g, _t2), dim=1))
276 | t2_influence = self.text_influence_net2(torch.cat((t2, g, _t2), dim=1))
277 | influences2 = F.softmax(torch.cat((c2_influence.unsqueeze(1), t2_influence.unsqueeze(1)), dim=1), dim=1)
278 | ct2 = (c2 * influences2[:, 0, :] + t2 * influences2[:, 1, :])
279 | up2 = self.up2(torch.cat((up1 * ct2 + _t2, g_down1), dim=1))
280 |
281 | c3_influence = self.cloud_influence_net3(torch.cat((c3, g, _t3), dim=1))
282 | t3_influence = self.text_influence_net3(torch.cat((t3, g, _t3), dim=1))
283 | influences3 = F.softmax(torch.cat((c3_influence.unsqueeze(1), t3_influence.unsqueeze(1)), dim=1), dim=1)
284 | ct3 = (c3 * influences3[:, 0, :] + t3 * influences3[:, 1, :])
285 | up3 = self.up3(torch.cat((up2 * ct3 + _t3, g), dim=1)) # size [B, 7]
286 |
287 | return up3
--------------------------------------------------------------------------------
/models/main_nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .components import TextEncoder, PointNetPlusPlus, PoseNet
6 |
7 |
8 | text_encoder = TextEncoder(device=torch.device('cuda'))
9 |
10 |
11 | # Linear noise scheduler
12 | def linear_diffusion_schedule(betas, T):
13 | """_summary_
14 | Linear cheduling for sampling in training.
15 | """
16 | beta_t = (betas[1] - betas[0]) * torch.arange(0, T + 1, dtype=torch.float32) / T + betas[0]
17 | sqrt_beta_t = torch.sqrt(beta_t)
18 | alpha_t = 1 - beta_t
19 | log_alpha_t = torch.log(alpha_t)
20 | alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
21 |
22 | sqrtab = torch.sqrt(alphabar_t)
23 | oneover_sqrta = 1 / torch.sqrt(alpha_t)
24 |
25 | sqrtmab = torch.sqrt(1 - alphabar_t)
26 | mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
27 |
28 | return {
29 | "alpha_t": alpha_t, # \alpha_t
30 | "oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
31 | "sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
32 | "alphabar_t": alphabar_t, # \bar{\alpha_t}
33 | "sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}}
34 | "sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}}
35 | "mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
36 | }
37 |
38 |
39 | # Main network for affordance detection and pose generation
40 | class DetectionDiffusion(nn.Module):
41 | def __init__(self, betas, n_T, device, background_text, drop_prob=0.1):
42 | """_summary_
43 |
44 | Args:
45 | drop_prob: probability to drop the conditions
46 | """
47 | super(DetectionDiffusion, self).__init__()
48 | self.posenet = PoseNet()
49 | self.pointnetplusplus = PointNetPlusPlus()
50 |
51 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
52 |
53 | # Register_buffer allows accessing dictionary, e.g. can access self.sqrtab later
54 | for k, v in linear_diffusion_schedule(betas, n_T).items():
55 | self.register_buffer(k, v)
56 |
57 | self.n_T = n_T
58 | self.device = device
59 | self.background_text = background_text
60 | self.drop_prob = drop_prob
61 | self.loss_mse = nn.MSELoss()
62 |
63 | def forward(self, xyz, text, affordance_label, g):
64 | """_summary_
65 | This method is used in training, so samples _ts and noise randomly.
66 | """
67 | B = xyz.shape[0] # xyz's size [B, 3, 2048]
68 | point_features, c = self.pointnetplusplus(xyz) # point_features' size [B, 512, 2048], c'size [B, 1024]
69 | with torch.no_grad():
70 | foreground_text_features = text_encoder(text) # size [B, 512]
71 | background_text_features = text_encoder([self.background_text] * B)
72 | text_features = torch.cat((background_text_features.unsqueeze(1), \
73 | foreground_text_features.unsqueeze(1)), dim=1) # size [B, 2, 512]
74 |
75 | affordance_prediction = self.logit_scale * torch.einsum('bij,bjk->bik', text_features, point_features) \
76 | / (torch.einsum('bij,bjk->bik', torch.norm(text_features, dim=2, keepdim=True), \
77 | torch.norm(point_features, dim=1, keepdim=True))) # size [B, 2, 2048]
78 |
79 | affordance_prediction = F.log_softmax(affordance_prediction, dim=1)
80 | affordance_loss = F.nll_loss(affordance_prediction, affordance_label)
81 |
82 | _ts = torch.randint(1, self.n_T + 1, (B,)).to(self.device)
83 | noise = torch.randn_like(g) # eps ~ N(0, 1), g size [B, 7]
84 | g_t = (
85 | self.sqrtab[_ts - 1, None] * g
86 | + self.sqrtmab[_ts - 1, None] * noise
87 | ) # This is the g_t, which is sqrt(alphabar) g_0 + sqrt(1-alphabar) * eps
88 |
89 | # dropout context with some probability
90 | context_mask = torch.bernoulli(torch.zeros(B, 1) + 1 - self.drop_prob).to(self.device)
91 |
92 | # Loss for poseing is MSE between added noise, and our predicted noise
93 | pose_loss = self.loss_mse(noise, self.posenet(g_t, c, foreground_text_features, context_mask, _ts / self.n_T))
94 | return affordance_loss, pose_loss
95 |
96 | def detect_and_sample(self, xyz, text, n_sample, guide_w):
97 | """_summary_
98 | Detect affordance for one point cloud and sample [n_sample] poses that support the 'text' affordance task,
99 | following the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'.
100 | """
101 | g_i = torch.randn(n_sample, (7)).to(self.device) # start by sampling from Gaussian noise
102 | point_features, c = self.pointnetplusplus(xyz) # point_features size [1, 512, 2048], c size [1, 1024]
103 | foreground_text_features = text_encoder(text) # size [1, 512]
104 | background_text_features = text_encoder([self.background_text] * 1)
105 | text_features = torch.cat((background_text_features.unsqueeze(1), \
106 | foreground_text_features.unsqueeze(1)), dim=1) # size [B, 2, 512]
107 |
108 | affordance_prediction = self.logit_scale * torch.einsum('bij,bjk->bik', text_features, point_features) \
109 | / (torch.einsum('bij,bjk->bik', torch.norm(text_features, dim=2, keepdim=True), \
110 | torch.norm(point_features, dim=1, keepdim=True))) # size [1, 2, 2048]
111 |
112 | affordance_prediction = F.log_softmax(affordance_prediction, dim=1) # .cpu().numpy()
113 | c_i = c.repeat(n_sample, 1)
114 | t_i = foreground_text_features.repeat(n_sample, 1)
115 | context_mask = torch.ones((n_sample, 1)).float().to(self.device)
116 |
117 | # Double the batch
118 | c_i = c_i.repeat(2, 1)
119 | t_i = t_i.repeat(2, 1)
120 | context_mask = context_mask.repeat(2, 1)
121 | context_mask[n_sample:] = 0. # make second half of the back context-free
122 |
123 | for i in range(self.n_T, 0, -1):
124 | _t_is = torch.tensor([i / self.n_T]).repeat(n_sample).repeat(2).to(self.device)
125 | g_i = g_i.repeat(2, 1)
126 |
127 | z = torch.randn(n_sample, (7)) if i > 1 else torch.zeros((n_sample, 7))
128 | z = z.to(self.device)
129 | eps = self.posenet(g_i, c_i, t_i, context_mask, _t_is)
130 | eps1 = eps[:n_sample]
131 | eps2 = eps[n_sample:]
132 | eps = (1 + guide_w) * eps1 - guide_w * eps2
133 |
134 | g_i = g_i[:n_sample]
135 | g_i = self.oneover_sqrta[i] * (g_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
136 | return np.argmax(affordance_prediction.cpu().numpy(), axis=1), g_i.cpu().numpy()
--------------------------------------------------------------------------------
/models/pointnet_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 |
8 | def timeit(tag, t):
9 | print("{}: {}s".format(tag, time() - t))
10 | return time()
11 |
12 |
13 | def pc_normalize(pc):
14 | l = pc.shape[0]
15 | centroid = np.mean(pc, axis=0)
16 | pc = pc - centroid
17 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
18 | pc = pc / m
19 | return pc
20 |
21 |
22 | def square_distance(src, dst):
23 | """_summary_
24 | Calculate Euclid distance between each two points.
25 |
26 | src^T * dst = xn * xm + yn * ym + zn * zm;
27 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
28 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
29 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
30 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
31 |
32 | Input:
33 | src: source points, [B, N, C]
34 | dst: target points, [B, M, C]
35 | Output:
36 | dist: per-point square distance, [B, N, M]
37 | """
38 | B, N, _ = src.shape
39 | _, M, _ = dst.shape
40 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
41 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
42 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
43 | return dist
44 |
45 |
46 | def index_points(points, idx):
47 | """_summary_
48 | Input:
49 | points: input points data, [B, N, C]
50 | idx: sample index data, [B, S]
51 | Return:
52 | new_points:, indexed points data, [B, S, C]
53 | """
54 | device = points.device
55 | B = points.shape[0]
56 | view_shape = list(idx.shape)
57 | view_shape[1:] = [1] * (len(view_shape) - 1)
58 | repeat_shape = list(idx.shape)
59 | repeat_shape[0] = 1
60 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
61 | new_points = points[batch_indices, idx, :]
62 | return new_points
63 |
64 |
65 | def farthest_point_sample(xyz, npoint):
66 | """_summary_
67 | Input:
68 | xyz: pointcloud data, [B, N, 3]
69 | npoint: number of samples
70 | Return:
71 | centroids: sampled pointcloud index, [B, npoint]
72 | """
73 | device = xyz.device
74 | B, N, C = xyz.shape
75 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
76 | distance = torch.ones(B, N).to(device) * 1e10
77 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
78 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
79 | for i in range(npoint):
80 | centroids[:, i] = farthest
81 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
82 | dist = torch.sum((xyz - centroid) ** 2, -1)
83 | mask = dist < distance
84 | distance[mask] = dist[mask]
85 | farthest = torch.max(distance, -1)[1]
86 | return centroids
87 |
88 |
89 | def query_ball_point(radius, nsample, xyz, new_xyz):
90 | """_summary_
91 | Input:
92 | radius: local region radius
93 | nsample: max sample number in local region
94 | xyz: all points, [B, N, 3]
95 | new_xyz: query points, [B, S, 3]
96 | Return:
97 | group_idx: grouped points index, [B, S, nsample]
98 | """
99 | device = xyz.device
100 | B, N, C = xyz.shape
101 | _, S, _ = new_xyz.shape
102 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
103 | sqrdists = square_distance(new_xyz, xyz)
104 | group_idx[sqrdists > radius ** 2] = N
105 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
106 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
107 | mask = group_idx == N
108 | group_idx[mask] = group_first[mask]
109 | return group_idx
110 |
111 |
112 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
113 | """_summary_
114 | Input:
115 | npoint:
116 | radius:
117 | nsample:
118 | xyz: input points position data, [B, N, 3]
119 | points: input points data, [B, N, D]
120 | Return:
121 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
122 | new_points: sampled points data, [B, npoint, nsample, 3+D]
123 | """
124 | B, N, C = xyz.shape
125 | S = npoint
126 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
127 | new_xyz = index_points(xyz, fps_idx)
128 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
129 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
130 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
131 |
132 | if points is not None:
133 | grouped_points = index_points(points, idx)
134 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
135 | else:
136 | new_points = grouped_xyz_norm
137 | if returnfps:
138 | return new_xyz, new_points, grouped_xyz, fps_idx
139 | else:
140 | return new_xyz, new_points
141 |
142 |
143 | def sample_and_group_all(xyz, points):
144 | """_summary_
145 | Input:
146 | xyz: input points position data, [B, N, 3]
147 | points: input points data, [B, N, D]
148 | Return:
149 | new_xyz: sampled points position data, [B, 1, 3]
150 | new_points: sampled points data, [B, 1, N, 3+D]
151 | """
152 | device = xyz.device
153 | B, N, C = xyz.shape
154 | new_xyz = torch.zeros(B, 1, C).to(device)
155 | grouped_xyz = xyz.view(B, 1, N, C)
156 | if points is not None:
157 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
158 | else:
159 | new_points = grouped_xyz
160 | return new_xyz, new_points
161 |
162 |
163 | class PointNetSetAbstraction(nn.Module):
164 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
165 | super(PointNetSetAbstraction, self).__init__()
166 | self.npoint = npoint
167 | self.radius = radius
168 | self.nsample = nsample
169 | self.mlp_convs = nn.ModuleList()
170 | self.mlp_bns = nn.ModuleList()
171 | last_channel = in_channel
172 | for out_channel in mlp:
173 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
174 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
175 | last_channel = out_channel
176 | self.group_all = group_all
177 |
178 | def forward(self, xyz, points):
179 | """_summary_
180 | Input:
181 | xyz: input points position data, [B, C, N]
182 | points: input points data, [B, D, N]
183 | Return:
184 | new_xyz: sampled points position data, [B, C, S]
185 | new_points_concat: sample points feature data, [B, D', S]
186 | """
187 | xyz = xyz.permute(0, 2, 1)
188 | if points is not None:
189 | points = points.permute(0, 2, 1)
190 |
191 | if self.group_all:
192 | new_xyz, new_points = sample_and_group_all(xyz, points)
193 | else:
194 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
195 | # new_xyz: sampled points position data, [B, npoint, C]
196 | # new_points: sampled points data, [B, npoint, nsample, C+D]
197 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
198 | for i, conv in enumerate(self.mlp_convs):
199 | bn = self.mlp_bns[i]
200 | new_points = F.relu(bn(conv(new_points)))
201 |
202 | new_points = torch.max(new_points, 2)[0]
203 | new_xyz = new_xyz.permute(0, 2, 1)
204 | return new_xyz, new_points
205 |
206 |
207 | class PointNetSetAbstractionMsg(nn.Module):
208 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
209 | super(PointNetSetAbstractionMsg, self).__init__()
210 | self.npoint = npoint
211 | self.radius_list = radius_list
212 | self.nsample_list = nsample_list
213 | self.conv_blocks = nn.ModuleList()
214 | self.bn_blocks = nn.ModuleList()
215 | for i in range(len(mlp_list)):
216 | convs = nn.ModuleList()
217 | bns = nn.ModuleList()
218 | last_channel = in_channel + 3
219 | for out_channel in mlp_list[i]:
220 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
221 | bns.append(nn.BatchNorm2d(out_channel))
222 | last_channel = out_channel
223 | self.conv_blocks.append(convs)
224 | self.bn_blocks.append(bns)
225 |
226 | def forward(self, xyz, points):
227 | """_summary_
228 | Input:
229 | xyz: input points position data, [B, C, N]
230 | points: input points data, [B, D, N]
231 | Return:
232 | new_xyz: sampled points position data, [B, C, S]
233 | new_points_concat: sample points feature data, [B, D', S]
234 | """
235 | xyz = xyz.permute(0, 2, 1)
236 | if points is not None:
237 | points = points.permute(0, 2, 1)
238 |
239 | B, N, C = xyz.shape
240 | S = self.npoint
241 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
242 | new_points_list = []
243 | for i, radius in enumerate(self.radius_list):
244 | K = self.nsample_list[i]
245 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
246 | grouped_xyz = index_points(xyz, group_idx)
247 | grouped_xyz -= new_xyz.view(B, S, 1, C)
248 | if points is not None:
249 | grouped_points = index_points(points, group_idx)
250 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
251 | else:
252 | grouped_points = grouped_xyz
253 |
254 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
255 | for j in range(len(self.conv_blocks[i])):
256 | conv = self.conv_blocks[i][j]
257 | bn = self.bn_blocks[i][j]
258 | grouped_points = F.relu(bn(conv(grouped_points)))
259 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
260 | new_points_list.append(new_points)
261 |
262 | new_xyz = new_xyz.permute(0, 2, 1)
263 | new_points_concat = torch.cat(new_points_list, dim=1)
264 | return new_xyz, new_points_concat
265 |
266 |
267 | class PointNetFeaturePropagation(nn.Module):
268 | def __init__(self, in_channel, mlp):
269 | super(PointNetFeaturePropagation, self).__init__()
270 | self.mlp_convs = nn.ModuleList()
271 | self.mlp_bns = nn.ModuleList()
272 | last_channel = in_channel
273 | for out_channel in mlp:
274 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
275 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
276 | last_channel = out_channel
277 |
278 | def forward(self, xyz1, xyz2, points1, points2):
279 | """_summary_
280 | Input:
281 | xyz1: input points position data, [B, C, N]
282 | xyz2: sampled input points position data, [B, C, S]
283 | points1: input points data, [B, D, N]
284 | points2: input points data, [B, D, S]
285 | Return:
286 | new_points: upsampled points data, [B, D', N]
287 | """
288 | xyz1 = xyz1.permute(0, 2, 1)
289 | xyz2 = xyz2.permute(0, 2, 1)
290 |
291 | points2 = points2.permute(0, 2, 1)
292 | B, N, C = xyz1.shape
293 | _, S, _ = xyz2.shape
294 |
295 | if S == 1:
296 | interpolated_points = points2.repeat(1, N, 1)
297 | else:
298 | dists = square_distance(xyz1, xyz2)
299 | dists, idx = dists.sort(dim=-1)
300 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
301 |
302 | dist_recip = 1.0 / (dists + 1e-8)
303 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
304 | weight = dist_recip / norm
305 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
306 |
307 | if points1 is not None:
308 | points1 = points1.permute(0, 2, 1)
309 | new_points = torch.cat([points1, interpolated_points], dim=-1)
310 | else:
311 | new_points = interpolated_points
312 |
313 | new_points = new_points.permute(0, 2, 1)
314 | for i, conv in enumerate(self.mlp_convs):
315 | bn = self.mlp_bns[i]
316 | new_points = F.relu(bn(conv(new_points)))
317 | return new_points
--------------------------------------------------------------------------------
/models/weights_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def weights_init(m):
4 | """_summary_
5 | Weights initialization
6 | """
7 | classname = m.__class__.__name__
8 | if classname.find('Conv2d') != -1:
9 | torch.nn.init.xavier_normal_(m.weight.data)
10 | if m.state_dict().get('bias') != None:
11 | torch.nn.init.constant_(m.bias.data, 0.0)
12 | elif classname.find('Linear') != -1:
13 | torch.nn.init.xavier_normal_(m.weight.data)
14 | if m.state_dict().get('bias') != None:
15 | torch.nn.init.constant_(m.bias.data, 0.0)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm
3 | h5py
4 | scikit_learn==1.3.0
5 | gorilla-core==0.2.7.8
6 | torch==2.0.1
7 | scipy==1.11.1
8 | trimesh==4.0.7
9 | open_clip_torch
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from utils.eval import affordance_eval, pose_eval
2 | import argparse
3 | import pickle
4 |
5 |
6 | AFFORDANCE_LIST = ['grasp to pour', 'grasp to stab', 'stab', 'pourable', 'lift', 'wrap_grasp', 'listen', 'contain', 'displaY', 'grasp to cut', 'cut', 'wear', 'openable', 'grasp']
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description="Test a model")
11 | parser.add_argument("--result", help="result file")
12 | args = parser.parse_args()
13 | return args
14 |
15 |
16 | if __name__ == "__main__":
17 | args = parse_args()
18 | with open(args.result, 'rb') as f:
19 | result = pickle.load(f)
20 | mIoU, Acc, mAcc = affordance_eval(AFFORDANCE_LIST, result)
21 | print(f'mIoU: {mIoU}, Acc: {Acc}, mAcc: {mAcc}')
22 |
23 | mESM, mCR = pose_eval(result)
24 | print(f'mESM: {mESM}, mCR: {mCR}')
25 |
26 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join as opj
3 | from gorilla.config import Config
4 | from utils import *
5 | import argparse
6 | import torch
7 |
8 |
9 | # Argument Parser
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description="Train a model")
12 | parser.add_argument("--config", help="train config file path")
13 | args = parser.parse_args()
14 | return args
15 |
16 |
17 | if __name__ == "__main__":
18 | args = parse_args()
19 | cfg = Config.fromfile(args.config)
20 |
21 | logger = IOStream(opj(cfg.log_dir, 'run.log'))
22 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.training_cfg.gpu
23 | num_gpu = len(cfg.training_cfg.gpu.split(',')) # number of GPUs to use
24 | logger.cprint('Use %d GPUs: %s' % (num_gpu, cfg.training_cfg.gpu))
25 | if cfg.get('seed') != None: # set random seed
26 | set_random_seed(cfg.seed)
27 | logger.cprint('Set seed to %d' % cfg.seed)
28 | model = build_model(cfg).cuda() # build the model from configuration
29 |
30 | print("Training from scratch!")
31 |
32 | dataset_dict = build_dataset(cfg) # build the dataset
33 | loader_dict = build_loader(cfg, dataset_dict) # build the loader
34 | optim_dict = build_optimizer(cfg, model) # build the optimizer
35 |
36 | # construct the training process
37 | training = dict(
38 | model=model,
39 | dataset_dict=dataset_dict,
40 | loader_dict=loader_dict,
41 | optim_dict=optim_dict,
42 | logger=logger
43 | )
44 |
45 | task_trainer = Trainer(cfg, training)
46 | task_trainer.run()
47 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_optimizer, build_dataset, build_loader, build_model
2 | from .trainer import Trainer
3 | from .utils import set_random_seed, IOStream, PN2_BNMomentum, PN2_Scheduler
4 |
5 | __all__ = ['build_optimizer', 'build_dataset', 'build_loader', 'build_model',
6 | 'Trainer', 'set_random_seed', 'IOStream', 'PN2_BNMomentum', 'PN2_Scheduler']
7 |
--------------------------------------------------------------------------------
/utils/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LambdaLR, MultiStepLR
3 | from dataset import *
4 | from models import *
5 | from torch.utils.data import DataLoader
6 | from torch.optim import SGD, Adam
7 |
8 | # Pools of models, optimizers, weights initialization methods, schedulers
9 | model_pool = {
10 | 'detectiondiffusion': DetectionDiffusion,
11 | }
12 |
13 | optimizer_pool = {
14 | 'sgd': SGD,
15 | 'adam': Adam
16 | }
17 |
18 | init_pool = {
19 | 'default_init': weights_init
20 | }
21 |
22 | scheduler_pool = {
23 | 'step': StepLR,
24 | 'cos': CosineAnnealingLR,
25 | 'lr_lambda': LambdaLR,
26 | 'multi_step': MultiStepLR
27 | }
28 |
29 |
30 | def build_model(cfg):
31 | """_summary_
32 | Function to build the model before training
33 | """
34 | if hasattr(cfg, 'model'):
35 | model_info = cfg.model
36 | weights_init = model_info.get('weights_init', None)
37 | background_text = model_info.get('background_text', 'none')
38 | device = model_info.get('device', torch.device('cuda'))
39 | model_name = model_info.type
40 | model_cls = model_pool[model_name]
41 | if model_name in ['detectiondiffusion']:
42 | betas = model_info.get('betas', [1e-4, 0.02])
43 | n_T = model_info.get('n_T', 1000)
44 | drop_prob = model_info.get('drop_prob', 0.1)
45 | model = model_cls(betas, n_T, device, background_text, drop_prob)
46 | else:
47 | raise ValueError("The model name does not exist!")
48 | if weights_init != None:
49 | init_fn = init_pool[weights_init]
50 | model.apply(init_fn)
51 | return model
52 | else:
53 | raise ValueError("Configuration does not have model config!")
54 |
55 |
56 | def build_dataset(cfg):
57 | """_summary_
58 | Function to build the dataset
59 | """
60 | if hasattr(cfg, 'data'):
61 | data_info = cfg.data
62 | data_path = data_info.data_path
63 | train_set = ThreeDAPDataset(data_path, mode='train')
64 | val_set = ThreeDAPDataset(data_path, mode='val')
65 | test_set = ThreeDAPDataset(data_path, mode='test')
66 | dataset_dict = dict(
67 | train_set=train_set,
68 | val_set=val_set,
69 | test_set=test_set
70 | )
71 | return dataset_dict
72 | else:
73 | raise ValueError("Configuration does not have data config!")
74 |
75 |
76 | def build_loader(cfg, dataset_dict):
77 | """_summary_
78 | Function to build the loader
79 | """
80 | train_set = dataset_dict["train_set"]
81 | train_loader = DataLoader(train_set, batch_size=cfg.training_cfg.batch_size,
82 | shuffle=True, drop_last=False, num_workers=8)
83 | loader_dict = dict(
84 | train_loader=train_loader,
85 | )
86 |
87 | return loader_dict
88 |
89 |
90 | def build_optimizer(cfg, model):
91 | """_summary_
92 | Function to build the optimizer
93 | """
94 | optimizer_info = cfg.optimizer
95 | optimizer_type = optimizer_info.type
96 | optimizer_info.pop('type')
97 | optimizer_cls = optimizer_pool[optimizer_type]
98 | optimizer = optimizer_cls(model.parameters(), **optimizer_info)
99 | scheduler_info = cfg.scheduler
100 | if scheduler_info:
101 | scheduler_name = scheduler_info.type
102 | scheduler_info.pop('type')
103 | scheduler_cls = scheduler_pool[scheduler_name]
104 | scheduler = scheduler_cls(optimizer, **scheduler_info)
105 | else:
106 | scheduler = None
107 | optim_dict = dict(
108 | scheduler=scheduler,
109 | optimizer=optimizer
110 | )
111 | return optim_dict
112 |
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import cdist
3 | from scipy.spatial.transform import Rotation as R
4 |
5 |
6 | def affordance_eval(affordance_list, result):
7 | """_summary_
8 | This fuction evaluates the affordance detection capability.
9 | `result` is loaded from result.pkl file produced by detect.py.
10 | """
11 | num_correct = 0
12 | num_all = 0
13 | num_points = {aff: 0 for aff in affordance_list}
14 | num_label_points = {aff: 0 for aff in affordance_list}
15 | num_correct_fg_points = {aff: 0 for aff in affordance_list}
16 | num_correct_bg_points = {aff: 0 for aff in affordance_list}
17 | num_union_points = {aff: 0 for aff in affordance_list}
18 | num_appearances = {aff: 0 for aff in affordance_list}
19 |
20 | for shape in result:
21 | for affordance in shape['affordance']:
22 | label = np.transpose(shape['full_shape']['label'][affordance])
23 | prediction = shape['result'][affordance][0]
24 |
25 | num_correct += np.sum(label == prediction)
26 | num_all += 2048
27 | num_points[affordance] += 2048
28 | num_label_points[affordance] += np.sum(label == 1.)
29 | num_correct_fg_points[affordance] += np.sum((label == 1.) & (prediction == 1.))
30 | num_correct_bg_points[affordance] += np.sum((label == 0.) & (prediction == 0.))
31 | num_union_points[affordance] += np.sum((label == 1.) | (prediction == 1.))
32 | mIoU = np.average(np.array(list(num_correct_fg_points.values())) / np.array(list(num_union_points.values())),
33 | weights=np.array(list(num_appearances.values())))
34 | Acc = num_correct / num_all
35 | mAcc = np.mean((np.array(list(num_correct_fg_points.values())) + np.array(list(num_correct_bg_points.values()))) / \
36 | np.array(list(num_points.values())))
37 |
38 | return mIoU, Acc, mAcc
39 |
40 |
41 | def pose_eval(result):
42 | """_summary_
43 | This function evaluates the pose detection capability.
44 | `result` is loaded from result.pkl file produced by detect.py.
45 | """
46 | all_min_dist = []
47 | all_rate = []
48 | for object in result:
49 | for affordance in object['affordance']:
50 | gt_poses = np.array([np.concatenate((R.from_matrix(p[:3, :3]).as_quat(), p[:3, 3]), axis=0) for p in object['pose'][affordance]])
51 | distances = cdist(gt_poses, object['result'][affordance][1])
52 | rate = np.sum(np.any(distances <= 0.2, axis=1)) / len(object['pose'][affordance])
53 | all_rate.append(rate)
54 |
55 | g = gt_poses[:, np.newaxis, :]
56 | g_pred = object['result'][affordance][1]
57 | l2_distances = np.sqrt(np.sum((g-g_pred)**2, axis=2))
58 | min_distance = np.min(l2_distances)
59 |
60 | # discard cases when set of gt poses and set of detected poses too far from each other, to get a stable result
61 | if min_distance <= 1.0:
62 | all_min_dist.append(min_distance)
63 | return (np.mean(np.array(all_min_dist)), np.mean(np.array(all_rate)))
--------------------------------------------------------------------------------
/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from os.path import join as opj
4 | from utils import *
5 |
6 |
7 | DEVICE = torch.device('cuda')
8 |
9 |
10 | class Trainer(object):
11 | def __init__(self, cfg, running):
12 | super().__init__()
13 | self.cfg = cfg
14 | self.logger = running['logger']
15 | self.model = running["model"]
16 | self.dataset_dict = running["dataset_dict"]
17 | self.loader_dict = running["loader_dict"]
18 | self.train_loader = self.loader_dict.get("train_loader", None)
19 | self.optimizer_dict = running["optim_dict"]
20 | self.optimizer = self.optimizer_dict.get("optimizer", None)
21 | self.scheduler = self.optimizer_dict.get("scheduler", None)
22 | self.epoch = 0
23 | self.bn_momentum = self.cfg.training_cfg.get('bn_momentum', None)
24 |
25 | def train(self):
26 | self.model.train()
27 | self.logger.cprint("Epoch(%d) begin training........" % self.epoch)
28 | pbar = tqdm(self.train_loader)
29 | for _, _, xyz, text, affordance_label, rotation, translation in pbar:
30 | self.optimizer.zero_grad()
31 | xyz = xyz.float()
32 | rotation = rotation.float()
33 | translation = translation.float()
34 | affordance_label = affordance_label.squeeze().long()
35 |
36 | g = torch.cat((rotation, translation), dim=1)
37 | xyz = xyz.to(DEVICE)
38 | affordance_label = affordance_label.to(DEVICE)
39 | g = g.to(DEVICE)
40 |
41 | affordance_loss, pose_loss = self.model(xyz, text, affordance_label, g)
42 | loss = affordance_loss + pose_loss
43 | loss.backward()
44 |
45 | affordance_l = affordance_loss.item()
46 | pose_l = pose_loss.item()
47 | pbar.set_description(f'Affordance loss: {affordance_l:.5f}, Pose loss: {pose_l:.5f}')
48 | self.optimizer.step()
49 |
50 | if self.scheduler != None:
51 | self.scheduler.step()
52 | if self.bn_momentum != None:
53 | self.model.apply(lambda x: self.bn_momentum(x, self.epoch))
54 |
55 | outstr = f"\nEpoch {self.epoch}, Last Affordance loss: {affordance_l:.5f}, Last Pose loss: {pose_l:.5f}"
56 | self.logger.cprint(outstr)
57 | print('Saving checkpoint')
58 | torch.save(self.model.state_dict(), opj(self.cfg.log_dir, 'current_model.t7'))
59 | self.epoch += 1
60 |
61 | def val(self):
62 | raise NotImplementedError
63 |
64 | def test(self):
65 | raise NotImplementedError
66 |
67 | def run(self):
68 | EPOCH = self.cfg.training_cfg.epoch
69 | workflow = self.cfg.training_cfg.workflow
70 |
71 | while self.epoch < EPOCH:
72 | for key, running_epoch in workflow.items():
73 | epoch_runner = getattr(self, key)
74 | for _ in range(running_epoch):
75 | epoch_runner()
76 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import random
5 |
6 |
7 | class IOStream():
8 | def __init__(self, path):
9 | self.f = open(path, 'a')
10 |
11 | def cprint(self, text):
12 | print(text)
13 | self.f.write(text+'\n')
14 | self.f.flush()
15 |
16 | def close(self):
17 | self.f.close()
18 |
19 |
20 | class PN2_Scheduler(object):
21 | def __init__(self, init_lr, step, decay_rate, min_lr):
22 | super().__init__()
23 | self.init_lr = init_lr
24 | self.step = step
25 | self.decay_rate = decay_rate
26 | self.min_lr = min_lr
27 | return
28 |
29 | def __call__(self, epoch):
30 | factor = self.decay_rate**(epoch//self.step)
31 | if self.init_lr*factor < self.min_lr:
32 | factor = self.min_lr / self.init_lr
33 | return factor
34 |
35 |
36 | class PN2_BNMomentum(object):
37 | def __init__(self, origin_m, m_decay, step):
38 | super().__init__()
39 | self.origin_m = origin_m
40 | self.m_decay = m_decay
41 | self.step = step
42 | return
43 |
44 | def __call__(self, m, epoch):
45 | momentum = self.origin_m * (self.m_decay**(epoch//self.step))
46 | if momentum < 0.01:
47 | momentum = 0.01
48 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
49 | m.momentum = momentum
50 | return
51 |
52 |
53 | def set_random_seed(seed):
54 | random.seed(seed)
55 | np.random.seed(seed)
56 | torch.manual_seed(seed)
57 | torch.cuda.manual_seed(seed)
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 |
3 |
4 | def create_gripper_marker(color=[0, 255, 0], tube_radius=0.002, sections=6):
5 | """Create a 3D mesh visualizing a parallel yaw gripper. It consists of four cylinders.
6 |
7 | Args:
8 | color (list, optional): RGB values of marker. Defaults to [0, 0, 255].
9 | tube_radius (float, optional): Radius of cylinders. Defaults to 0.001.
10 | sections (int, optional): Number of sections of each cylinder. Defaults to 6.
11 |
12 | Returns:
13 | trimesh.Trimesh: A mesh that represents a simple parallel yaw gripper.
14 | """
15 | cfl = trimesh.creation.cylinder(
16 | radius=tube_radius,
17 | sections=sections,
18 | segment=[
19 | [4.10000000e-02, -7.27595772e-12, 6.59999996e-02],
20 | [4.10000000e-02, -7.27595772e-12, 1.12169998e-01],
21 | ],
22 | )
23 | cfr = trimesh.creation.cylinder(
24 | radius=tube_radius,
25 | sections=sections,
26 | segment=[
27 | [-4.100000e-02, -7.27595772e-12, 6.59999996e-02],
28 | [-4.100000e-02, -7.27595772e-12, 1.12169998e-01],
29 | ],
30 | )
31 | cb1 = trimesh.creation.cylinder(
32 | radius=tube_radius, sections=sections, segment=[[0, 0, 0], [0, 0, 6.59999996e-02]]
33 | )
34 | cb2 = trimesh.creation.cylinder(
35 | radius=tube_radius,
36 | sections=sections,
37 | segment=[[-4.100000e-02, 0, 6.59999996e-02], [4.100000e-02, 0, 6.59999996e-02]],
38 | )
39 |
40 | tmp = trimesh.util.concatenate([cb1, cb2, cfr, cfl])
41 | tmp.visual.face_colors = color
42 |
43 | return tmp
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import numpy as np
3 | import pickle
4 | from scipy.spatial.transform import Rotation as R
5 | import argparse
6 | from utils.visualization import create_gripper_marker
7 |
8 | color_code_1 = np.array([0, 0, 255]) # color code for affordance region
9 | color_code_2 = np.array([0, 255, 0]) # color code for gripper pose
10 | num_pose = 100 # number of poses to visualize per each object-affordance pair
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description="Visualize")
15 | parser.add_argument("--result", help="result file")
16 | args = parser.parse_args()
17 | return args
18 |
19 |
20 | if __name__ == "__main__":
21 | args = parse_args()
22 | result_file = args.result_file
23 | with open(result_file, 'rb') as f:
24 | result = pickle.load(f)
25 |
26 | for i in range(len(result)):
27 | if result[i]['semantic class'] == 'Bottle':
28 | shape_index = i
29 | shape = result[shape_index]
30 |
31 | for affordance in shape['affordance']:
32 | colors = np.transpose(shape['result'][affordance][0]) * color_code_1
33 | point_cloud = trimesh.points.PointCloud(shape['full_shape']['coordinate'], colors=colors)
34 | print(f"Affordance: {affordance}")
35 | T = shape['result'][affordance][1][:num_pose]
36 | rotation = np.concatenate((R.from_quat(T[:, :4]).as_matrix(), np.zeros((num_pose, 1, 3), dtype=np.float32)), axis=1)
37 | translation = np.expand_dims(np.concatenate((T[:, 4:], np.ones((num_pose, 1), dtype=np.float32)), axis=1), axis=2)
38 | T = np.concatenate((rotation, translation), axis=2)
39 | poses = [create_gripper_marker(color=color_code_2).apply_transform(t) for t in T
40 | if np.min(np.linalg.norm(point_cloud - (t @ np.array([0., 0., 6.59999996e-02, 1.]))[:3], axis=1)) <= 0.03] # this line is used to get reliable poses only
41 | scene = trimesh.Scene([point_cloud, poses])
42 | scene.show(line_settings={'point size': 10})
--------------------------------------------------------------------------------