├── LICENSE
├── README.md
├── data
└── S3DIS
│ ├── S3DISDataLoader.py
│ └── s3dis_names.txt
├── flowchart.jpg
├── model
├── gacnet.py
└── gacnet_utils.py
├── pointnet2_ops_lib
├── MANIFEST.in
├── pointnet2_ops
│ ├── __init__.py
│ ├── _ext-src
│ │ ├── include
│ │ │ ├── ball_query.h
│ │ │ ├── cuda_utils.h
│ │ │ ├── group_points.h
│ │ │ ├── interpolate.h
│ │ │ ├── sampling.h
│ │ │ └── utils.h
│ │ └── src
│ │ │ ├── ball_query.cpp
│ │ │ ├── ball_query_gpu.cu
│ │ │ ├── bindings.cpp
│ │ │ ├── group_points.cpp
│ │ │ ├── group_points_gpu.cu
│ │ │ ├── interpolate.cpp
│ │ │ ├── interpolate_gpu.cu
│ │ │ ├── sampling.cpp
│ │ │ └── sampling_gpu.cu
│ ├── _version.py
│ ├── pointnet2_modules.py
│ └── pointnet2_utils.py
└── setup.py
├── tool
├── .DS_Store
├── test.py
└── train.py
└── util
├── dataset.py
├── ply.py
├── transform.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Yongqiang Mao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GACNet_PyTorch
2 | This is the pytorch implmentation of GACNet on S3DIS.
3 |
4 |
5 |
6 |
7 | If you find this repository useful. Please consider giving a star :star:.
8 | ## Dependencies
9 | - Python 3.6
10 | - PyTorch 1.7
11 | - cuda 11
12 |
13 |
14 | ## Install pointnet2-ops
15 |
16 | ```
17 | cd pointnet2_ops_lib
18 | python setup.py install
19 | ```
20 |
21 | ## Train Model
22 | ```
23 | cd tool
24 | python train.py
25 | ```
26 |
27 | ## Test Model
28 | ```
29 | cd tool
30 | python test.py
31 | ```
32 |
33 | ## References
34 | This repo is built based on the Tensorflow implementation of [GACNet](https://github.com/wleigithub/GACNet). Thanks for their great work!
35 |
--------------------------------------------------------------------------------
/data/S3DIS/S3DISDataLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter']
5 | class2label = {cls: i for i,cls in enumerate(classes)}
6 | def pc_normalize(data, center):
7 | mindata = np.min(data[:, :3], axis=0)
8 | maxdata = np.max(data[:, :3], axis=0)
9 | center = (center - mindata) / (maxdata - mindata)
10 | data_xyz_norm = (data[:, :3] - mindata) / (maxdata - mindata)
11 | return data_xyz_norm, center
12 |
13 | class S3DISDataset(Dataset):
14 | def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096,
15 | test_area=5, block_size=1.0, sample_rate=1.0, transform=None):
16 | super().__init__()
17 | self.num_point = num_point
18 | self.block_size = block_size
19 | self.transform = transform
20 | rooms = sorted(os.listdir(data_root))
21 | rooms = [room for room in rooms if 'Area_' in room]
22 | if split == 'train':
23 | rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room]
24 | else:
25 | rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room]
26 | self.room_points, self.room_labels = [], []
27 | self.room_coord_min, self.room_coord_max = [], []
28 | num_point_all = []
29 | labelweights = np.zeros(13)
30 | for room_name in rooms_split:
31 | room_path = os.path.join(data_root, room_name)
32 | room_data = np.load(room_path) # xyzrgbl, N*7
33 | points, labels = room_data[:, 0:6], room_data[:, 6] # xyzrgb, N*6; l, N
34 | tmp, _ = np.histogram(labels, range(14))
35 | labelweights += tmp
36 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
37 |
38 | self.room_points.append(points), self.room_labels.append(labels)
39 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
40 | num_point_all.append(labels.size)
41 | labelweights = labelweights.astype(np.float32)
42 | labelweights = labelweights / np.sum(labelweights)
43 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
44 | print(self.labelweights)
45 |
46 |
47 |
48 |
49 | sample_prob = num_point_all / np.sum(num_point_all)
50 | num_iter = int(np.sum(num_point_all) * sample_rate / num_point)
51 | room_idxs = []
52 | for index in range(len(rooms_split)):
53 | room_idxs.extend([index] * int(round(sample_prob[index] * num_iter)))
54 | self.room_idxs = np.array(room_idxs)
55 | print("Totally {} samples in {} set.".format(len(self.room_idxs), split))
56 |
57 | def __getitem__(self, idx):
58 | room_idx = self.room_idxs[idx]
59 | points = self.room_points[room_idx] # N * 6
60 | labels = self.room_labels[room_idx] # N
61 | N_points = points.shape[0]
62 |
63 | while (True):
64 | center = points[np.random.choice(N_points)][:3]
65 | block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0]
66 | block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0]
67 | point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0]
68 | if point_idxs.size > 1024:
69 | break
70 |
71 | if point_idxs.size >= self.num_point:
72 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False)
73 | else:
74 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True)
75 |
76 | # normalize
77 | selected_points = points[selected_point_idxs, :] # num_point * 6
78 | current_points = np.zeros((self.num_point, 9)) # num_point * 9
79 | current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0]
80 | current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1]
81 | current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2]
82 | selected_points[:, 0] = selected_points[:, 0] - center[0]
83 | selected_points[:, 1] = selected_points[:, 1] - center[1]
84 | selected_points[:, 3:6] /= 255.0
85 |
86 | current_points[:, 0:6] = selected_points
87 | current_labels = labels[selected_point_idxs]
88 | if self.transform is not None:
89 | current_points, current_labels = self.transform(current_points, current_labels)
90 | return current_points, current_labels
91 |
92 | def __len__(self):
93 | return len(self.room_idxs)
94 |
95 | class S3DISDatasetWholeScene():
96 | # prepare to give prediction on each points
97 | def __init__(self, root, block_points=4096, split='test', test_area=5, stride=0.5, block_size=1.0, padding=0.001):
98 | self.block_points = block_points
99 | self.block_size = block_size
100 | self.padding = padding
101 | self.root = root
102 | self.split = split
103 | self.stride = stride
104 | self.scene_points_num = []
105 | assert split in ['train', 'test']
106 | if self.split == 'train':
107 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is -1]
108 | else:
109 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is not -1]
110 | self.scene_points_list = []
111 | self.semantic_labels_list = []
112 | self.room_coord_min, self.room_coord_max = [], []
113 | for file in self.file_list:
114 | data = np.load(root + file)
115 | points = data[:, :3]
116 | self.scene_points_list.append(data[:, :6])
117 | self.semantic_labels_list.append(data[:, 6])
118 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
119 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
120 | assert len(self.scene_points_list) == len(self.semantic_labels_list)
121 |
122 | labelweights = np.zeros(13)
123 | for seg in self.semantic_labels_list:
124 | tmp, _ = np.histogram(seg, range(14))
125 | self.scene_points_num.append(seg.shape[0])
126 | labelweights += tmp
127 | labelweights = labelweights.astype(np.float32)
128 | labelweights = labelweights / np.sum(labelweights)
129 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
130 |
131 | def __getitem__(self, index):
132 | point_set_ini = self.scene_points_list[index]
133 | points = point_set_ini[:,:6]
134 | labels = self.semantic_labels_list[index]
135 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
136 | grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - self.block_size) / self.stride) + 1)
137 | grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - self.block_size) / self.stride) + 1)
138 |
139 | data_room, label_room, sample_weight, index_room = np.array([]), np.array([]), np.array([]), np.array([])
140 | for index_y in range(0, grid_y):
141 | for index_x in range(0, grid_x):
142 | s_x = coord_min[0] + index_x * self.stride
143 | e_x = min(s_x + self.block_size, coord_max[0])
144 | s_x = e_x - self.block_size
145 |
146 | s_y = coord_min[1] + index_y * self.stride
147 | e_y = min(s_y + self.block_size, coord_max[1])
148 | s_y = e_y - self.block_size
149 |
150 | point_idxs = np.where(
151 | (points[:, 0] >= s_x - self.padding) & (points[:, 0] <= e_x + self.padding) & (points[:, 1] >= s_y - self.padding) & (
152 | points[:, 1] <= e_y + self.padding))[0]
153 | if point_idxs.size == 0:
154 | continue
155 | num_batch = int(np.ceil(point_idxs.size / self.block_points))
156 | point_size = int(num_batch * self.block_points)
157 | replace = False if (point_size - point_idxs.size <= point_idxs.size) else True
158 | point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace)
159 | point_idxs = np.concatenate((point_idxs, point_idxs_repeat))
160 | np.random.shuffle(point_idxs)
161 |
162 | data_batch = points[point_idxs, :]
163 | normlized_xyz = np.zeros((point_size, 3))
164 |
165 | normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0]
166 | normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1]
167 | normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2]
168 | data_batch[:, 0] = data_batch[:, 0] - (s_x + self.block_size / 2.0)
169 | data_batch[:, 1] = data_batch[:, 1] - (s_y + self.block_size / 2.0)
170 | data_batch[:, 2] = data_batch[:, 2]
171 | data_batch[:, 3:6] /= 255.0
172 | data_batch = np.concatenate((data_batch, normlized_xyz), axis=1)
173 |
174 | label_batch = labels[point_idxs].astype(int)
175 | batch_weight = self.labelweights[label_batch]
176 |
177 | data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch
178 | label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch
179 | sample_weight = np.hstack([sample_weight, batch_weight]) if label_room.size else batch_weight
180 | index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs
181 | data_room = data_room.reshape((-1, self.block_points, data_room.shape[1]))
182 | label_room = label_room.reshape((-1, self.block_points))
183 | sample_weight = sample_weight.reshape((-1, self.block_points))
184 | index_room = index_room.reshape((-1, self.block_points))
185 | return data_room, label_room, sample_weight, index_room
186 |
187 | def __len__(self):
188 | return len(self.scene_points_list)
189 |
190 | if __name__ == '__main__':
191 | data_root = '/data/yxu/PointNonLocal/data/stanford_indoor3d/'
192 | num_point, test_area, block_size, sample_rate = 4096, 5, 1.0, 0.01
193 |
194 | point_data = S3DISDataset(split='train', data_root=data_root, num_point=num_point, test_area=test_area, block_size=block_size, sample_rate=sample_rate, transform=None)
195 | print('point data size:', point_data.__len__())
196 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape)
197 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape)
198 | import torch, time, random
199 | manual_seed = 123
200 | random.seed(manual_seed)
201 | np.random.seed(manual_seed)
202 | torch.manual_seed(manual_seed)
203 | torch.cuda.manual_seed_all(manual_seed)
204 | def worker_init_fn(worker_id):
205 | random.seed(manual_seed + worker_id)
206 | train_loader = torch.utils.data.DataLoader(point_data, batch_size=16, shuffle=True, num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn)
207 | for idx in range(4):
208 | end = time.time()
209 | for i, (input, target) in enumerate(train_loader):
210 | print('time: {}/{}--{}'.format(i+1, len(train_loader), time.time() - end))
211 | end = time.time()
--------------------------------------------------------------------------------
/data/S3DIS/s3dis_names.txt:
--------------------------------------------------------------------------------
1 | ceiling
2 | floor
3 | wall
4 | beam
5 | column
6 | window
7 | door
8 | chair
9 | table
10 | bookcase
11 | sofa
12 | board
13 | clutter
14 |
--------------------------------------------------------------------------------
/flowchart.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WingkeungM/GACNet_PyTorch/c5b3701e7029cf3c6a6ff2622255fc762de6690d/flowchart.jpg
--------------------------------------------------------------------------------
/model/gacnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import pointnet2_ops_lib.pointnet2_ops.pointnet2_utils as pointnet2_utils
4 |
5 | from model.gacnet_utils import *
6 |
7 |
8 | graph_inf = {'stride_list': [1024, 256, 64, 32], #can be seen as the downsampling rate
9 | 'radius_list': [0.1, 0.2, 0.4, 0.8, 1.6], # radius for neighbor points searching
10 | 'maxsample_list': [12, 21, 21, 21, 12] #number of neighbor points for each layer
11 | }
12 |
13 | # number of units for each mlp layer
14 | forward_parm = [
15 | [ [32,32,64], [64] ],
16 | [ [64,64,128], [128] ],
17 | [ [128,128,256], [256] ],
18 | [ [256,256,512], [512] ],
19 | [ [256,256], [256] ]
20 | ]
21 |
22 | # for feature interpolation stage
23 | upsample_parm = [
24 | [128, 128],
25 | [128, 128],
26 | [256, 256],
27 | [256, 256]
28 | ]
29 |
30 | # parameters for fully connection layer
31 | fullconect_parm = 128
32 |
33 | net_inf = {'forward_parm': forward_parm,
34 | 'upsample_parm': upsample_parm,
35 | 'fullconect_parm': fullconect_parm
36 | }
37 |
38 |
39 |
40 |
41 | class GACNet(nn.Module):
42 | def __init__(self, num_classes, graph_inf, net_inf):
43 | super(GACNet, self).__init__()
44 | self.num_classes = num_classes
45 | self.forward_parm, self.upsample_parm, self.fullconect_parm = \
46 | net_inf['forward_parm'], net_inf['upsample_parm'], net_inf['fullconect_parm']
47 |
48 | self.stride_inf = graph_inf['stride_list']
49 |
50 | self.graph_attention_layer1 = GraphAttentionConvLayer(4, self.forward_parm[0][0], self.forward_parm[0][1])
51 | self.graph_attention_layer2 = GraphAttentionConvLayer(64, self.forward_parm[1][0], self.forward_parm[1][1])
52 | self.graph_attention_layer3 = GraphAttentionConvLayer(128, self.forward_parm[2][0], self.forward_parm[2][1])
53 | self.graph_attention_layer4 = GraphAttentionConvLayer(256, self.forward_parm[3][0], self.forward_parm[3][1])
54 |
55 | self.gragh_pooling_layer = GraphPoolingLayer()
56 | self.mid_graph_attention_layers = GraphAttentionConvLayer(512, self.forward_parm[-1][0], self.forward_parm[-1][1])
57 |
58 | self.point_upsample_layer1 = PointUpsampleLayer(512 + 256, self.upsample_parm[3])
59 | self.point_upsample_layer2 = PointUpsampleLayer(256 + 256, self.upsample_parm[2])
60 | self.point_upsample_layer3 = PointUpsampleLayer(128 + 256, self.upsample_parm[1])
61 | self.point_upsample_layer4 = PointUpsampleLayer(64 + 128, self.upsample_parm[0])
62 |
63 | self.graph_attention_layer_for_featurerefine = GraphAttentionConvLayerforFeatureRefine(self.num_classes)
64 |
65 | self.conv1 = nn.Conv1d(128, 128, 1)
66 | self.bn1 = nn.BatchNorm1d(128)
67 | self.drop = nn.Dropout(0.5)
68 | self.conv2 = nn.Conv1d(128, self.num_classes, 1)
69 |
70 | def forward(self, features, graph_prd, coarse_map):
71 | inif = features[:, :, 0:6] # (x,y,z,r,g,b)
72 | features = features[:, :, 2:] # (z, r, g, b, and (initial geofeatures if possible))
73 |
74 | feature_prd = []
75 |
76 | features = self.graph_attention_layer1(graph_prd[0], features)
77 | feature_prd.append(features)
78 | features = self.gragh_pooling_layer(features, coarse_map[0])
79 |
80 | features = self.graph_attention_layer2(graph_prd[1], features)
81 | feature_prd.append(features)
82 | features = self.gragh_pooling_layer(features, coarse_map[1])
83 |
84 | features = self.graph_attention_layer3(graph_prd[2], features)
85 | feature_prd.append(features)
86 | features = self.gragh_pooling_layer(features, coarse_map[2])
87 |
88 | features = self.graph_attention_layer4(graph_prd[3], features)
89 | feature_prd.append(features)
90 | features = self.gragh_pooling_layer(features, coarse_map[3])
91 |
92 | features = self.mid_graph_attention_layers(graph_prd[-1], features)
93 |
94 | features = self.point_upsample_layer1(graph_prd[3]['vertex'], graph_prd[3 + 1]['vertex'], feature_prd[3], features)
95 | features = self.point_upsample_layer2(graph_prd[2]['vertex'], graph_prd[2 + 1]['vertex'], feature_prd[2], features)
96 | features = self.point_upsample_layer3(graph_prd[1]['vertex'], graph_prd[1 + 1]['vertex'], feature_prd[1], features)
97 | features = self.point_upsample_layer4(graph_prd[0]['vertex'], graph_prd[0 + 1]['vertex'], feature_prd[0], features)
98 |
99 | features = features.permute(0, 2, 1)
100 | features = F.relu(self.bn1(self.conv1(features)))
101 | features = self.drop(features)
102 | features = self.conv2(features)
103 | features = features.permute(0, 2, 1)
104 |
105 | features = self.graph_attention_layer_for_featurerefine(inif, features, graph_prd[0]['adjids'])
106 | features = F.log_softmax(features, dim=2)
107 | return features
108 |
109 |
110 | def build_graph_pyramid(xyz, graph_inf):
111 | """ Builds a pyramid of graphs and pooling operations corresponding to progressively coarsened point cloud.
112 | Inputs:
113 | xyz: (batchsize, num_point, nfeature)
114 | graph_inf: parameters for graph building (see run.py)
115 | Outputs:
116 | graph_prd: graph pyramid contains the vertices and their edges at each layer
117 | coarse_map: record the corresponding relation between two close graph layers (for graph coarseing/pooling)
118 | """
119 | stride_list, radius_list, maxsample_list = graph_inf['stride_list'], graph_inf['radius_list'], graph_inf['maxsample_list']
120 |
121 | graph_prd = []
122 | graph = {}
123 | coarse_map = []
124 |
125 | xyz = xyz.contiguous()
126 | ids = pointnet2_utils.ball_query(radius_list[0], maxsample_list[0], xyz, xyz)
127 | graph['vertex'], graph['adjids'] = xyz, ids
128 | graph_prd.append(graph.copy())
129 |
130 | for stride, radius, maxsample in zip(stride_list, radius_list[1:], maxsample_list[1:]):
131 | xyz, coarse_map_ids = graph_coarse(xyz, ids, stride)
132 | coarse_map.append(coarse_map_ids.int())
133 | ids = pointnet2_utils.ball_query(radius, maxsample, xyz, xyz)
134 | graph['vertex'], graph['adjids'] = xyz, ids
135 | graph_prd.append(graph.copy())
136 |
137 | return graph_prd, coarse_map
138 |
139 | if __name__ == '__main__':
140 | import os
141 | import torch
142 | os.environ["CUDA_VISIBLE_DEVICES"] = '1'
143 | input = torch.randn((8,6,4096)).permute(0, 2, 1).cuda()
144 | graph_prd, coarse_map = build_graph_pyramid(input[:, :, :3], graph_inf)
145 | model = GACNet(50, graph_inf, net_inf).cuda()
146 | logits = model(input, graph_prd, coarse_map)
147 |
148 |
--------------------------------------------------------------------------------
/model/gacnet_utils.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import pointnet2_ops_lib.pointnet2_ops.pointnet2_utils as pointnet2_utils
7 |
8 |
9 | gac_par = [
10 | [32, 16], #MLP for xyz
11 | [16, 16], #MLP for feature
12 | [64] #hidden node of MLP for mergering
13 | ]
14 |
15 |
16 | class MLP1D(nn.Module):
17 | def __init__(self, inchannel, mlp_1d):
18 | super(MLP1D, self).__init__()
19 | self.mlp_conv1ds = nn.ModuleList()
20 | self.mlp_bn1ds = nn.ModuleList()
21 | self.mlp_1d = mlp_1d
22 | last_channel = inchannel
23 | for i, outchannel in enumerate(self.mlp_1d):
24 | self.mlp_conv1ds.append(nn.Conv1d(last_channel, outchannel, 1))
25 | self.mlp_bn1ds.append(nn.BatchNorm1d(outchannel))
26 | last_channel = outchannel
27 |
28 | def forward(self, features):
29 | features = features.permute(0, 2, 1)
30 | for i, conv in enumerate(self.mlp_conv1ds):
31 | bn = self.mlp_bn1ds[i]
32 | features = F.relu(bn(conv(features)))
33 | features = features.permute(0, 2, 1)
34 | return features
35 |
36 | class MLP2D(nn.Module):
37 | def __init__(self, inchannel, mlp_2d):
38 | super(MLP2D, self).__init__()
39 | self.mlp_conv2ds = nn.ModuleList()
40 | self.mlp_bn2ds = nn.ModuleList()
41 | self.mlp_2d = mlp_2d
42 | last_channel = inchannel
43 | for i, outchannel in enumerate(self.mlp_2d):
44 | self.mlp_conv2ds.append(nn.Conv2d(last_channel, outchannel, 1))
45 | self.mlp_bn2ds.append(nn.BatchNorm2d(outchannel))
46 | last_channel = outchannel
47 |
48 | def forward(self, features):
49 | features = features.permute(0, 3, 2, 1)
50 | for i, conv in enumerate(self.mlp_conv2ds):
51 | bn = self.mlp_bn2ds[i]
52 | features = F.relu(bn(conv(features)))
53 | features = features.permute(0, 3, 2, 1)
54 | return features
55 |
56 |
57 |
58 | class CoeffGeneration(nn.Module):
59 | def __init__(self, inchannel):
60 | super(CoeffGeneration, self).__init__()
61 | self.inchannel = inchannel
62 | self.MlP2D = MLP2D(self.inchannel, gac_par[1])
63 | self.MlP2D_2 = MLP2D(32, gac_par[2])
64 | self.conv1 = nn.Conv2d(64, 16 + self.inchannel, 1)
65 | self.bn1 = nn.BatchNorm2d(16 + self.inchannel)
66 |
67 | def forward(self, grouped_features, features, grouped_xyz, mode='with_feature'):
68 | if mode == 'with_feature':
69 | coeff = grouped_features - features.unsqueeze(dim=2)
70 | coeff = self.MlP2D(coeff)
71 | coeff = torch.cat((grouped_xyz, coeff), dim=-1)
72 | if mode == 'edge_only':
73 | coeff = grouped_xyz
74 | if mode == 'feature_only':
75 | coeff = grouped_features - features.unsqueeze(dim=2)
76 | coeff = self.MlP2D(coeff)
77 |
78 | grouped_features = torch.cat((grouped_xyz, grouped_features), dim=-1)
79 | coeff = self.MlP2D_2(coeff)
80 | coeff = coeff.permute(0, 3, 2, 1)
81 | coeff = self.bn1(self.conv1(coeff))
82 | coeff = coeff.permute(0, 3, 2, 1)
83 | coeff = F.softmax(coeff, dim=2)
84 |
85 | grouped_features = coeff * grouped_features
86 | grouped_features = torch.sum(grouped_features, dim=2)
87 | return grouped_features
88 |
89 | class GraphAttentionConvLayer(nn.Module):
90 | def __init__(self, feature_inchannel, mlp1, mlp2):
91 | super(GraphAttentionConvLayer, self).__init__()
92 | self.mlp1 = mlp1
93 | self.mlp2 = mlp2
94 | self.feature_inchannel = feature_inchannel
95 | self.edge_mapping = MLP2D(3, gac_par[0])
96 | self.MLP1D = MLP1D(self.feature_inchannel, self.mlp1)
97 | self.MLP1D_2 = MLP1D(2 * self.mlp1[-1] + 16, self.mlp2)
98 | self.coeff_generation = CoeffGeneration(self.mlp1[-1])
99 |
100 | def forward(self, graph, features):
101 | xyz, ids = graph['vertex'], graph['adjids']
102 | grouped_xyz = pointnet2_utils.grouping_operation(xyz.permute(0, 2, 1).contiguous(), ids).permute(0, 2, 3, 1).contiguous()
103 | grouped_xyz -= xyz.unsqueeze(dim=2)
104 | grouped_xyz = self.edge_mapping(grouped_xyz)
105 |
106 | features = self.MLP1D(features)
107 | grouped_features = pointnet2_utils.grouping_operation(features.permute(0, 2, 1).contiguous(), ids).permute(0, 2, 3, 1).contiguous()
108 |
109 | new_features = self.coeff_generation(grouped_features, features, grouped_xyz)
110 | if self.mlp2 is not None and features is not None:
111 | new_features = torch.cat((features, new_features), dim=-1)
112 | new_features = self.MLP1D_2(new_features)
113 | return new_features
114 |
115 |
116 | class GraphPoolingLayer(nn.Module):
117 | def __init__(self, pooling='max'):
118 | super(GraphPoolingLayer, self).__init__()
119 | self.pooling = pooling
120 |
121 | def forward(self, features, coarse_map):
122 | grouped_features = pointnet2_utils.grouping_operation(features.permute(0, 2, 1).contiguous(), coarse_map).permute(0, 2, 3, 1).contiguous()
123 | if self.pooling == 'max':
124 | new_features = torch.max(grouped_features, dim=2)[0]
125 | return new_features
126 |
127 | class PointUpsampleLayer(nn.Module):
128 | def __init__(self, inchannel, upsample_parm):
129 | super(PointUpsampleLayer, self).__init__()
130 | self.inchannel = inchannel
131 | self.upsample_list = upsample_parm
132 | self.MLP1D = MLP1D(self.inchannel, self.upsample_list)
133 |
134 |
135 | def forward(self, xyz1, xyz2, features1, features2):
136 | B, N, C = xyz1.shape
137 | _, S, _ = xyz2.shape
138 |
139 | features1 = features1.permute(0, 2, 1)
140 | features2 = features2.permute(0, 2, 1)
141 | assert xyz1.is_contiguous()
142 | assert xyz2.is_contiguous()
143 | dist, idx = pointnet2_utils.three_nn(xyz1, xyz2)
144 | dist_recip = 1.0 / (dist + 1e-8)
145 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
146 | weight = dist_recip / norm
147 | assert features2.is_contiguous()
148 | interpolated_features = pointnet2_utils.three_interpolate(features2, idx, weight)
149 |
150 | new_features = torch.cat((interpolated_features, features1), dim=1).permute(0, 2, 1)
151 |
152 | if self.upsample_list is not None:
153 | new_features = self.MLP1D(new_features)
154 | return new_features
155 |
156 | class GraphAttentionConvLayerforFeatureRefine(nn.Module):
157 | def __init__(self, inchannel):
158 | super(GraphAttentionConvLayerforFeatureRefine, self).__init__()
159 | self.inchannel = inchannel
160 | self.edge_mapping = MLP2D(6, gac_par[0])
161 | self.coeff_generation = CoeffGeneration(self.inchannel)
162 | self.conv1 = nn.Conv1d(16 + 2 * self.inchannel, self.inchannel, 1)
163 |
164 | def forward(self, initf, features, ids):
165 | initf = initf.permute(0, 2, 1).contiguous()
166 | grouped_initf = pointnet2_utils.grouping_operation(initf, ids).permute(0, 2, 3, 1).contiguous()
167 | grouped_initf -= initf.permute(0, 2, 1).unsqueeze(dim=2)
168 | grouped_initf = self.edge_mapping(grouped_initf)
169 | features = features.permute(0, 2, 1).contiguous()
170 | grouped_features = pointnet2_utils.grouping_operation(features, ids).permute(0, 2, 3, 1).contiguous()
171 | features = features.permute(0, 2, 1)
172 |
173 | new_features = self.coeff_generation(grouped_features, features, grouped_initf)
174 |
175 | new_features = torch.cat((features, new_features), dim=-1)
176 | new_features = new_features.permute(0, 2, 1)
177 | new_features = self.conv1(new_features)
178 | new_features = new_features.permute(0, 2, 1)
179 |
180 | return new_features
181 |
182 |
183 |
184 | def graph_coarse(xyz_org, ids_full, stride):
185 | """ Coarse graph with down sampling, and find their corresponding vertexes at previous (or father) level. """
186 | if stride > 1:
187 | sub_pts_ids = pointnet2_utils.furthest_point_sample(xyz_org, stride)
188 | sub_xyz = pointnet2_utils.gather_operation(xyz_org.permute(0, 2, 1).contiguous(), sub_pts_ids).permute(0, 2, 1).contiguous()
189 |
190 | ids = pointnet2_utils.grouping_operation(ids_full.permute(0, 2, 1).float().contiguous(),
191 | sub_pts_ids.unsqueeze(dim=-1).contiguous()).long().squeeze(-1).permute(0, 2, 1).contiguous() # (batchsize, num_point, maxsample)
192 |
193 | return sub_xyz, ids
194 | else:
195 | return xyz_org, ids_full
196 |
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/MANIFEST.in:
--------------------------------------------------------------------------------
1 | graft pointnet2_ops/_ext-src
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/__init__.py:
--------------------------------------------------------------------------------
1 | import pointnet2_ops.pointnet2_modules
2 | import pointnet2_ops.pointnet2_utils
3 | from pointnet2_ops._version import __version__
4 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5 | const int nsample);
6 | at::Tensor cube_select_sift(at::Tensor xyz, const float radius);
7 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 | #include
10 |
11 | #include
12 |
13 | #define TOTAL_THREADS 512
14 |
15 | inline int opt_n_threads(int work_size) {
16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
17 |
18 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
19 | }
20 |
21 | inline dim3 opt_block_config(int x, int y) {
22 | const int x_threads = opt_n_threads(x);
23 | const int y_threads =
24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25 | dim3 block_config(x_threads, y_threads, 1);
26 |
27 | return block_config;
28 | }
29 |
30 | #define CUDA_CHECK_ERRORS() \
31 | do { \
32 | cudaError_t err = cudaGetLastError(); \
33 | if (cudaSuccess != err) { \
34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
36 | __FILE__); \
37 | exit(-1); \
38 | } \
39 | } while (0)
40 |
41 | #endif
42 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
7 | std::vector dist_nn(at::Tensor unknowns, at::Tensor knows);
8 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
9 | at::Tensor weight);
10 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
11 | at::Tensor weight, const int m);
12 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
7 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #define CHECK_CUDA(x) \
6 | do { \
7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
8 | } while (0)
9 |
10 | #define CHECK_CONTIGUOUS(x) \
11 | do { \
12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_IS_INT(x) \
16 | do { \
17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
18 | #x " must be an int tensor"); \
19 | } while (0)
20 |
21 | #define CHECK_IS_FLOAT(x) \
22 | do { \
23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
24 | #x " must be a float tensor"); \
25 | } while (0)
26 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "utils.h"
3 |
4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
5 | int nsample, const float *new_xyz,
6 | const float *xyz, int *idx);
7 | void cube_select_sift_kernel_wrapper(int b, int n, float radius,
8 | const float *xyz, int *idx_out);
9 |
10 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
11 | const int nsample) {
12 | CHECK_CONTIGUOUS(new_xyz);
13 | CHECK_CONTIGUOUS(xyz);
14 | CHECK_IS_FLOAT(new_xyz);
15 | CHECK_IS_FLOAT(xyz);
16 |
17 | if (new_xyz.is_cuda()) {
18 | CHECK_CUDA(xyz);
19 | }
20 |
21 | at::Tensor idx =
22 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
23 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
24 |
25 | if (new_xyz.is_cuda()) {
26 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
27 | radius, nsample, new_xyz.data_ptr(),
28 | xyz.data_ptr(), idx.data_ptr());
29 | } else {
30 | AT_ASSERT(false, "CPU not supported");
31 | }
32 |
33 | return idx;
34 | }
35 |
36 |
37 | at::Tensor cube_select_sift(at::Tensor xyz, const float radius) {
38 | CHECK_CONTIGUOUS(xyz);
39 | CHECK_IS_FLOAT(xyz);
40 |
41 | if (xyz.is_cuda()) {
42 | CHECK_CUDA(xyz);
43 | }
44 |
45 | at::Tensor idx_out =
46 | torch::zeros({xyz.size(0), xyz.size(1), 8},
47 | at::device(xyz.device()).dtype(at::ScalarType::Int));
48 |
49 | if (xyz.is_cuda()) {
50 | cube_select_sift_kernel_wrapper(xyz.size(0), xyz.size(1), radius,
51 | xyz.data_ptr(), idx_out.data_ptr());
52 | } else {
53 | AT_ASSERT(false, "CPU not supported");
54 | }
55 |
56 | return idx_out;
57 | }
58 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
8 | // output: idx(b, m, nsample)
9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
10 | int nsample,
11 | const float *__restrict__ new_xyz,
12 | const float *__restrict__ xyz,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | xyz += batch_index * n * 3;
16 | new_xyz += batch_index * m * 3;
17 | idx += m * nsample * batch_index;
18 |
19 | int index = threadIdx.x;
20 | int stride = blockDim.x;
21 |
22 | float radius2 = radius * radius;
23 | for (int j = index; j < m; j += stride) {
24 | float new_x = new_xyz[j * 3 + 0];
25 | float new_y = new_xyz[j * 3 + 1];
26 | float new_z = new_xyz[j * 3 + 2];
27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
28 | float x = xyz[k * 3 + 0];
29 | float y = xyz[k * 3 + 1];
30 | float z = xyz[k * 3 + 2];
31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
32 | (new_z - z) * (new_z - z);
33 | if (d2 < radius2) {
34 | if (cnt == 0) {
35 | for (int l = 0; l < nsample; ++l) {
36 | idx[j * nsample + l] = k;
37 | }
38 | }
39 | idx[j * nsample + cnt] = k;
40 | ++cnt;
41 | }
42 | }
43 | }
44 | }
45 |
46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
47 | int nsample, const float *new_xyz,
48 | const float *xyz, int *idx) {
49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50 | query_ball_point_kernel<<>>(
51 | b, n, m, radius, nsample, new_xyz, xyz, idx);
52 |
53 | CUDA_CHECK_ERRORS();
54 | }
55 |
56 |
57 |
58 | __global__ void cube_select_sift_kernel(int b, int n,float radius, const float*__restrict__ xyz, int*__restrict__ idx_out) {
59 | int batch_idx = blockIdx.x;
60 | xyz += batch_idx * n * 3;
61 | idx_out += batch_idx * n * 8;
62 | float temp_dist[8];
63 | float judge_dist = radius * radius;
64 | for(int i = threadIdx.x; i < n;i += blockDim.x) {
65 | float x = xyz[i * 3];
66 | float y = xyz[i * 3 + 1];
67 | float z = xyz[i * 3 + 2];
68 | for(int j = 0;j < 8;j ++) {
69 | temp_dist[j] = 1e8;
70 | idx_out[i * 8 + j] = i; // if not found, just return itself..
71 | }
72 | for(int j = 0;j < n;j ++) {
73 | if(i == j) continue;
74 | float tx = xyz[j * 3];
75 | float ty = xyz[j * 3 + 1];
76 | float tz = xyz[j * 3 + 2];
77 | float dist = (x - tx) * (x - tx) + (y - ty) * (y - ty) + (z - tz) * (z - tz);
78 | if(dist > judge_dist) continue;
79 | int _x = (tx > x);
80 | int _y = (ty > y);
81 | int _z = (tz > z);
82 | int temp_idx = _x * 4 + _y * 2 + _z;
83 | if(dist < temp_dist[temp_idx]) {
84 | idx_out[i * 8 + temp_idx] = j;
85 | temp_dist[temp_idx] = dist;
86 | }
87 | }
88 | }
89 | }
90 |
91 | void cube_select_sift_kernel_wrapper(int b, int n, float radius,
92 | const float *xyz, int *idx_out) {
93 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
94 | cube_select_sift_kernel<<>>(
95 | b, n, radius, xyz, idx_out);
96 |
97 | CUDA_CHECK_ERRORS();
98 | }
99 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "group_points.h"
3 | #include "interpolate.h"
4 | #include "sampling.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7 | m.def("gather_points", &gather_points);
8 | m.def("gather_points_grad", &gather_points_grad);
9 | m.def("furthest_point_sampling", &furthest_point_sampling);
10 |
11 | m.def("three_nn", &three_nn);
12 | m.def("dist_nn", &dist_nn);
13 | m.def("three_interpolate", &three_interpolate);
14 | m.def("three_interpolate_grad", &three_interpolate_grad);
15 |
16 | m.def("ball_query", &ball_query);
17 | m.def("cube_select_sift", &cube_select_sift);
18 |
19 | m.def("group_points", &group_points);
20 | m.def("group_points_grad", &group_points_grad);
21 | }
22 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | #include "group_points.h"
2 | #include "utils.h"
3 |
4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
5 | const float *points, const int *idx,
6 | float *out);
7 |
8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
9 | int nsample, const float *grad_out,
10 | const int *idx, float *grad_points);
11 |
12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
13 | CHECK_CONTIGUOUS(points);
14 | CHECK_CONTIGUOUS(idx);
15 | CHECK_IS_FLOAT(points);
16 | CHECK_IS_INT(idx);
17 |
18 | if (points.is_cuda()) {
19 | CHECK_CUDA(idx);
20 | }
21 |
22 | at::Tensor output =
23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
24 | at::device(points.device()).dtype(at::ScalarType::Float));
25 |
26 | if (points.is_cuda()) {
27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
28 | idx.size(1), idx.size(2),
29 | points.data_ptr(), idx.data_ptr(),
30 | output.data_ptr());
31 | } else {
32 | AT_ASSERT(false, "CPU not supported");
33 | }
34 |
35 | return output;
36 | }
37 |
38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
39 | CHECK_CONTIGUOUS(grad_out);
40 | CHECK_CONTIGUOUS(idx);
41 | CHECK_IS_FLOAT(grad_out);
42 | CHECK_IS_INT(idx);
43 |
44 | if (grad_out.is_cuda()) {
45 | CHECK_CUDA(idx);
46 | }
47 |
48 | at::Tensor output =
49 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
50 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
51 |
52 | if (grad_out.is_cuda()) {
53 | group_points_grad_kernel_wrapper(
54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
55 | grad_out.data_ptr(), idx.data_ptr(),
56 | output.data_ptr());
57 | } else {
58 | AT_ASSERT(false, "CPU not supported");
59 | }
60 |
61 | return output;
62 | }
63 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, npoints, nsample)
7 | // output: out(b, c, npoints, nsample)
8 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
9 | int nsample,
10 | const float *__restrict__ points,
11 | const int *__restrict__ idx,
12 | float *__restrict__ out) {
13 | int batch_index = blockIdx.x;
14 | points += batch_index * n * c;
15 | idx += batch_index * npoints * nsample;
16 | out += batch_index * npoints * nsample * c;
17 |
18 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
19 | const int stride = blockDim.y * blockDim.x;
20 | for (int i = index; i < c * npoints; i += stride) {
21 | const int l = i / npoints;
22 | const int j = i % npoints;
23 | for (int k = 0; k < nsample; ++k) {
24 | int ii = idx[j * nsample + k];
25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
26 | }
27 | }
28 | }
29 |
30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
31 | const float *points, const int *idx,
32 | float *out) {
33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34 |
35 | group_points_kernel<<>>(
36 | b, c, n, npoints, nsample, points, idx, out);
37 |
38 | CUDA_CHECK_ERRORS();
39 | }
40 |
41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
42 | // output: grad_points(b, c, n)
43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
44 | int nsample,
45 | const float *__restrict__ grad_out,
46 | const int *__restrict__ idx,
47 | float *__restrict__ grad_points) {
48 | int batch_index = blockIdx.x;
49 | grad_out += batch_index * npoints * nsample * c;
50 | idx += batch_index * npoints * nsample;
51 | grad_points += batch_index * n * c;
52 |
53 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
54 | const int stride = blockDim.y * blockDim.x;
55 | for (int i = index; i < c * npoints; i += stride) {
56 | const int l = i / npoints;
57 | const int j = i % npoints;
58 | for (int k = 0; k < nsample; ++k) {
59 | int ii = idx[j * nsample + k];
60 | atomicAdd(grad_points + l * n + ii,
61 | grad_out[(l * npoints + j) * nsample + k]);
62 | }
63 | }
64 | }
65 |
66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
67 | int nsample, const float *grad_out,
68 | const int *idx, float *grad_points) {
69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
70 |
71 | group_points_grad_kernel<<>>(
72 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
73 |
74 | CUDA_CHECK_ERRORS();
75 | }
76 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | #include "interpolate.h"
2 | #include "utils.h"
3 |
4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
5 | const float *known, float *dist2, int *idx);
6 | void dist_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
7 | const float *known, float *dist2);
8 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
9 | const float *points, const int *idx,
10 | const float *weight, float *out);
11 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
12 | const float *grad_out,
13 | const int *idx, const float *weight,
14 | float *grad_points);
15 |
16 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
17 | CHECK_CONTIGUOUS(unknowns);
18 | CHECK_CONTIGUOUS(knows);
19 | CHECK_IS_FLOAT(unknowns);
20 | CHECK_IS_FLOAT(knows);
21 |
22 | if (unknowns.is_cuda()) {
23 | CHECK_CUDA(knows);
24 | }
25 |
26 | at::Tensor idx =
27 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
28 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
29 | at::Tensor dist2 =
30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
31 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
32 |
33 | if (unknowns.is_cuda()) {
34 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
35 | unknowns.data_ptr(), knows.data_ptr(),
36 | dist2.data_ptr(), idx.data_ptr());
37 | } else {
38 | AT_ASSERT(false, "CPU not supported");
39 | }
40 |
41 | return {dist2, idx};
42 | }
43 |
44 |
45 | std::vector dist_nn(at::Tensor unknowns, at::Tensor knows) {
46 | CHECK_CONTIGUOUS(unknowns);
47 | CHECK_CONTIGUOUS(knows);
48 | CHECK_IS_FLOAT(unknowns);
49 | CHECK_IS_FLOAT(knows);
50 |
51 | if (unknowns.is_cuda()) {
52 | CHECK_CUDA(knows);
53 | }
54 |
55 | at::Tensor dist2 =
56 | torch::zeros({unknowns.size(0), unknowns.size(1), knows.size(1)},
57 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
58 |
59 | if (unknowns.is_cuda()) {
60 | dist_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
61 | unknowns.data_ptr(), knows.data_ptr(),
62 | dist2.data_ptr());
63 | } else {
64 | AT_ASSERT(false, "CPU not supported");
65 | }
66 |
67 | return {dist2};
68 | }
69 |
70 |
71 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
72 | at::Tensor weight) {
73 | CHECK_CONTIGUOUS(points);
74 | CHECK_CONTIGUOUS(idx);
75 | CHECK_CONTIGUOUS(weight);
76 | CHECK_IS_FLOAT(points);
77 | CHECK_IS_INT(idx);
78 | CHECK_IS_FLOAT(weight);
79 |
80 | if (points.is_cuda()) {
81 | CHECK_CUDA(idx);
82 | CHECK_CUDA(weight);
83 | }
84 |
85 | at::Tensor output =
86 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
87 | at::device(points.device()).dtype(at::ScalarType::Float));
88 |
89 | if (points.is_cuda()) {
90 | three_interpolate_kernel_wrapper(
91 | points.size(0), points.size(1), points.size(2), idx.size(1),
92 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(),
93 | output.data_ptr());
94 | } else {
95 | AT_ASSERT(false, "CPU not supported");
96 | }
97 |
98 | return output;
99 | }
100 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
101 | at::Tensor weight, const int m) {
102 | CHECK_CONTIGUOUS(grad_out);
103 | CHECK_CONTIGUOUS(idx);
104 | CHECK_CONTIGUOUS(weight);
105 | CHECK_IS_FLOAT(grad_out);
106 | CHECK_IS_INT(idx);
107 | CHECK_IS_FLOAT(weight);
108 |
109 | if (grad_out.is_cuda()) {
110 | CHECK_CUDA(idx);
111 | CHECK_CUDA(weight);
112 | }
113 |
114 | at::Tensor output =
115 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
116 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
117 |
118 | if (grad_out.is_cuda()) {
119 | three_interpolate_grad_kernel_wrapper(
120 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
121 | grad_out.data_ptr(), idx.data_ptr(),
122 | weight.data_ptr(), output.data_ptr());
123 | } else {
124 | AT_ASSERT(false, "CPU not supported");
125 | }
126 |
127 | return output;
128 | }
129 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: unknown(b, n, 3) known(b, m, 3)
8 | // output: dist2(b, n, 3), idx(b, n, 3)
9 | __global__ void three_nn_kernel(int b, int n, int m,
10 | const float *__restrict__ unknown,
11 | const float *__restrict__ known,
12 | float *__restrict__ dist2,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | unknown += batch_index * n * 3;
16 | known += batch_index * m * 3;
17 | dist2 += batch_index * n * 3;
18 | idx += batch_index * n * 3;
19 |
20 | int index = threadIdx.x;
21 | int stride = blockDim.x;
22 | for (int j = index; j < n; j += stride) {
23 | float ux = unknown[j * 3 + 0];
24 | float uy = unknown[j * 3 + 1];
25 | float uz = unknown[j * 3 + 2];
26 |
27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
28 | int besti1 = 0, besti2 = 0, besti3 = 0;
29 | for (int k = 0; k < m; ++k) {
30 | float x = known[k * 3 + 0];
31 | float y = known[k * 3 + 1];
32 | float z = known[k * 3 + 2];
33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
34 | if (d < best1) {
35 | best3 = best2;
36 | besti3 = besti2;
37 | best2 = best1;
38 | besti2 = besti1;
39 | best1 = d;
40 | besti1 = k;
41 | } else if (d < best2) {
42 | best3 = best2;
43 | besti3 = besti2;
44 | best2 = d;
45 | besti2 = k;
46 | } else if (d < best3) {
47 | best3 = d;
48 | besti3 = k;
49 | }
50 | }
51 | dist2[j * 3 + 0] = best1;
52 | dist2[j * 3 + 1] = best2;
53 | dist2[j * 3 + 2] = best3;
54 |
55 | idx[j * 3 + 0] = besti1;
56 | idx[j * 3 + 1] = besti2;
57 | idx[j * 3 + 2] = besti3;
58 | }
59 | }
60 |
61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
62 | const float *known, float *dist2, int *idx) {
63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
64 | three_nn_kernel<<>>(b, n, m, unknown, known,
65 | dist2, idx);
66 |
67 | CUDA_CHECK_ERRORS();
68 | }
69 |
70 |
71 | // input: unknown(b, n, 3) known(b, m, 3)
72 | // output: dist2(b, n, m)
73 | __global__ void dist_nn_kernel(int b, int n, int m,
74 | const float *__restrict__ unknown,
75 | const float *__restrict__ known,
76 | float *__restrict__ dist2) {
77 | int batch_index = blockIdx.x;
78 | unknown += batch_index * n * 3;
79 | known += batch_index * m * 3;
80 | dist2 += batch_index * n * m;
81 |
82 | double best1 = 1e40;
83 | int index = threadIdx.x;
84 | int stride = blockDim.x;
85 | for (int j = index; j < n; j += stride) {
86 | float ux = unknown[j * 3 + 0];
87 | float uy = unknown[j * 3 + 1];
88 | float uz = unknown[j * 3 + 2];
89 |
90 | for (int k = 0; k < m; ++k) {
91 | float x = known[k * 3 + 0];
92 | float y = known[k * 3 + 1];
93 | float z = known[k * 3 + 2];
94 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
95 |
96 | dist2[j + k] = d;}
97 | }
98 | }
99 |
100 | void dist_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
101 | const float *known, float *dist2) {
102 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
103 | dist_nn_kernel<<>>(b, n, m, unknown, known,
104 | dist2);
105 |
106 | CUDA_CHECK_ERRORS();
107 | }
108 |
109 |
110 |
111 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
112 | // output: out(b, c, n)
113 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
114 | const float *__restrict__ points,
115 | const int *__restrict__ idx,
116 | const float *__restrict__ weight,
117 | float *__restrict__ out) {
118 | int batch_index = blockIdx.x;
119 | points += batch_index * m * c;
120 |
121 | idx += batch_index * n * 3;
122 | weight += batch_index * n * 3;
123 |
124 | out += batch_index * n * c;
125 |
126 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
127 | const int stride = blockDim.y * blockDim.x;
128 | for (int i = index; i < c * n; i += stride) {
129 | const int l = i / n;
130 | const int j = i % n;
131 | float w1 = weight[j * 3 + 0];
132 | float w2 = weight[j * 3 + 1];
133 | float w3 = weight[j * 3 + 2];
134 |
135 | int i1 = idx[j * 3 + 0];
136 | int i2 = idx[j * 3 + 1];
137 | int i3 = idx[j * 3 + 2];
138 |
139 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
140 | points[l * m + i3] * w3;
141 | }
142 | }
143 |
144 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
145 | const float *points, const int *idx,
146 | const float *weight, float *out) {
147 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
148 | three_interpolate_kernel<<>>(
149 | b, c, m, n, points, idx, weight, out);
150 |
151 | CUDA_CHECK_ERRORS();
152 | }
153 |
154 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
155 | // output: grad_points(b, c, m)
156 |
157 | __global__ void three_interpolate_grad_kernel(
158 | int b, int c, int n, int m, const float *__restrict__ grad_out,
159 | const int *__restrict__ idx, const float *__restrict__ weight,
160 | float *__restrict__ grad_points) {
161 | int batch_index = blockIdx.x;
162 | grad_out += batch_index * n * c;
163 | idx += batch_index * n * 3;
164 | weight += batch_index * n * 3;
165 | grad_points += batch_index * m * c;
166 |
167 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
168 | const int stride = blockDim.y * blockDim.x;
169 | for (int i = index; i < c * n; i += stride) {
170 | const int l = i / n;
171 | const int j = i % n;
172 | float w1 = weight[j * 3 + 0];
173 | float w2 = weight[j * 3 + 1];
174 | float w3 = weight[j * 3 + 2];
175 |
176 | int i1 = idx[j * 3 + 0];
177 | int i2 = idx[j * 3 + 1];
178 | int i3 = idx[j * 3 + 2];
179 |
180 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
181 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
182 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
183 | }
184 | }
185 |
186 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
187 | const float *grad_out,
188 | const int *idx, const float *weight,
189 | float *grad_points) {
190 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
191 | three_interpolate_grad_kernel<<>>(
192 | b, c, n, m, grad_out, idx, weight, grad_points);
193 |
194 | CUDA_CHECK_ERRORS();
195 | }
196 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | #include "sampling.h"
2 | #include "utils.h"
3 |
4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
5 | const float *points, const int *idx,
6 | float *out);
7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
8 | const float *grad_out, const int *idx,
9 | float *grad_points);
10 |
11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
12 | const float *dataset, float *temp,
13 | int *idxs);
14 |
15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
16 | CHECK_CONTIGUOUS(points);
17 | CHECK_CONTIGUOUS(idx);
18 | CHECK_IS_FLOAT(points);
19 | CHECK_IS_INT(idx);
20 |
21 | if (points.is_cuda()) {
22 | CHECK_CUDA(idx);
23 | }
24 |
25 | at::Tensor output =
26 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
27 | at::device(points.device()).dtype(at::ScalarType::Float));
28 |
29 | if (points.is_cuda()) {
30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
31 | idx.size(1), points.data_ptr(),
32 | idx.data_ptr(), output.data_ptr());
33 | } else {
34 | AT_ASSERT(false, "CPU not supported");
35 | }
36 |
37 | return output;
38 | }
39 |
40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
41 | const int n) {
42 | CHECK_CONTIGUOUS(grad_out);
43 | CHECK_CONTIGUOUS(idx);
44 | CHECK_IS_FLOAT(grad_out);
45 | CHECK_IS_INT(idx);
46 |
47 | if (grad_out.is_cuda()) {
48 | CHECK_CUDA(idx);
49 | }
50 |
51 | at::Tensor output =
52 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
53 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
54 |
55 | if (grad_out.is_cuda()) {
56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
57 | idx.size(1), grad_out.data_ptr(),
58 | idx.data_ptr(),
59 | output.data_ptr());
60 | } else {
61 | AT_ASSERT(false, "CPU not supported");
62 | }
63 |
64 | return output;
65 | }
66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
67 | CHECK_CONTIGUOUS(points);
68 | CHECK_IS_FLOAT(points);
69 |
70 | at::Tensor output =
71 | torch::zeros({points.size(0), nsamples},
72 | at::device(points.device()).dtype(at::ScalarType::Int));
73 |
74 | at::Tensor tmp =
75 | torch::full({points.size(0), points.size(1)}, 1e10,
76 | at::device(points.device()).dtype(at::ScalarType::Float));
77 |
78 | if (points.is_cuda()) {
79 | furthest_point_sampling_kernel_wrapper(
80 | points.size(0), points.size(1), nsamples, points.data_ptr(),
81 | tmp.data_ptr(), output.data_ptr());
82 | } else {
83 | AT_ASSERT(false, "CPU not supported");
84 | }
85 |
86 | return output;
87 | }
88 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, m)
7 | // output: out(b, c, m)
8 | __global__ void gather_points_kernel(int b, int c, int n, int m,
9 | const float *__restrict__ points,
10 | const int *__restrict__ idx,
11 | float *__restrict__ out) {
12 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
13 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
14 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
15 | int a = idx[i * m + j];
16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
17 | }
18 | }
19 | }
20 | }
21 |
22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
23 | const float *points, const int *idx,
24 | float *out) {
25 | gather_points_kernel<<>>(b, c, n, npoints,
27 | points, idx, out);
28 |
29 | CUDA_CHECK_ERRORS();
30 | }
31 |
32 | // input: grad_out(b, c, m) idx(b, m)
33 | // output: grad_points(b, c, n)
34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
35 | const float *__restrict__ grad_out,
36 | const int *__restrict__ idx,
37 | float *__restrict__ grad_points) {
38 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
39 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
40 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
41 | int a = idx[i * m + j];
42 | atomicAdd(grad_points + (i * c + l) * n + a,
43 | grad_out[(i * c + l) * m + j]);
44 | }
45 | }
46 | }
47 | }
48 |
49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
50 | const float *grad_out, const int *idx,
51 | float *grad_points) {
52 | gather_points_grad_kernel<<>>(
54 | b, c, n, npoints, grad_out, idx, grad_points);
55 |
56 | CUDA_CHECK_ERRORS();
57 | }
58 |
59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
60 | int idx1, int idx2) {
61 | const float v1 = dists[idx1], v2 = dists[idx2];
62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
63 | dists[idx1] = max(v1, v2);
64 | dists_i[idx1] = v2 > v1 ? i2 : i1;
65 | }
66 |
67 | // Input dataset: (b, n, 3), tmp: (b, n)
68 | // Ouput idxs (b, m)
69 | template
70 | __global__ void furthest_point_sampling_kernel(
71 | int b, int n, int m, const float *__restrict__ dataset,
72 | float *__restrict__ temp, int *__restrict__ idxs) {
73 | if (m <= 0) return;
74 | __shared__ float dists[block_size];
75 | __shared__ int dists_i[block_size];
76 |
77 | int batch_index = blockIdx.x;
78 | dataset += batch_index * n * 3;
79 | temp += batch_index * n;
80 | idxs += batch_index * m;
81 |
82 | int tid = threadIdx.x;
83 | const int stride = block_size;
84 |
85 | int old = 0;
86 | if (threadIdx.x == 0) idxs[0] = old;
87 |
88 | __syncthreads();
89 | for (int j = 1; j < m; j++) {
90 | int besti = 0;
91 | float best = -1;
92 | float x1 = dataset[old * 3 + 0];
93 | float y1 = dataset[old * 3 + 1];
94 | float z1 = dataset[old * 3 + 2];
95 | for (int k = tid; k < n; k += stride) {
96 | float x2, y2, z2;
97 | x2 = dataset[k * 3 + 0];
98 | y2 = dataset[k * 3 + 1];
99 | z2 = dataset[k * 3 + 2];
100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
101 | if (mag <= 1e-3) continue;
102 |
103 | float d =
104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
105 |
106 | float d2 = min(d, temp[k]);
107 | temp[k] = d2;
108 | besti = d2 > best ? k : besti;
109 | best = d2 > best ? d2 : best;
110 | }
111 | dists[tid] = best;
112 | dists_i[tid] = besti;
113 | __syncthreads();
114 |
115 | if (block_size >= 512) {
116 | if (tid < 256) {
117 | __update(dists, dists_i, tid, tid + 256);
118 | }
119 | __syncthreads();
120 | }
121 | if (block_size >= 256) {
122 | if (tid < 128) {
123 | __update(dists, dists_i, tid, tid + 128);
124 | }
125 | __syncthreads();
126 | }
127 | if (block_size >= 128) {
128 | if (tid < 64) {
129 | __update(dists, dists_i, tid, tid + 64);
130 | }
131 | __syncthreads();
132 | }
133 | if (block_size >= 64) {
134 | if (tid < 32) {
135 | __update(dists, dists_i, tid, tid + 32);
136 | }
137 | __syncthreads();
138 | }
139 | if (block_size >= 32) {
140 | if (tid < 16) {
141 | __update(dists, dists_i, tid, tid + 16);
142 | }
143 | __syncthreads();
144 | }
145 | if (block_size >= 16) {
146 | if (tid < 8) {
147 | __update(dists, dists_i, tid, tid + 8);
148 | }
149 | __syncthreads();
150 | }
151 | if (block_size >= 8) {
152 | if (tid < 4) {
153 | __update(dists, dists_i, tid, tid + 4);
154 | }
155 | __syncthreads();
156 | }
157 | if (block_size >= 4) {
158 | if (tid < 2) {
159 | __update(dists, dists_i, tid, tid + 2);
160 | }
161 | __syncthreads();
162 | }
163 | if (block_size >= 2) {
164 | if (tid < 1) {
165 | __update(dists, dists_i, tid, tid + 1);
166 | }
167 | __syncthreads();
168 | }
169 |
170 | old = dists_i[0];
171 | if (tid == 0) idxs[j] = old;
172 | }
173 | }
174 |
175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
176 | const float *dataset, float *temp,
177 | int *idxs) {
178 | unsigned int n_threads = opt_n_threads(n);
179 |
180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
181 |
182 | switch (n_threads) {
183 | case 512:
184 | furthest_point_sampling_kernel<512>
185 | <<>>(b, n, m, dataset, temp, idxs);
186 | break;
187 | case 256:
188 | furthest_point_sampling_kernel<256>
189 | <<>>(b, n, m, dataset, temp, idxs);
190 | break;
191 | case 128:
192 | furthest_point_sampling_kernel<128>
193 | <<>>(b, n, m, dataset, temp, idxs);
194 | break;
195 | case 64:
196 | furthest_point_sampling_kernel<64>
197 | <<>>(b, n, m, dataset, temp, idxs);
198 | break;
199 | case 32:
200 | furthest_point_sampling_kernel<32>
201 | <<>>(b, n, m, dataset, temp, idxs);
202 | break;
203 | case 16:
204 | furthest_point_sampling_kernel<16>
205 | <<>>(b, n, m, dataset, temp, idxs);
206 | break;
207 | case 8:
208 | furthest_point_sampling_kernel<8>
209 | <<>>(b, n, m, dataset, temp, idxs);
210 | break;
211 | case 4:
212 | furthest_point_sampling_kernel<4>
213 | <<>>(b, n, m, dataset, temp, idxs);
214 | break;
215 | case 2:
216 | furthest_point_sampling_kernel<2>
217 | <<>>(b, n, m, dataset, temp, idxs);
218 | break;
219 | case 1:
220 | furthest_point_sampling_kernel<1>
221 | <<>>(b, n, m, dataset, temp, idxs);
222 | break;
223 | default:
224 | furthest_point_sampling_kernel<512>
225 | <<>>(b, n, m, dataset, temp, idxs);
226 | }
227 |
228 | CUDA_CHECK_ERRORS();
229 | }
230 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = "3.0.0"
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from pointnet2_ops import pointnet2_utils
7 |
8 |
9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
10 | layers = []
11 | for i in range(1, len(mlp_spec)):
12 | layers.append(
13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
14 | )
15 | if bn:
16 | layers.append(nn.BatchNorm2d(mlp_spec[i]))
17 | layers.append(nn.ReLU(True))
18 |
19 | return nn.Sequential(*layers)
20 |
21 |
22 | class _PointnetSAModuleBase(nn.Module):
23 | def __init__(self):
24 | super(_PointnetSAModuleBase, self).__init__()
25 | self.npoint = None
26 | self.groupers = None
27 | self.mlps = None
28 |
29 | def forward(
30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor]
31 | ) -> Tuple[torch.Tensor, torch.Tensor]:
32 | r"""
33 | Parameters
34 | ----------
35 | xyz : torch.Tensor
36 | (B, N, 3) tensor of the xyz coordinates of the features
37 | features : torch.Tensor
38 | (B, C, N) tensor of the descriptors of the the features
39 |
40 | Returns
41 | -------
42 | new_xyz : torch.Tensor
43 | (B, npoint, 3) tensor of the new features' xyz
44 | new_features : torch.Tensor
45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
46 | """
47 |
48 | new_features_list = []
49 |
50 | xyz_flipped = xyz.transpose(1, 2).contiguous()
51 | new_xyz = (
52 | pointnet2_utils.gather_operation(
53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
54 | )
55 | .transpose(1, 2)
56 | .contiguous()
57 | if self.npoint is not None
58 | else None
59 | )
60 |
61 | for i in range(len(self.groupers)):
62 | new_features = self.groupers[i](
63 | xyz, new_xyz, features
64 | ) # (B, C, npoint, nsample)
65 |
66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
67 | new_features = F.max_pool2d(
68 | new_features, kernel_size=[1, new_features.size(3)]
69 | ) # (B, mlp[-1], npoint, 1)
70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
71 |
72 | new_features_list.append(new_features)
73 |
74 | return new_xyz, torch.cat(new_features_list, dim=1)
75 |
76 |
77 | class PointnetSAModuleMSG(_PointnetSAModuleBase):
78 | r"""Pointnet set abstrction layer with multiscale grouping
79 |
80 | Parameters
81 | ----------
82 | npoint : int
83 | Number of features
84 | radii : list of float32
85 | list of radii to group with
86 | nsamples : list of int32
87 | Number of samples in each ball query
88 | mlps : list of list of int32
89 | Spec of the pointnet before the global max_pool for each scale
90 | bn : bool
91 | Use batchnorm
92 | """
93 |
94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
96 | super(PointnetSAModuleMSG, self).__init__()
97 |
98 | assert len(radii) == len(nsamples) == len(mlps)
99 |
100 | self.npoint = npoint
101 | self.groupers = nn.ModuleList()
102 | self.mlps = nn.ModuleList()
103 | for i in range(len(radii)):
104 | radius = radii[i]
105 | nsample = nsamples[i]
106 | self.groupers.append(
107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
108 | if npoint is not None
109 | else pointnet2_utils.GroupAll(use_xyz)
110 | )
111 | mlp_spec = mlps[i]
112 | if use_xyz:
113 | mlp_spec[0] += 3
114 |
115 | self.mlps.append(build_shared_mlp(mlp_spec, bn))
116 |
117 |
118 | class PointnetSAModule(PointnetSAModuleMSG):
119 | r"""Pointnet set abstrction layer
120 |
121 | Parameters
122 | ----------
123 | npoint : int
124 | Number of features
125 | radius : float
126 | Radius of ball
127 | nsample : int
128 | Number of samples in the ball query
129 | mlp : list
130 | Spec of the pointnet before the global max_pool
131 | bn : bool
132 | Use batchnorm
133 | """
134 |
135 | def __init__(
136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
137 | ):
138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
139 | super(PointnetSAModule, self).__init__(
140 | mlps=[mlp],
141 | npoint=npoint,
142 | radii=[radius],
143 | nsamples=[nsample],
144 | bn=bn,
145 | use_xyz=use_xyz,
146 | )
147 |
148 |
149 | class PointnetFPModule(nn.Module):
150 | r"""Propigates the features of one set to another
151 |
152 | Parameters
153 | ----------
154 | mlp : list
155 | Pointnet module parameters
156 | bn : bool
157 | Use batchnorm
158 | """
159 |
160 | def __init__(self, mlp, bn=True):
161 | # type: (PointnetFPModule, List[int], bool) -> None
162 | super(PointnetFPModule, self).__init__()
163 | self.mlp = build_shared_mlp(mlp, bn=bn)
164 |
165 | def forward(self, unknown, known, unknow_feats, known_feats):
166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
167 | r"""
168 | Parameters
169 | ----------
170 | unknown : torch.Tensor
171 | (B, n, 3) tensor of the xyz positions of the unknown features
172 | known : torch.Tensor
173 | (B, m, 3) tensor of the xyz positions of the known features
174 | unknow_feats : torch.Tensor
175 | (B, C1, n) tensor of the features to be propigated to
176 | known_feats : torch.Tensor
177 | (B, C2, m) tensor of features to be propigated
178 |
179 | Returns
180 | -------
181 | new_features : torch.Tensor
182 | (B, mlp[-1], n) tensor of the features of the unknown features
183 | """
184 |
185 | if known is not None:
186 | dist, idx = pointnet2_utils.three_nn(unknown, known)
187 | dist_recip = 1.0 / (dist + 1e-8)
188 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
189 | weight = dist_recip / norm
190 |
191 | interpolated_feats = pointnet2_utils.three_interpolate(
192 | known_feats, idx, weight
193 | )
194 | else:
195 | interpolated_feats = known_feats.expand(
196 | *(known_feats.size()[0:2] + [unknown.size(1)])
197 | )
198 |
199 | if unknow_feats is not None:
200 | new_features = torch.cat(
201 | [interpolated_feats, unknow_feats], dim=1
202 | ) # (B, C2 + C1, n)
203 | else:
204 | new_features = interpolated_feats
205 |
206 | new_features = new_features.unsqueeze(-1)
207 | new_features = self.mlp(new_features)
208 |
209 | return new_features.squeeze(-1)
210 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | from torch.autograd import Function
5 | from typing import *
6 |
7 | try:
8 | import pointnet2_ops._ext as _ext
9 | except ImportError:
10 | from torch.utils.cpp_extension import load
11 | import glob
12 | import os.path as osp
13 | import os
14 |
15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
16 |
17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
19 | osp.join(_ext_src_root, "src", "*.cu")
20 | )
21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
22 |
23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
24 | _ext = load(
25 | "_ext",
26 | sources=_ext_sources,
27 | extra_include_paths=[osp.join(_ext_src_root, "include")],
28 | extra_cflags=["-O3"],
29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
30 | with_cuda=True,
31 | )
32 |
33 |
34 | class FurthestPointSampling(Function):
35 | @staticmethod
36 | def forward(ctx, xyz, npoint):
37 | # type: (Any, torch.Tensor, int) -> torch.Tensor
38 | r"""
39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
40 | minimum distance
41 |
42 | Parameters
43 | ----------
44 | xyz : torch.Tensor
45 | (B, N, 3) tensor where N > npoint
46 | npoint : int32
47 | number of features in the sampled set
48 |
49 | Returns
50 | -------
51 | torch.Tensor
52 | (B, npoint) tensor containing the set
53 | """
54 | out = _ext.furthest_point_sampling(xyz, npoint)
55 |
56 | ctx.mark_non_differentiable(out)
57 |
58 | return out
59 |
60 | @staticmethod
61 | def backward(ctx, grad_out):
62 | return ()
63 |
64 |
65 | furthest_point_sample = FurthestPointSampling.apply
66 |
67 |
68 | class GatherOperation(Function):
69 | @staticmethod
70 | def forward(ctx, features, idx):
71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
72 | r"""
73 |
74 | Parameters
75 | ----------
76 | features : torch.Tensor
77 | (B, C, N) tensor
78 |
79 | idx : torch.Tensor
80 | (B, npoint) tensor of the features to gather
81 |
82 | Returns
83 | -------
84 | torch.Tensor
85 | (B, C, npoint) tensor
86 | """
87 |
88 | ctx.save_for_backward(idx, features)
89 |
90 | return _ext.gather_points(features, idx)
91 |
92 | @staticmethod
93 | def backward(ctx, grad_out):
94 | idx, features = ctx.saved_tensors
95 | N = features.size(2)
96 |
97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
98 | return grad_features, None
99 |
100 |
101 | gather_operation = GatherOperation.apply
102 |
103 |
104 | class ThreeNN(Function):
105 | @staticmethod
106 | def forward(ctx, unknown, known):
107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
108 | r"""
109 | Find the three nearest neighbors of unknown in known
110 | Parameters
111 | ----------
112 | unknown : torch.Tensor
113 | (B, n, 3) tensor of known features
114 | known : torch.Tensor
115 | (B, m, 3) tensor of unknown features
116 |
117 | Returns
118 | -------
119 | dist : torch.Tensor
120 | (B, n, 3) l2 distance to the three nearest neighbors
121 | idx : torch.Tensor
122 | (B, n, 3) index of 3 nearest neighbors
123 | """
124 | dist2, idx = _ext.three_nn(unknown, known)
125 | dist = torch.sqrt(dist2)
126 |
127 | ctx.mark_non_differentiable(dist, idx)
128 |
129 | return dist, idx
130 |
131 | @staticmethod
132 | def backward(ctx, grad_dist, grad_idx):
133 | return ()
134 |
135 |
136 | three_nn = ThreeNN.apply
137 |
138 |
139 |
140 | class DistNN(Function):
141 | @staticmethod
142 | def forward(ctx, unknown, known):
143 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
144 | r"""
145 | Find the three nearest neighbors of unknown in known
146 | Parameters
147 | ----------
148 | unknown : torch.Tensor
149 | (B, n, 3) tensor of known features
150 | known : torch.Tensor
151 | (B, m, 3) tensor of unknown features
152 |
153 | Returns
154 | -------
155 | dist : torch.Tensor
156 | (B, n, 3) l2 distance to the three nearest neighbors
157 | idx : torch.Tensor
158 | (B, n, 3) index of 3 nearest neighbors
159 | """
160 | dist2 = _ext.dist_nn(unknown, known)
161 | # dist = torch.sqrt(dist2[0])
162 |
163 | # ctx.mark_non_differentiable(dist[0])
164 | ctx.mark_non_differentiable(dist2[0])
165 |
166 | return dist2[0]
167 |
168 | @staticmethod
169 | def backward(ctx, grad_dist):
170 | return ()
171 |
172 |
173 | dist_nn = DistNN.apply
174 |
175 | class ThreeInterpolate(Function):
176 | @staticmethod
177 | def forward(ctx, features, idx, weight):
178 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
179 | r"""
180 | Performs weight linear interpolation on 3 features
181 | Parameters
182 | ----------
183 | features : torch.Tensor
184 | (B, c, m) Features descriptors to be interpolated from
185 | idx : torch.Tensor
186 | (B, n, 3) three nearest neighbors of the target features in features
187 | weight : torch.Tensor
188 | (B, n, 3) weights
189 |
190 | Returns
191 | -------
192 | torch.Tensor
193 | (B, c, n) tensor of the interpolated features
194 | """
195 | ctx.save_for_backward(idx, weight, features)
196 |
197 | return _ext.three_interpolate(features, idx, weight)
198 |
199 | @staticmethod
200 | def backward(ctx, grad_out):
201 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
202 | r"""
203 | Parameters
204 | ----------
205 | grad_out : torch.Tensor
206 | (B, c, n) tensor with gradients of ouputs
207 |
208 | Returns
209 | -------
210 | grad_features : torch.Tensor
211 | (B, c, m) tensor with gradients of features
212 |
213 | None
214 |
215 | None
216 | """
217 | idx, weight, features = ctx.saved_tensors
218 | m = features.size(2)
219 |
220 | grad_features = _ext.three_interpolate_grad(
221 | grad_out.contiguous(), idx, weight, m
222 | )
223 |
224 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
225 |
226 |
227 | three_interpolate = ThreeInterpolate.apply
228 |
229 |
230 | class GroupingOperation(Function):
231 | @staticmethod
232 | def forward(ctx, features, idx):
233 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
234 | r"""
235 |
236 | Parameters
237 | ----------
238 | features : torch.Tensor
239 | (B, C, N) tensor of features to group
240 | idx : torch.Tensor
241 | (B, npoint, nsample) tensor containing the indicies of features to group with
242 |
243 | Returns
244 | -------
245 | torch.Tensor
246 | (B, C, npoint, nsample) tensor
247 | """
248 | ctx.save_for_backward(idx, features)
249 |
250 | return _ext.group_points(features, idx)
251 |
252 | @staticmethod
253 | def backward(ctx, grad_out):
254 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
255 | r"""
256 |
257 | Parameters
258 | ----------
259 | grad_out : torch.Tensor
260 | (B, C, npoint, nsample) tensor of the gradients of the output from forward
261 |
262 | Returns
263 | -------
264 | torch.Tensor
265 | (B, C, N) gradient of the features
266 | None
267 | """
268 | idx, features = ctx.saved_tensors
269 | N = features.size(2)
270 |
271 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
272 |
273 | return grad_features, torch.zeros_like(idx)
274 |
275 |
276 | grouping_operation = GroupingOperation.apply
277 |
278 |
279 | class BallQuery(Function):
280 | @staticmethod
281 | def forward(ctx, radius, nsample, xyz, new_xyz):
282 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
283 | r"""
284 |
285 | Parameters
286 | ----------
287 | radius : float
288 | radius of the balls
289 | nsample : int
290 | maximum number of features in the balls
291 | xyz : torch.Tensor
292 | (B, N, 3) xyz coordinates of the features
293 | new_xyz : torch.Tensor
294 | (B, npoint, 3) centers of the ball query
295 |
296 | Returns
297 | -------
298 | torch.Tensor
299 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls
300 | """
301 | output = _ext.ball_query(new_xyz, xyz, radius, nsample)
302 |
303 | ctx.mark_non_differentiable(output)
304 |
305 | return output
306 |
307 | @staticmethod
308 | def backward(ctx, grad_out):
309 | return ()
310 |
311 |
312 | ball_query = BallQuery.apply
313 |
314 |
315 |
316 | class PointSIFTSelect(Function):
317 | @staticmethod
318 | def forward(ctx, radius, xyz):
319 | # type: (Any, float, torch.Tensor) -> torch.Tensor
320 | r"""
321 |
322 | Parameters
323 | ----------
324 | radius : float
325 | radius of the balls
326 | nsample : int
327 | maximum number of features in the balls
328 | xyz : torch.Tensor
329 | (B, N, 3) xyz coordinates of the features
330 | new_xyz : torch.Tensor
331 | (B, npoint, 3) centers of the ball query
332 |
333 | Returns
334 | -------
335 | torch.Tensor
336 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls
337 | """
338 | output = _ext.cube_select_sift(xyz, radius)
339 |
340 | ctx.mark_non_differentiable(output)
341 |
342 | return output
343 |
344 | @staticmethod
345 | def backward(ctx, grad_out):
346 | return ()
347 |
348 |
349 | pointsift_select = PointSIFTSelect.apply
350 |
351 | class QueryAndGroup(nn.Module):
352 | r"""
353 | Groups with a ball query of radius
354 |
355 | Parameters
356 | ---------
357 | radius : float32
358 | Radius of ball
359 | nsample : int32
360 | Maximum number of features to gather in the ball
361 | """
362 |
363 | def __init__(self, radius, nsample, use_xyz=True):
364 | # type: (QueryAndGroup, float, int, bool) -> None
365 | super(QueryAndGroup, self).__init__()
366 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
367 |
368 | def forward(self, xyz, new_xyz, features=None):
369 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
370 | r"""
371 | Parameters
372 | ----------
373 | xyz : torch.Tensor
374 | xyz coordinates of the features (B, N, 3)
375 | new_xyz : torch.Tensor
376 | centriods (B, npoint, 3)
377 | features : torch.Tensor
378 | Descriptors of the features (B, C, N)
379 |
380 | Returns
381 | -------
382 | new_features : torch.Tensor
383 | (B, 3 + C, npoint, nsample) tensor
384 | """
385 |
386 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
387 | xyz_trans = xyz.transpose(1, 2).contiguous()
388 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
389 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
390 |
391 | if features is not None:
392 | grouped_features = grouping_operation(features, idx)
393 | if self.use_xyz:
394 | new_features = torch.cat(
395 | [grouped_xyz, grouped_features], dim=1
396 | ) # (B, C + 3, npoint, nsample)
397 | else:
398 | new_features = grouped_features
399 | else:
400 | assert (
401 | self.use_xyz
402 | ), "Cannot have not features and not use xyz as a feature!"
403 | new_features = grouped_xyz
404 |
405 | return new_features
406 |
407 |
408 | class GroupAll(nn.Module):
409 | r"""
410 | Groups all features
411 |
412 | Parameters
413 | ---------
414 | """
415 |
416 | def __init__(self, use_xyz=True):
417 | # type: (GroupAll, bool) -> None
418 | super(GroupAll, self).__init__()
419 | self.use_xyz = use_xyz
420 |
421 | def forward(self, xyz, new_xyz, features=None):
422 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
423 | r"""
424 | Parameters
425 | ----------
426 | xyz : torch.Tensor
427 | xyz coordinates of the features (B, N, 3)
428 | new_xyz : torch.Tensor
429 | Ignored
430 | features : torch.Tensor
431 | Descriptors of the features (B, C, N)
432 |
433 | Returns
434 | -------
435 | new_features : torch.Tensor
436 | (B, C + 3, 1, N) tensor
437 | """
438 |
439 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
440 | if features is not None:
441 | grouped_features = features.unsqueeze(2)
442 | if self.use_xyz:
443 | new_features = torch.cat(
444 | [grouped_xyz, grouped_features], dim=1
445 | ) # (B, 3 + C, 1, N)
446 | else:
447 | new_features = grouped_features
448 | else:
449 | new_features = grouped_xyz
450 |
451 | return new_features
452 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import os.path as osp
4 |
5 | from setuptools import find_packages, setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | this_dir = osp.dirname(osp.abspath(__file__))
9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src")
10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
11 | osp.join(_ext_src_root, "src", "*.cu")
12 | )
13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
14 |
15 | requirements = ["torch>=1.4"]
16 |
17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read())
18 |
19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
20 | setup(
21 | name="pointnet2_ops",
22 | version=__version__,
23 | author="Erik Wijmans",
24 | packages=find_packages(),
25 | install_requires=requirements,
26 | ext_modules=[
27 | CUDAExtension(
28 | name="pointnet2_ops._ext",
29 | sources=_ext_sources,
30 | extra_compile_args={
31 | "cxx": ["-O3"],
32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
33 | },
34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
35 | )
36 | ],
37 | cmdclass={"build_ext": BuildExtension},
38 | include_package_data=True,
39 | )
40 |
--------------------------------------------------------------------------------
/tool/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WingkeungM/GACNet_PyTorch/c5b3701e7029cf3c6a6ff2622255fc762de6690d/tool/.DS_Store
--------------------------------------------------------------------------------
/tool/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import logging
5 | import argparse
6 |
7 | import numpy as np
8 |
9 | from tqdm import tqdm
10 | from pathlib import Path
11 | from data.S3DIS.S3DISDataLoader import S3DISDatasetWholeScene
12 | from model.gacnet import GACNet, build_graph_pyramid
13 |
14 | g_classes = [x.rstrip() for x in open('./data/s3dis/s3dis_names.txt')]
15 | g_class2label = {cls: i for i,cls in enumerate(g_classes)}
16 | g_class2color = {'ceiling': [0,255,0],
17 | 'floor': [0,0,255],
18 | 'wall': [0,255,255],
19 | 'beam': [255,255,0],
20 | 'column': [255,0,255],
21 | 'window': [100,100,255],
22 | 'door': [200,200,100],
23 | 'table': [170,120,200],
24 | 'chair': [255,0,0],
25 | 'sofa': [200,100,100],
26 | 'bookcase': [10,200,100],
27 | 'board': [200,200,200],
28 | 'clutter': [50,50,50]}
29 | g_easy_view_labels = [7,8,9,10,11,1]
30 | g_label2color = {g_classes.index(cls): g_class2color[cls] for cls in g_classes}
31 |
32 |
33 | graph_inf = {'stride_list': [1024, 256, 64, 32], #can be seen as the downsampling rate
34 | 'radius_list': [0.1, 0.2, 0.4, 0.8, 1.6], # radius for neighbor points searching
35 | 'maxsample_list': [12, 21, 21, 21, 12] #number of neighbor points for each layer
36 | }
37 |
38 | # number of units for each mlp layer
39 | forward_parm = [
40 | [ [32,32,64], [64] ],
41 | [ [64,64,128], [128] ],
42 | [ [128,128,256], [256] ],
43 | [ [256,256,512], [512] ],
44 | [ [256,256], [256] ]
45 | ]
46 |
47 | # for feature interpolation stage
48 | upsample_parm = [
49 | [128, 128],
50 | [128, 128],
51 | [256, 256],
52 | [256, 256]
53 | ]
54 |
55 | # parameters for fully connection layer
56 | fullconect_parm = 128
57 |
58 | net_inf = {'forward_parm': forward_parm,
59 | 'upsample_parm': upsample_parm,
60 | 'fullconect_parm': fullconect_parm
61 | }
62 |
63 |
64 |
65 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
66 | ROOT_DIR = BASE_DIR
67 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
68 |
69 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter']
70 | class2label = {cls: i for i,cls in enumerate(classes)}
71 | seg_classes = class2label
72 | seg_label_to_cat = {}
73 | for i,cat in enumerate(seg_classes.keys()):
74 | seg_label_to_cat[i] = cat
75 |
76 | def parse_args():
77 | '''PARAMETERS'''
78 | parser = argparse.ArgumentParser('Model')
79 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
80 | parser.add_argument('--num_point', type=int, default=4096, help='Point Number [default: 4096]')
81 | parser.add_argument('--batch_size', type=int, default=32, help='batch size in testing [default: 32]')
82 | parser.add_argument('--visual', action='store_true', default=False, help='Whether visualize result [default: False]')
83 | parser.add_argument('--log_dir', type=str, default='logs_GACNet', help='Experiment root')
84 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]')
85 | parser.add_argument('--num_votes', type=int, default=5, help='Aggregate segmentation scores with voting [default: 5]')
86 | parser.add_argument('--num_class', type=int, default=13, help='Class number of the dataset [default: 13]')
87 | parser.add_argument('--datapath', type=str, default='/workspace/datasets/stanford_indoor3d/', help='Path of the sataset')
88 | parser.add_argument('--test_model', type=str, default='/workspace/experiment/checkpoints/GACNet.pth/', help='Path of the test model')
89 |
90 | return parser.parse_args()
91 |
92 | def add_vote(vote_label_pool, point_idx, pred_label, weight):
93 | B = pred_label.shape[0]
94 | N = pred_label.shape[1]
95 | for b in range(B):
96 | for n in range(N):
97 | if weight[b,n]:
98 | vote_label_pool[int(point_idx[b, n]), int(pred_label[b, n])] += 1
99 | return vote_label_pool
100 |
101 | def main(args):
102 | def log_string(str):
103 | logger.info(str)
104 | print(str)
105 |
106 | '''HYPER PARAMETER'''
107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
108 | experiment_dir = './log/' + args.log_dir
109 | visual_dir = experiment_dir + '/visual/'
110 | visual_dir = Path(visual_dir)
111 | visual_dir.mkdir(exist_ok=True)
112 |
113 | '''LOG'''
114 | args = parse_args()
115 | logger = logging.getLogger("Model")
116 | logger.setLevel(logging.INFO)
117 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
118 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
119 | file_handler.setLevel(logging.INFO)
120 | file_handler.setFormatter(formatter)
121 | logger.addHandler(file_handler)
122 | log_string('PARAMETER ...')
123 | log_string(args)
124 |
125 |
126 | TEST_DATASET_WHOLE_SCENE = S3DISDatasetWholeScene(root=args.datapath, block_points=args.num_point, split='test',
127 | stride=0.5, block_size=1.0, padding=0.001)
128 | log_string("The number of test data is: %d" %len(TEST_DATASET_WHOLE_SCENE))
129 |
130 | '''MODEL LOADING'''
131 | classifier = GACNet(args.num_class, graph_inf, net_inf).cuda()
132 | checkpoint = torch.load(args.test_model)
133 | classifier.load_state_dict(checkpoint)
134 |
135 | with torch.no_grad():
136 | scene_id = TEST_DATASET_WHOLE_SCENE.file_list
137 | scene_id = [x[:-5] for x in scene_id]
138 | num_batches = len(TEST_DATASET_WHOLE_SCENE)
139 |
140 | total_seen_class = [0 for _ in range(args.num_class)]
141 | total_correct_class = [0 for _ in range(args.num_class)]
142 | total_iou_deno_class = [0 for _ in range(args.num_class)]
143 |
144 | total_pre_class = [0 for _ in range(args.num_class)]
145 | precision_per_class = [0 for _ in range(args.num_class)]
146 | recall_per_class = [0 for _ in range(args.num_class)]
147 |
148 | log_string('---- EVALUATION WHOLE SCENE----')
149 |
150 | for batch_idx in range(num_batches):
151 | print("visualize [%d/%d] %s ..." % (batch_idx+1, num_batches, scene_id[batch_idx]))
152 | total_seen_class_tmp = [0 for _ in range(args.num_class)]
153 | total_correct_class_tmp = [0 for _ in range(args.num_class)]
154 | total_iou_deno_class_tmp = [0 for _ in range(args.num_class)]
155 |
156 | total_pre_class_tmp = [0 for _ in range(args.num_class)]
157 | precision_per_class_tmp = [0 for _ in range(args.num_class)]
158 | recall_per_class_tmp = [0 for _ in range(args.num_class)]
159 | if args.visual:
160 | fout = open(os.path.join(visual_dir, scene_id[batch_idx] + '_pred.obj'), 'w')
161 | fout_gt = open(os.path.join(visual_dir, scene_id[batch_idx] + '_gt.obj'), 'w')
162 |
163 | whole_scene_data = TEST_DATASET_WHOLE_SCENE.scene_points_list[batch_idx]
164 | whole_scene_label = TEST_DATASET_WHOLE_SCENE.semantic_labels_list[batch_idx]
165 | vote_label_pool = np.zeros((whole_scene_label.shape[0], args.num_class))#[num_points,num_classes]
166 | for _ in tqdm(range(args.num_votes), total=args.num_votes):
167 | scene_data, scene_label, scene_smpw, scene_point_index = TEST_DATASET_WHOLE_SCENE[batch_idx]
168 | num_blocks = scene_data.shape[0]#有多少个4096个点??
169 | s_batch_num = (num_blocks + args.batch_size - 1) // args.batch_size#能凑多少个batch_size??
170 | batch_data = np.zeros((args.batch_size, args.num_point, 9))
171 |
172 | batch_label = np.zeros((args.batch_size, args.num_point))
173 | batch_point_index = np.zeros((args.batch_size, args.num_point))
174 | batch_smpw = np.zeros((args.batch_size, args.num_point))
175 | for sbatch in range(s_batch_num):
176 | start_idx = sbatch * args.batch_size
177 | end_idx = min((sbatch + 1) * args.batch_size, num_blocks)
178 | real_batch_size = end_idx - start_idx
179 | batch_data[0:real_batch_size, ...] = scene_data[start_idx:end_idx, ...]
180 | batch_label[0:real_batch_size, ...] = scene_label[start_idx:end_idx, ...]
181 | batch_point_index[0:real_batch_size, ...] = scene_point_index[start_idx:end_idx, ...]
182 | batch_smpw[0:real_batch_size, ...] = scene_smpw[start_idx:end_idx, ...]
183 |
184 | torch_data = torch.Tensor(batch_data)
185 | torch_data= torch_data.float().cuda()
186 |
187 | classifier = classifier.eval()
188 | graph_prd, coarse_map = build_graph_pyramid(torch_data[:, :, 0:3], graph_inf)
189 | seg_pred = classifier(torch_data[:, :, :6], graph_prd, coarse_map)
190 |
191 | batch_pred_label = seg_pred.contiguous().cpu().data.max(2)[1].numpy()
192 |
193 | vote_label_pool = add_vote(vote_label_pool, batch_point_index[0:real_batch_size, ...],
194 | batch_pred_label[0:real_batch_size, ...],
195 | batch_smpw[0:real_batch_size, ...])
196 |
197 | pred_label = np.argmax(vote_label_pool, 1)
198 |
199 | for l in range(args.num_class):
200 | total_seen_class_tmp[l] += np.sum((whole_scene_label == l))
201 | total_correct_class_tmp[l] += np.sum((pred_label == l) & (whole_scene_label == l))
202 | total_iou_deno_class_tmp[l] += np.sum(((pred_label == l) | (whole_scene_label == l)))
203 |
204 | total_pre_class[l] += np.sum((pred_label == l))
205 |
206 | total_seen_class[l] += total_seen_class_tmp[l]
207 | total_correct_class[l] += total_correct_class_tmp[l]
208 | total_iou_deno_class[l] += total_iou_deno_class_tmp[l]
209 |
210 | iou_map = np.array(total_correct_class_tmp) / (np.array(total_iou_deno_class_tmp, dtype=np.float) + 1e-6)
211 | print(iou_map)
212 | arr = np.array(total_seen_class_tmp)
213 | tmp_iou = np.mean(iou_map[arr != 0])
214 | log_string('Mean IoU of %s: %.4f' % (scene_id[batch_idx], tmp_iou))
215 | # print('----------------------------')
216 | #
217 | # filename = os.path.join(visual_dir, scene_id[batch_idx] + '.txt')
218 | # with open(filename, 'w') as pl_save:
219 | # for i in pred_label:
220 | # pl_save.write(str(int(i)) + '\n')
221 | # pl_save.close()
222 | # for i in range(whole_scene_label.shape[0]):
223 | # color = g_label2color[pred_label[i]]
224 | # color_gt = g_label2color[whole_scene_label[i]]
225 | # if args.visual:
226 | # fout.write('v %f %f %f %d %d %d\n' % (
227 | # whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color[0], color[1],
228 | # color[2]))
229 | # fout_gt.write(
230 | # 'v %f %f %f %d %d %d\n' % (
231 | # whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color_gt[0],
232 | # color_gt[1], color_gt[2]))
233 | # if args.visual:
234 | # fout.close()
235 | # fout_gt.close()
236 |
237 | IoU = np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6)
238 | iou_per_class_str = '------- IoU --------\n'
239 | F1_per_class_str = '------- F1 --------\n'
240 | for l in range(args.num_class):
241 | precision_per_class[l] = total_correct_class[l] / total_pre_class[l]
242 | recall_per_class[l] = total_correct_class[l] / total_seen_class[l]
243 | iou_per_class_str += 'class %s, IoU: %.3f \n' % (
244 | seg_label_to_cat[l] + ' ' * (10 - len(seg_label_to_cat[l])),
245 | total_correct_class[l] / float(total_iou_deno_class[l]))
246 | F1_per_class_str += 'class %s ,F1: %.3f \n' % (
247 | seg_label_to_cat[l] + ' ' * (10 - len(seg_label_to_cat[l])),
248 | 2 * (precision_per_class[l] * recall_per_class[l]) / (
249 | precision_per_class[l] + recall_per_class[l] + 1e-6))
250 | log_string(iou_per_class_str)
251 | log_string(F1_per_class_str)
252 | log_string('eval point avg class IoU: %f' % np.mean(IoU))
253 | log_string('eval whole scene point avg class acc: %f' % (
254 | np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float) + 1e-6))))
255 | log_string('eval whole scene point accuracy: %f' % (
256 | np.sum(total_correct_class) / float(np.sum(total_seen_class) + 1e-6)))
257 |
258 | print("Done!")
259 |
260 | if __name__ == '__main__':
261 | args = parse_args()
262 | main(args)
263 |
--------------------------------------------------------------------------------
/tool/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import logging
5 | import warnings
6 | import datetime
7 | import argparse
8 | import torch.utils.data
9 | import torch.nn.parallel
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import torch.nn.functional as F
14 |
15 | from tqdm import tqdm
16 | from pathlib import Path
17 | from collections import defaultdict
18 | from torch.autograd import Variable
19 |
20 | from util import transform
21 | from data.S3DIS.S3DISDataLoader import S3DISDataset
22 | from model.gacnet import GACNet, build_graph_pyramid
23 |
24 |
25 |
26 |
27 | warnings.filterwarnings("ignore")
28 |
29 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter']
30 | class2label = {cls: i for i,cls in enumerate(classes)}
31 | seg_classes = class2label
32 | seg_label_to_cat = {}
33 | for i,cat in enumerate(seg_classes.keys()):
34 | seg_label_to_cat[i] = cat
35 |
36 | graph_inf = {'stride_list': [1024, 256, 64, 32], #can be seen as the downsampling rate
37 | 'radius_list': [0.1, 0.2, 0.4, 0.8, 1.6], # radius for neighbor points searching
38 | 'maxsample_list': [12, 21, 21, 21, 12] #number of neighbor points for each layer
39 | }
40 |
41 | # number of units for each mlp layer
42 | forward_parm = [
43 | [ [32,32,64], [64] ],
44 | [ [64,64,128], [128] ],
45 | [ [128,128,256], [256] ],
46 | [ [256,256,512], [512] ],
47 | [ [256,256], [256] ]
48 | ]
49 |
50 | # for feature interpolation stage
51 | upsample_parm = [
52 | [128, 128],
53 | [128, 128],
54 | [256, 256],
55 | [256, 256]
56 | ]
57 |
58 | # parameters for fully connection layer
59 | fullconect_parm = 128
60 |
61 | net_inf = {'forward_parm': forward_parm,
62 | 'upsample_parm': upsample_parm,
63 | 'fullconect_parm': fullconect_parm
64 | }
65 |
66 |
67 |
68 | def parse_args():
69 | parser = argparse.ArgumentParser('GACNet')
70 | parser.add_argument('--gpu', type=str, default='1', help='specify gpu device')
71 | parser.add_argument('--multi_gpu', type=str, default=None, help='whether use multi gpu training')
72 |
73 | parser.add_argument('--epoch', type=int, default=100, help='number of epochs for training [default: 200]')
74 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers [default: 4]')
75 | parser.add_argument('--batchSize', type=int, default=2, help='input batch size [default: 24]')
76 |
77 | parser.add_argument('--alpha', type=float, default=0.2, help='alpha for leakyRelu [default: 0.2]')
78 | parser.add_argument('--dropout', type=float, default=0, help='dropout [default: 0]')
79 |
80 | parser.add_argument('--optimizer', type=str, default='SGD', help='type of optimizer')
81 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay for Adam')
82 | parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate for training [default: 0.001 for Adam, 0.01 for SGD]')
83 |
84 | parser.add_argument('--pretrain', type=str, default=None, help='whether use pretrain model')
85 | parser.add_argument('--log_dir', type=str, default='logs_gacnet/',help='decay rate of learning rate')
86 |
87 | parser.add_argument('--datapath', type=str, default='/workspace/VisualPointCloud/testgacnet', help='path of the dataset')
88 | parser.add_argument('--numpoint', type=int, default=4096, help='number of point for input [default: 4096]')
89 | parser.add_argument('--numclass', type=int, default=13, help='class number of the dataset [default: 13]')
90 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]')
91 | return parser.parse_args()
92 |
93 | def main(args):
94 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'
95 | '''CREATE DIR'''
96 | experiment_dir = Path('./experiment/')
97 | experiment_dir.mkdir(exist_ok=True)
98 | file_dir = Path(str(experiment_dir) +'/'+ str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + '_' + args.log_dir))
99 | file_dir.mkdir(exist_ok=True)
100 | checkpoints_dir = file_dir.joinpath('checkpoints/')
101 | checkpoints_dir.mkdir(exist_ok=True)
102 | log_dir = file_dir.joinpath(args.log_dir)
103 | log_dir.mkdir(exist_ok=True)
104 |
105 | '''LOG'''
106 | args = parse_args()
107 | logger = logging.getLogger('GACNet')
108 | logger.setLevel(logging.INFO)
109 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
110 | file_handler = logging.FileHandler(str(log_dir) + '/train_gacnet.txt')
111 | file_handler.setLevel(logging.INFO)
112 | file_handler.setFormatter(formatter)
113 | logger.addHandler(file_handler)
114 | logger.info('PARAMETER ...')
115 | logger.info(args)
116 | print('Load data...')
117 |
118 | train_transform = transform.Compose([transform.ToTensor()])
119 | print("start loading training data ...")
120 | TRAIN_DATASET = S3DISDataset(split='train', data_root=args.datapath, num_point=args.numpoint, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)
121 | dataloader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchSize, shuffle=True, num_workers=args.workers,
122 | pin_memory=True, drop_last=True, worker_init_fn = lambda x: np.random.seed(x+int(time.time())))
123 |
124 | val_transform = transform.Compose([transform.ToTensor()])
125 | print("start loading test data ...")
126 | TEST_DATASET = S3DISDataset(split='test', data_root=args.datapath, num_point=args.numpoint, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)
127 | testdataloader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=8, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
128 |
129 | weights = torch.Tensor(TRAIN_DATASET.labelweights).cuda()
130 |
131 | blue = lambda x: '\033[94m' + x + '\033[0m'
132 | model = GACNet(args.numclass, graph_inf, net_inf)
133 |
134 | if args.pretrain is not None:
135 | model.load_state_dict(torch.load(args.pretrain))
136 | print('load model %s'%args.pretrain)
137 | logger.info('load model %s'%args.pretrain)
138 | else:
139 | print('Training from scratch')
140 | logger.info('Training from scratch')
141 | pretrain = args.pretrain
142 | init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0
143 |
144 | def adjust_learning_rate(optimizer, step):
145 | """Sets the learning rate to the initial LR decayed by 30 every 20000 steps"""
146 | lr = args.learning_rate * (0.3 ** (step // 20000))
147 | for param_group in optimizer.param_groups:
148 | param_group['lr'] = lr
149 |
150 | if args.optimizer == 'SGD':
151 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)
152 | elif args.optimizer == 'Adam':
153 | optimizer = torch.optim.Adam(
154 | model.parameters(),
155 | lr=args.learning_rate,
156 | betas=(0.9, 0.999),
157 | eps=1e-08,
158 | weight_decay=args.decay_rate
159 | )
160 |
161 | '''GPU selection and multi-GPU'''
162 | if args.multi_gpu is not None:
163 | device_ids = [int(x) for x in args.multi_gpu.split(',')]
164 | torch.backends.cudnn.benchmark = True
165 | model.cuda(device_ids[0])
166 | model = torch.nn.DataParallel(model, device_ids=device_ids)
167 | else:
168 | model.cuda()
169 |
170 | history = defaultdict(lambda: list())
171 | best_acc = 0
172 | best_meaniou = 0
173 | step = 0
174 |
175 | for epoch in range(init_epoch,args.epoch):
176 | for i, data in tqdm(enumerate(dataloader, 0),total=len(dataloader),smoothing=0.9):
177 | points, target = data
178 | points, target = Variable(points.float()), Variable(target.long())
179 | points, target = points.cuda(), target.cuda()
180 | optimizer.zero_grad()
181 | model = model.train()
182 |
183 | graph_prd, coarse_map = build_graph_pyramid(points[:, :, 0:3], graph_inf)
184 | pred = model(points[:, :, :6], graph_prd, coarse_map)
185 |
186 | pred = pred.contiguous().view(-1, args.numclass)
187 | target = target.view(-1, 1)[:, 0]
188 | loss = F.nll_loss(pred, target, weight=weights)
189 | history['loss'].append(loss.cpu().data.numpy())
190 | loss.backward()
191 | optimizer.step()
192 | step += 1
193 | adjust_learning_rate(optimizer, step)
194 |
195 | test_metrics, test_hist_acc, cat_mean_iou = test_seg(model, testdataloader, seg_label_to_cat)
196 | mean_iou = np.mean(cat_mean_iou)
197 |
198 | print('Epoch %d %s accuracy: %f meanIOU: %f' % (
199 | epoch, blue('test'), test_metrics['accuracy'],mean_iou))
200 | logger.info('Epoch %d %s accuracy: %f meanIOU: %f' % (
201 | epoch, 'test', test_metrics['accuracy'],mean_iou))
202 | if test_metrics['accuracy'] > best_acc:
203 | best_acc = test_metrics['accuracy']
204 | torch.save(model.state_dict(), '%s/GACNet_%.3d_%.4f_%.4f.pth' % (checkpoints_dir, epoch, best_acc, best_meaniou))
205 | logger.info(cat_mean_iou)
206 | logger.info('Save model..')
207 | print('Save model..')
208 | print(cat_mean_iou)
209 | if mean_iou > best_meaniou:
210 | best_meaniou = mean_iou
211 | torch.save(model.state_dict(), '%s/GACNet_%.3d_%.4f_%.4f.pth' % (checkpoints_dir, epoch, best_acc, best_meaniou))
212 | logger.info(cat_mean_iou)
213 | logger.info('Save model..')
214 | print('Save model..')
215 | print(cat_mean_iou)
216 | print('Best accuracy is: %.5f'%best_acc)
217 | logger.info('Best accuracy is: %.5f'%best_acc)
218 | print('Best meanIOU is: %.5f'%best_meaniou)
219 | logger.info('Best meanIOU is: %.5f'%best_meaniou)
220 |
221 | def test_seg(model, loader, catdict, num_classes = 13):
222 | iou_tabel = np.zeros((len(catdict),3))
223 | metrics = defaultdict(lambda:list())
224 | hist_acc = []
225 | for batch_id, (points, target) in tqdm(enumerate(loader), total=len(loader), smoothing=0.9):
226 | batchsize, num_point, _ = points.size()
227 |
228 | points, target = Variable(points.float()), Variable(target.long())
229 |
230 | points, target = points.cuda(), target.cuda()
231 |
232 | graph_prd, coarse_map = build_graph_pyramid(points[:, :, 0:3], graph_inf)
233 | pred = model(points[:, :, :6], graph_prd, coarse_map)
234 |
235 | mean_iou, iou_tabel = compute_iou(pred, target, iou_tabel)
236 | pred = pred.contiguous().view(-1, num_classes)
237 | target = target.view(-1, 1)[:, 0]
238 | pred_choice = pred.data.max(1)[1]
239 | correct = pred_choice.eq(target.data).cpu().sum()
240 | metrics['accuracy'].append(correct.item()/ (batchsize * num_point))
241 | metrics['iou'].append(mean_iou)
242 | iou_tabel[:,2] = iou_tabel[:,0] /(iou_tabel[:,1]+0.01)
243 | hist_acc += metrics['accuracy']
244 | metrics['accuracy'] = np.mean(metrics['accuracy'])
245 | iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou'])
246 | iou_tabel['Category_IOU'] = [cat_value for cat_value in catdict.values()]
247 | cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean()
248 |
249 | return metrics, hist_acc, cat_iou
250 |
251 | def compute_iou(pred,target,iou_tabel=None):
252 | ious = []
253 | target = target.cpu().data.numpy()
254 | for j in range(pred.size(0)):
255 | batch_pred = pred[j]
256 | batch_target = target[j]
257 | batch_choice = batch_pred.data.max(1)[1].cpu().data.numpy()
258 | for cat in np.unique(batch_target):
259 | intersection = np.sum((batch_target == cat) & (batch_choice == cat))
260 | union = float(np.sum((batch_target == cat) | (batch_choice == cat)))
261 | iou = intersection/union
262 | ious.append(iou)
263 | iou_tabel[cat,0] += iou
264 | iou_tabel[cat,1] += 1
265 | return np.mean(ious), iou_tabel
266 |
267 | if __name__ == '__main__':
268 | args = parse_args()
269 | main(args)
270 |
--------------------------------------------------------------------------------
/util/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import h5py
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset
6 |
7 |
8 | def make_dataset(split='train', data_root=None, data_list=None):
9 | if not os.path.isfile(data_list):
10 | raise (RuntimeError("Point list file do not exist: " + data_list + "\n"))
11 | point_list = []
12 | list_read = open(data_list).readlines()
13 | print("Totally {} samples in {} set.".format(len(list_read), split))
14 | for line in list_read:
15 | point_list.append(os.path.join(data_root, line.strip()))
16 | return point_list
17 |
18 |
19 | class PointData(Dataset):
20 | def __init__(self, split='train', data_root=None, data_list=None, transform=None, num_point=None, random_index=False):
21 | assert split in ['train', 'val', 'test']
22 | self.split = split
23 | self.data_list = make_dataset(split, data_root, data_list)
24 | self.transform = transform
25 | self.num_point = num_point
26 | self.random_index = random_index
27 |
28 | def __len__(self):
29 | return len(self.data_list)
30 |
31 | def __getitem__(self, index):
32 | data_path = self.data_list[index]
33 | f = h5py.File(data_path, 'r')
34 | data = f['data'][:]
35 | if self.split is 'test':
36 | label = 255 # place holder
37 | else:
38 | label = f['label'][:]
39 | f.close()
40 | if self.num_point is None:
41 | self.num_point = data.shape[0]
42 | idxs = np.arange(data.shape[0])
43 | if self.random_index:
44 | np.random.shuffle(idxs)
45 | idxs = idxs[0:self.num_point]
46 | data = data[idxs, :]
47 | if label.size != 1: # seg data
48 | label = label[idxs]
49 | if self.transform is not None:
50 | data, label = self.transform(data, label)
51 | return data, label
52 |
53 |
54 | if __name__ == '__main__':
55 | data_root = '/mnt/sda1/hszhao/dataset/3d/s3dis'
56 | data_list = '/mnt/sda1/hszhao/dataset/3d/s3dis/list/train12346.txt'
57 | point_data = PointData('train', data_root, data_list)
58 | print('point data size:', point_data.__len__())
59 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape)
60 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape)
61 |
--------------------------------------------------------------------------------
/util/ply.py:
--------------------------------------------------------------------------------
1 | # Basic libs
2 | import numpy as np
3 | import sys
4 |
5 |
6 | # Define PLY types
7 | ply_dtypes = dict([
8 | (b'int8', 'i1'),
9 | (b'char', 'i1'),
10 | (b'uint8', 'u1'),
11 | (b'uchar', 'u1'),
12 | (b'int16', 'i2'),
13 | (b'short', 'i2'),
14 | (b'uint16', 'u2'),
15 | (b'ushort', 'u2'),
16 | (b'int32', 'i4'),
17 | (b'int', 'i4'),
18 | (b'uint32', 'u4'),
19 | (b'uint', 'u4'),
20 | (b'float32', 'f4'),
21 | (b'float', 'f4'),
22 | (b'float64', 'f8'),
23 | (b'double', 'f8')
24 | ])
25 |
26 | # Numpy reader format
27 | valid_formats = {'ascii': '', 'binary_big_endian': '>',
28 | 'binary_little_endian': '<'}
29 |
30 |
31 | # ----------------------------------------------------------------------------------------------------------------------
32 | #
33 | # Functions
34 | # \***************/
35 | #
36 |
37 |
38 | def parse_header(plyfile, ext):
39 | # Variables
40 | line = []
41 | properties = []
42 | num_points = None
43 |
44 | while b'end_header' not in line and line != b'':
45 | line = plyfile.readline()
46 |
47 | if b'element' in line:
48 | line = line.split()
49 | num_points = int(line[2])
50 |
51 | elif b'property' in line:
52 | line = line.split()
53 | properties.append((line[2].decode(), ext + ply_dtypes[line[1]]))
54 |
55 | return num_points, properties
56 |
57 |
58 | def parse_mesh_header(plyfile, ext):
59 | # Variables
60 | line = []
61 | vertex_properties = []
62 | num_points = None
63 | num_faces = None
64 | current_element = None
65 |
66 |
67 | while b'end_header' not in line and line != b'':
68 | line = plyfile.readline()
69 |
70 | # Find point element
71 | if b'element vertex' in line:
72 | current_element = 'vertex'
73 | line = line.split()
74 | num_points = int(line[2])
75 |
76 | elif b'element face' in line:
77 | current_element = 'face'
78 | line = line.split()
79 | num_faces = int(line[2])
80 |
81 | elif b'property' in line:
82 | if current_element == 'vertex':
83 | line = line.split()
84 | vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]]))
85 | elif current_element == 'vertex':
86 | if not line.startswith('property list uchar int'):
87 | raise ValueError('Unsupported faces property : ' + line)
88 |
89 | return num_points, num_faces, vertex_properties
90 |
91 |
92 | def read_ply(filename, triangular_mesh=False):
93 | """
94 | Read ".ply" files
95 |
96 | Parameters
97 | ----------
98 | filename : string
99 | the name of the file to read.
100 |
101 | Returns
102 | -------
103 | result : array
104 | data stored in the file
105 |
106 | Examples
107 | --------
108 | Store data in file
109 |
110 | >>> points = np.random.rand(5, 3)
111 | >>> values = np.random.randint(2, size=10)
112 | >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values'])
113 |
114 | Read the file
115 |
116 | >>> data = read_ply('example.ply')
117 | >>> values = data['values']
118 | array([0, 0, 1, 1, 0])
119 |
120 | >>> points = np.vstack((data['x'], data['y'], data['z'])).T
121 | array([[ 0.466 0.595 0.324]
122 | [ 0.538 0.407 0.654]
123 | [ 0.850 0.018 0.988]
124 | [ 0.395 0.394 0.363]
125 | [ 0.873 0.996 0.092]])
126 |
127 | """
128 |
129 | with open(filename, 'rb') as plyfile:
130 |
131 |
132 | # Check if the file start with ply
133 | if b'ply' not in plyfile.readline():
134 | raise ValueError('The file does not start whith the word ply')
135 |
136 | # get binary_little/big or ascii
137 | fmt = plyfile.readline().split()[1].decode()
138 | if fmt == "ascii":
139 | raise ValueError('The file is not binary')
140 |
141 | # get extension for building the numpy dtypes
142 | ext = valid_formats[fmt]
143 |
144 | # PointCloud reader vs mesh reader
145 | if triangular_mesh:
146 |
147 | # Parse header
148 | num_points, num_faces, properties = parse_mesh_header(plyfile, ext)
149 |
150 | # Get point data
151 | vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points)
152 |
153 | # Get face data
154 | face_properties = [('k', ext + 'u1'),
155 | ('v1', ext + 'i4'),
156 | ('v2', ext + 'i4'),
157 | ('v3', ext + 'i4')]
158 | faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces)
159 |
160 | # Return vertex data and concatenated faces
161 | faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T
162 | data = [vertex_data, faces]
163 |
164 | else:
165 |
166 | # Parse header
167 | num_points, properties = parse_header(plyfile, ext)
168 |
169 | # Get data
170 | data = np.fromfile(plyfile, dtype=properties, count=num_points)
171 |
172 | return data
173 |
174 |
175 | def header_properties(field_list, field_names):
176 |
177 | # List of lines to write
178 | lines = []
179 |
180 | # First line describing element vertex
181 | lines.append('element vertex %d' % field_list[0].shape[0])
182 |
183 | # Properties lines
184 | i = 0
185 | for fields in field_list:
186 | for field in fields.T:
187 | lines.append('property %s %s' % (field.dtype.name, field_names[i]))
188 | i += 1
189 |
190 | return lines
191 |
192 |
193 | def write_ply(filename, field_list, field_names, triangular_faces=None):
194 | """
195 | Write ".ply" files
196 |
197 | Parameters
198 | ----------
199 | filename : string
200 | the name of the file to which the data is saved. A '.ply' extension will be appended to the
201 | file name if it does no already have one.
202 |
203 | field_list : list, tuple, numpy array
204 | the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a
205 | tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered
206 | as one field.
207 |
208 | field_names : list
209 | the name of each fields as a list of strings. Has to be the same length as the number of
210 | fields.
211 |
212 | Examples
213 | --------
214 | >>> points = np.random.rand(10, 3)
215 | >>> write_ply('example1.ply', points, ['x', 'y', 'z'])
216 |
217 | >>> values = np.random.randint(2, size=10)
218 | >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values'])
219 |
220 | >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8)
221 | >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', values']
222 | >>> write_ply('example3.ply', [points, colors, values], field_names)
223 |
224 | """
225 |
226 | # Format list input to the right form
227 | field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,))
228 | for i, field in enumerate(field_list):
229 | if field.ndim < 2:
230 | field_list[i] = field.reshape(-1, 1)
231 | if field.ndim > 2:
232 | print('fields have more than 2 dimensions')
233 | return False
234 |
235 | # check all fields have the same number of data
236 | n_points = [field.shape[0] for field in field_list]
237 | if not np.all(np.equal(n_points, n_points[0])):
238 | print('wrong field dimensions')
239 | return False
240 |
241 | # Check if field_names and field_list have same nb of column
242 | n_fields = np.sum([field.shape[1] for field in field_list])
243 | if (n_fields != len(field_names)):
244 | print('wrong number of field names')
245 | return False
246 |
247 | # Add extension if not there
248 | if not filename.endswith('.ply'):
249 | filename += '.ply'
250 |
251 | # open in text mode to write the header
252 | with open(filename, 'w') as plyfile:
253 |
254 | # First magical word
255 | header = ['ply']
256 |
257 | # Encoding format
258 | header.append('format binary_' + sys.byteorder + '_endian 1.0')
259 |
260 | # Points properties description
261 | header.extend(header_properties(field_list, field_names))
262 |
263 | # Add faces if needded
264 | if triangular_faces is not None:
265 | header.append('element face {:d}'.format(triangular_faces.shape[0]))
266 | header.append('property list uchar int vertex_indices')
267 |
268 | # End of header
269 | header.append('end_header')
270 |
271 | # Write all lines
272 | for line in header:
273 | plyfile.write("%s\n" % line)
274 |
275 | # open in binary/append to use tofile
276 | with open(filename, 'ab') as plyfile:
277 |
278 | # Create a structured array
279 | i = 0
280 | type_list = []
281 | for fields in field_list:
282 | for field in fields.T:
283 | type_list += [(field_names[i], field.dtype.str)]
284 | i += 1
285 | data = np.empty(field_list[0].shape[0], dtype=type_list)
286 | i = 0
287 | for fields in field_list:
288 | for field in fields.T:
289 | data[field_names[i]] = field
290 | i += 1
291 |
292 | data.tofile(plyfile)
293 |
294 | if triangular_faces is not None:
295 | triangular_faces = triangular_faces.astype(np.int32)
296 | type_list = [('k', 'uint8')] + [(str(ind), 'int32') for ind in range(3)]
297 | data = np.empty(triangular_faces.shape[0], dtype=type_list)
298 | data['k'] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8)
299 | data['0'] = triangular_faces[:, 0]
300 | data['1'] = triangular_faces[:, 1]
301 | data['2'] = triangular_faces[:, 2]
302 | data.tofile(plyfile)
303 |
304 | return True
305 |
306 |
307 | def describe_element(name, df):
308 | """ Takes the columns of the dataframe and builds a ply-like description
309 |
310 | Parameters
311 | ----------
312 | name: str
313 | df: pandas DataFrame
314 |
315 | Returns
316 | -------
317 | element: list[str]
318 | """
319 | property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'}
320 | element = ['element ' + name + ' ' + str(len(df))]
321 |
322 | if name == 'face':
323 | element.append("property list uchar int points_indices")
324 |
325 | else:
326 | for i in range(len(df.columns)):
327 | # get first letter of dtype to infer format
328 | f = property_formats[str(df.dtypes[i])[0]]
329 | element.append('property ' + f + ' ' + df.columns.values[i])
330 |
331 | return element
--------------------------------------------------------------------------------
/util/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 |
6 | class Compose(object):
7 | def __init__(self, transforms):
8 | self.transforms = transforms
9 |
10 | def __call__(self, data, label):
11 | for t in self.transforms:
12 | data, label = t(data, label)
13 | return data, label
14 |
15 |
16 | class ToTensor(object):
17 | def __call__(self, data, label):
18 | data = torch.from_numpy(data)
19 | if not isinstance(data, torch.FloatTensor):
20 | data = data.float()
21 | label = torch.from_numpy(label)
22 | if not isinstance(label, torch.LongTensor):
23 | label = label.long()
24 | return data, label
25 |
26 |
27 | class RandomRotate(object):
28 | def __init__(self, rotate_angle=None, along_z=False):
29 | self.rotate_angle = rotate_angle
30 | self.along_z = along_z
31 |
32 | def __call__(self, data, label):
33 | if self.rotate_angle is None:
34 | rotate_angle = np.random.uniform() * 2 * np.pi
35 | else:
36 | rotate_angle = self.rotate_angle
37 | cosval, sinval = np.cos(rotate_angle), np.sin(rotate_angle)
38 | if self.along_z:
39 | rotation_matrix = np.array([[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]])
40 | else:
41 | rotation_matrix = np.array([[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]])
42 | data[:, 0:3] = np.dot(data[:, 0:3], rotation_matrix)
43 | if data.shape[1] > 3: # use normal
44 | data[:, 3:6] = np.dot(data[:, 3:6], rotation_matrix)
45 | return data, label
46 |
47 |
48 | class RandomRotatePerturbation(object):
49 | def __init__(self, angle_sigma=0.06, angle_clip=0.18):
50 | self.angle_sigma = angle_sigma
51 | self.angle_clip = angle_clip
52 |
53 | def __call__(self, data, label):
54 | angles = np.clip(self.angle_sigma*np.random.randn(3), -self.angle_clip, self.angle_clip)
55 | Rx = np.array([[1, 0, 0],
56 | [0, np.cos(angles[0]), -np.sin(angles[0])],
57 | [0, np.sin(angles[0]), np.cos(angles[0])]])
58 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
59 | [0, 1, 0],
60 | [-np.sin(angles[1]), 0, np.cos(angles[1])]])
61 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
62 | [np.sin(angles[2]), np.cos(angles[2]), 0],
63 | [0, 0, 1]])
64 | R = np.dot(Rz, np.dot(Ry, Rx))
65 | data[:, 0:3] = np.dot(data[:, 0:3], R)
66 | if data.shape[1] > 3: # use normal
67 | data[:, 3:6] = np.dot(data[:, 3:6], R)
68 | return data, label
69 |
70 |
71 | class RandomScale(object):
72 | def __init__(self, scale_low=0.8, scale_high=1.25):
73 | self.scale_low = scale_low
74 | self.scale_high = scale_high
75 |
76 | def __call__(self, data, label):
77 | scale = np.random.uniform(self.scale_low, self.scale_high)
78 | data[:, 0:3] *= scale
79 | return data, label
80 |
81 |
82 | class RandomShift(object):
83 | def __init__(self, shift_range=0.1):
84 | self.shift_range = shift_range
85 |
86 | def __call__(self, data, label):
87 | shift = np.random.uniform(-self.shift_range, self.shift_range, 3)
88 | data[:, 0:3] += shift
89 | return data, label
90 |
91 |
92 | class RandomJitter(object):
93 | def __init__(self, sigma=0.01, clip=0.05):
94 | self.sigma = sigma
95 | self.clip = clip
96 |
97 | def __call__(self, data, label):
98 | assert (self.clip > 0)
99 | jitter = np.clip(self.sigma * np.random.randn(data.shape[0], 3), -1 * self.clip, self.clip)
100 | data[:, 0:3] += jitter
101 | return data, label
102 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn.modules.conv import _ConvNd
8 | from torch.nn.modules.batchnorm import _BatchNorm
9 | import torch.nn.init as initer
10 |
11 |
12 | class AverageMeter(object):
13 | """Computes and stores the average and current value"""
14 | def __init__(self):
15 | self.reset()
16 |
17 | def reset(self):
18 | self.val = 0
19 | self.avg = 0
20 | self.sum = 0
21 | self.count = 0
22 |
23 | def update(self, val, n=1):
24 | self.val = val
25 | self.sum += val * n
26 | self.count += n
27 | self.avg = self.sum / self.count
28 |
29 |
30 | def step_learning_rate(optimizer, base_lr, epoch, step_epoch, multiplier=0.1, clip=1e-6):
31 | """Sets the learning rate to the base LR decayed by 10 every step epochs"""
32 | lr = max(base_lr * (multiplier ** (epoch // step_epoch)), clip)
33 | for param_group in optimizer.param_groups:
34 | param_group['lr'] = lr
35 |
36 |
37 | def poly_learning_rate(optimizer, base_lr, curr_iter, max_iter, power=0.9):
38 | """poly learning rate policy"""
39 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power
40 | for param_group in optimizer.param_groups:
41 | param_group['lr'] = lr
42 |
43 |
44 | def intersectionAndUnion(output, target, K, ignore_index=255):
45 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
46 | assert (output.ndim in [1, 2, 3])
47 | assert output.shape == target.shape
48 | output = output.reshape(output.size).copy()
49 | target = target.reshape(target.size)
50 | output[np.where(target == ignore_index)[0]] = 255
51 | intersection = output[np.where(output == target)[0]]
52 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1))
53 | area_output, _ = np.histogram(output, bins=np.arange(K+1))
54 | area_target, _ = np.histogram(target, bins=np.arange(K+1))
55 | area_union = area_output + area_target - area_intersection
56 | return area_intersection, area_union, area_target
57 |
58 |
59 | def intersectionAndUnionGPU(output, target, K, ignore_index=255):
60 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
61 | assert (output.dim() in [1, 2, 3])
62 | assert output.shape == target.shape
63 | output = output.view(-1)
64 | target = target.view(-1)
65 | output[target == ignore_index] = ignore_index
66 | intersection = output[output == target]
67 | # https://github.com/pytorch/pytorch/issues/1382
68 | area_intersection = torch.histc(intersection.float().cpu(), bins=K, min=0, max=K-1)
69 | area_output = torch.histc(output.float().cpu(), bins=K, min=0, max=K-1)
70 | area_target = torch.histc(target.float().cpu(), bins=K, min=0, max=K-1)
71 | area_union = area_output + area_target - area_intersection
72 | return area_intersection.cuda(), area_union.cuda(), area_target.cuda()
73 |
74 |
75 | def check_mkdir(dir_name):
76 | if not os.path.exists(dir_name):
77 | os.mkdir(dir_name)
78 |
79 |
80 | def check_makedirs(dir_name):
81 | if not os.path.exists(dir_name):
82 | os.makedirs(dir_name)
83 |
84 |
85 | def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'):
86 | """
87 | :param model: Pytorch Model which is nn.Module
88 | :param conv: 'kaiming' or 'xavier'
89 | :param batchnorm: 'normal' or 'constant'
90 | :param linear: 'kaiming' or 'xavier'
91 | :param lstm: 'kaiming' or 'xavier'
92 | """
93 | for m in model.modules():
94 | if isinstance(m, (_ConvNd)):
95 | if conv == 'kaiming':
96 | initer.kaiming_normal_(m.weight)
97 | elif conv == 'xavier':
98 | initer.xavier_normal_(m.weight)
99 | else:
100 | raise ValueError("init type of conv error.\n")
101 | if m.bias is not None:
102 | initer.constant_(m.bias, 0)
103 |
104 | elif isinstance(m, _BatchNorm):
105 | if batchnorm == 'normal':
106 | initer.normal_(m.weight, 1.0, 0.02)
107 | elif batchnorm == 'constant':
108 | initer.constant_(m.weight, 1.0)
109 | else:
110 | raise ValueError("init type of batchnorm error.\n")
111 | initer.constant_(m.bias, 0.0)
112 |
113 | elif isinstance(m, nn.Linear):
114 | if linear == 'kaiming':
115 | initer.kaiming_normal_(m.weight)
116 | elif linear == 'xavier':
117 | initer.xavier_normal_(m.weight)
118 | else:
119 | raise ValueError("init type of linear error.\n")
120 | if m.bias is not None:
121 | initer.constant_(m.bias, 0)
122 |
123 | elif isinstance(m, nn.LSTM):
124 | for name, param in m.named_parameters():
125 | if 'weight' in name:
126 | if lstm == 'kaiming':
127 | initer.kaiming_normal_(param)
128 | elif lstm == 'xavier':
129 | initer.xavier_normal_(param)
130 | else:
131 | raise ValueError("init type of lstm error.\n")
132 | elif 'bias' in name:
133 | initer.constant_(param, 0)
134 |
135 |
136 | def convert_to_syncbn(model):
137 | def recursive_set(cur_module, name, module):
138 | if len(name.split('.')) > 1:
139 | recursive_set(getattr(cur_module, name[:name.find('.')]), name[name.find('.')+1:], module)
140 | else:
141 | setattr(cur_module, name, module)
142 | from lib.sync_bn import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
143 | for name, m in model.named_modules():
144 | if isinstance(m, nn.BatchNorm1d):
145 | recursive_set(model, name, SynchronizedBatchNorm1d(m.num_features, m.eps, m.momentum, m.affine))
146 | elif isinstance(m, nn.BatchNorm2d):
147 | recursive_set(model, name, SynchronizedBatchNorm2d(m.num_features, m.eps, m.momentum, m.affine))
148 | elif isinstance(m, nn.BatchNorm3d):
149 | recursive_set(model, name, SynchronizedBatchNorm3d(m.num_features, m.eps, m.momentum, m.affine))
150 |
151 |
152 | def colorize(gray, palette):
153 | # gray: numpy array of the label and 1*3N size list palette
154 | color = Image.fromarray(gray.astype(np.uint8)).convert('P')
155 | color.putpalette(palette)
156 | return color
157 |
--------------------------------------------------------------------------------