├── LICENSE ├── README.md ├── data └── S3DIS │ ├── S3DISDataLoader.py │ └── s3dis_names.txt ├── flowchart.jpg ├── model ├── gacnet.py └── gacnet_utils.py ├── pointnet2_ops_lib ├── MANIFEST.in ├── pointnet2_ops │ ├── __init__.py │ ├── _ext-src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── _version.py │ ├── pointnet2_modules.py │ └── pointnet2_utils.py └── setup.py ├── tool ├── .DS_Store ├── test.py └── train.py └── util ├── dataset.py ├── ply.py ├── transform.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yongqiang Mao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GACNet_PyTorch 2 | This is the pytorch implmentation of GACNet on S3DIS. 3 |

4 | 5 |

6 | 7 | If you find this repository useful. Please consider giving a star :star:. 8 | ## Dependencies 9 | - Python 3.6 10 | - PyTorch 1.7 11 | - cuda 11 12 | 13 | 14 | ## Install pointnet2-ops 15 | 16 | ``` 17 | cd pointnet2_ops_lib 18 | python setup.py install 19 | ``` 20 | 21 | ## Train Model 22 | ``` 23 | cd tool 24 | python train.py 25 | ``` 26 | 27 | ## Test Model 28 | ``` 29 | cd tool 30 | python test.py 31 | ``` 32 | 33 | ## References 34 | This repo is built based on the Tensorflow implementation of [GACNet](https://github.com/wleigithub/GACNet). Thanks for their great work! 35 | -------------------------------------------------------------------------------- /data/S3DIS/S3DISDataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter'] 5 | class2label = {cls: i for i,cls in enumerate(classes)} 6 | def pc_normalize(data, center): 7 | mindata = np.min(data[:, :3], axis=0) 8 | maxdata = np.max(data[:, :3], axis=0) 9 | center = (center - mindata) / (maxdata - mindata) 10 | data_xyz_norm = (data[:, :3] - mindata) / (maxdata - mindata) 11 | return data_xyz_norm, center 12 | 13 | class S3DISDataset(Dataset): 14 | def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096, 15 | test_area=5, block_size=1.0, sample_rate=1.0, transform=None): 16 | super().__init__() 17 | self.num_point = num_point 18 | self.block_size = block_size 19 | self.transform = transform 20 | rooms = sorted(os.listdir(data_root)) 21 | rooms = [room for room in rooms if 'Area_' in room] 22 | if split == 'train': 23 | rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room] 24 | else: 25 | rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room] 26 | self.room_points, self.room_labels = [], [] 27 | self.room_coord_min, self.room_coord_max = [], [] 28 | num_point_all = [] 29 | labelweights = np.zeros(13) 30 | for room_name in rooms_split: 31 | room_path = os.path.join(data_root, room_name) 32 | room_data = np.load(room_path) # xyzrgbl, N*7 33 | points, labels = room_data[:, 0:6], room_data[:, 6] # xyzrgb, N*6; l, N 34 | tmp, _ = np.histogram(labels, range(14)) 35 | labelweights += tmp 36 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 37 | 38 | self.room_points.append(points), self.room_labels.append(labels) 39 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max) 40 | num_point_all.append(labels.size) 41 | labelweights = labelweights.astype(np.float32) 42 | labelweights = labelweights / np.sum(labelweights) 43 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) 44 | print(self.labelweights) 45 | 46 | 47 | 48 | 49 | sample_prob = num_point_all / np.sum(num_point_all) 50 | num_iter = int(np.sum(num_point_all) * sample_rate / num_point) 51 | room_idxs = [] 52 | for index in range(len(rooms_split)): 53 | room_idxs.extend([index] * int(round(sample_prob[index] * num_iter))) 54 | self.room_idxs = np.array(room_idxs) 55 | print("Totally {} samples in {} set.".format(len(self.room_idxs), split)) 56 | 57 | def __getitem__(self, idx): 58 | room_idx = self.room_idxs[idx] 59 | points = self.room_points[room_idx] # N * 6 60 | labels = self.room_labels[room_idx] # N 61 | N_points = points.shape[0] 62 | 63 | while (True): 64 | center = points[np.random.choice(N_points)][:3] 65 | block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0] 66 | block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0] 67 | point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0] 68 | if point_idxs.size > 1024: 69 | break 70 | 71 | if point_idxs.size >= self.num_point: 72 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False) 73 | else: 74 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True) 75 | 76 | # normalize 77 | selected_points = points[selected_point_idxs, :] # num_point * 6 78 | current_points = np.zeros((self.num_point, 9)) # num_point * 9 79 | current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0] 80 | current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1] 81 | current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2] 82 | selected_points[:, 0] = selected_points[:, 0] - center[0] 83 | selected_points[:, 1] = selected_points[:, 1] - center[1] 84 | selected_points[:, 3:6] /= 255.0 85 | 86 | current_points[:, 0:6] = selected_points 87 | current_labels = labels[selected_point_idxs] 88 | if self.transform is not None: 89 | current_points, current_labels = self.transform(current_points, current_labels) 90 | return current_points, current_labels 91 | 92 | def __len__(self): 93 | return len(self.room_idxs) 94 | 95 | class S3DISDatasetWholeScene(): 96 | # prepare to give prediction on each points 97 | def __init__(self, root, block_points=4096, split='test', test_area=5, stride=0.5, block_size=1.0, padding=0.001): 98 | self.block_points = block_points 99 | self.block_size = block_size 100 | self.padding = padding 101 | self.root = root 102 | self.split = split 103 | self.stride = stride 104 | self.scene_points_num = [] 105 | assert split in ['train', 'test'] 106 | if self.split == 'train': 107 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is -1] 108 | else: 109 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is not -1] 110 | self.scene_points_list = [] 111 | self.semantic_labels_list = [] 112 | self.room_coord_min, self.room_coord_max = [], [] 113 | for file in self.file_list: 114 | data = np.load(root + file) 115 | points = data[:, :3] 116 | self.scene_points_list.append(data[:, :6]) 117 | self.semantic_labels_list.append(data[:, 6]) 118 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 119 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max) 120 | assert len(self.scene_points_list) == len(self.semantic_labels_list) 121 | 122 | labelweights = np.zeros(13) 123 | for seg in self.semantic_labels_list: 124 | tmp, _ = np.histogram(seg, range(14)) 125 | self.scene_points_num.append(seg.shape[0]) 126 | labelweights += tmp 127 | labelweights = labelweights.astype(np.float32) 128 | labelweights = labelweights / np.sum(labelweights) 129 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) 130 | 131 | def __getitem__(self, index): 132 | point_set_ini = self.scene_points_list[index] 133 | points = point_set_ini[:,:6] 134 | labels = self.semantic_labels_list[index] 135 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 136 | grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - self.block_size) / self.stride) + 1) 137 | grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - self.block_size) / self.stride) + 1) 138 | 139 | data_room, label_room, sample_weight, index_room = np.array([]), np.array([]), np.array([]), np.array([]) 140 | for index_y in range(0, grid_y): 141 | for index_x in range(0, grid_x): 142 | s_x = coord_min[0] + index_x * self.stride 143 | e_x = min(s_x + self.block_size, coord_max[0]) 144 | s_x = e_x - self.block_size 145 | 146 | s_y = coord_min[1] + index_y * self.stride 147 | e_y = min(s_y + self.block_size, coord_max[1]) 148 | s_y = e_y - self.block_size 149 | 150 | point_idxs = np.where( 151 | (points[:, 0] >= s_x - self.padding) & (points[:, 0] <= e_x + self.padding) & (points[:, 1] >= s_y - self.padding) & ( 152 | points[:, 1] <= e_y + self.padding))[0] 153 | if point_idxs.size == 0: 154 | continue 155 | num_batch = int(np.ceil(point_idxs.size / self.block_points)) 156 | point_size = int(num_batch * self.block_points) 157 | replace = False if (point_size - point_idxs.size <= point_idxs.size) else True 158 | point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace) 159 | point_idxs = np.concatenate((point_idxs, point_idxs_repeat)) 160 | np.random.shuffle(point_idxs) 161 | 162 | data_batch = points[point_idxs, :] 163 | normlized_xyz = np.zeros((point_size, 3)) 164 | 165 | normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0] 166 | normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1] 167 | normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2] 168 | data_batch[:, 0] = data_batch[:, 0] - (s_x + self.block_size / 2.0) 169 | data_batch[:, 1] = data_batch[:, 1] - (s_y + self.block_size / 2.0) 170 | data_batch[:, 2] = data_batch[:, 2] 171 | data_batch[:, 3:6] /= 255.0 172 | data_batch = np.concatenate((data_batch, normlized_xyz), axis=1) 173 | 174 | label_batch = labels[point_idxs].astype(int) 175 | batch_weight = self.labelweights[label_batch] 176 | 177 | data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch 178 | label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch 179 | sample_weight = np.hstack([sample_weight, batch_weight]) if label_room.size else batch_weight 180 | index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs 181 | data_room = data_room.reshape((-1, self.block_points, data_room.shape[1])) 182 | label_room = label_room.reshape((-1, self.block_points)) 183 | sample_weight = sample_weight.reshape((-1, self.block_points)) 184 | index_room = index_room.reshape((-1, self.block_points)) 185 | return data_room, label_room, sample_weight, index_room 186 | 187 | def __len__(self): 188 | return len(self.scene_points_list) 189 | 190 | if __name__ == '__main__': 191 | data_root = '/data/yxu/PointNonLocal/data/stanford_indoor3d/' 192 | num_point, test_area, block_size, sample_rate = 4096, 5, 1.0, 0.01 193 | 194 | point_data = S3DISDataset(split='train', data_root=data_root, num_point=num_point, test_area=test_area, block_size=block_size, sample_rate=sample_rate, transform=None) 195 | print('point data size:', point_data.__len__()) 196 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape) 197 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape) 198 | import torch, time, random 199 | manual_seed = 123 200 | random.seed(manual_seed) 201 | np.random.seed(manual_seed) 202 | torch.manual_seed(manual_seed) 203 | torch.cuda.manual_seed_all(manual_seed) 204 | def worker_init_fn(worker_id): 205 | random.seed(manual_seed + worker_id) 206 | train_loader = torch.utils.data.DataLoader(point_data, batch_size=16, shuffle=True, num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn) 207 | for idx in range(4): 208 | end = time.time() 209 | for i, (input, target) in enumerate(train_loader): 210 | print('time: {}/{}--{}'.format(i+1, len(train_loader), time.time() - end)) 211 | end = time.time() -------------------------------------------------------------------------------- /data/S3DIS/s3dis_names.txt: -------------------------------------------------------------------------------- 1 | ceiling 2 | floor 3 | wall 4 | beam 5 | column 6 | window 7 | door 8 | chair 9 | table 10 | bookcase 11 | sofa 12 | board 13 | clutter 14 | -------------------------------------------------------------------------------- /flowchart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WingkeungM/GACNet_PyTorch/c5b3701e7029cf3c6a6ff2622255fc762de6690d/flowchart.jpg -------------------------------------------------------------------------------- /model/gacnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import pointnet2_ops_lib.pointnet2_ops.pointnet2_utils as pointnet2_utils 4 | 5 | from model.gacnet_utils import * 6 | 7 | 8 | graph_inf = {'stride_list': [1024, 256, 64, 32], #can be seen as the downsampling rate 9 | 'radius_list': [0.1, 0.2, 0.4, 0.8, 1.6], # radius for neighbor points searching 10 | 'maxsample_list': [12, 21, 21, 21, 12] #number of neighbor points for each layer 11 | } 12 | 13 | # number of units for each mlp layer 14 | forward_parm = [ 15 | [ [32,32,64], [64] ], 16 | [ [64,64,128], [128] ], 17 | [ [128,128,256], [256] ], 18 | [ [256,256,512], [512] ], 19 | [ [256,256], [256] ] 20 | ] 21 | 22 | # for feature interpolation stage 23 | upsample_parm = [ 24 | [128, 128], 25 | [128, 128], 26 | [256, 256], 27 | [256, 256] 28 | ] 29 | 30 | # parameters for fully connection layer 31 | fullconect_parm = 128 32 | 33 | net_inf = {'forward_parm': forward_parm, 34 | 'upsample_parm': upsample_parm, 35 | 'fullconect_parm': fullconect_parm 36 | } 37 | 38 | 39 | 40 | 41 | class GACNet(nn.Module): 42 | def __init__(self, num_classes, graph_inf, net_inf): 43 | super(GACNet, self).__init__() 44 | self.num_classes = num_classes 45 | self.forward_parm, self.upsample_parm, self.fullconect_parm = \ 46 | net_inf['forward_parm'], net_inf['upsample_parm'], net_inf['fullconect_parm'] 47 | 48 | self.stride_inf = graph_inf['stride_list'] 49 | 50 | self.graph_attention_layer1 = GraphAttentionConvLayer(4, self.forward_parm[0][0], self.forward_parm[0][1]) 51 | self.graph_attention_layer2 = GraphAttentionConvLayer(64, self.forward_parm[1][0], self.forward_parm[1][1]) 52 | self.graph_attention_layer3 = GraphAttentionConvLayer(128, self.forward_parm[2][0], self.forward_parm[2][1]) 53 | self.graph_attention_layer4 = GraphAttentionConvLayer(256, self.forward_parm[3][0], self.forward_parm[3][1]) 54 | 55 | self.gragh_pooling_layer = GraphPoolingLayer() 56 | self.mid_graph_attention_layers = GraphAttentionConvLayer(512, self.forward_parm[-1][0], self.forward_parm[-1][1]) 57 | 58 | self.point_upsample_layer1 = PointUpsampleLayer(512 + 256, self.upsample_parm[3]) 59 | self.point_upsample_layer2 = PointUpsampleLayer(256 + 256, self.upsample_parm[2]) 60 | self.point_upsample_layer3 = PointUpsampleLayer(128 + 256, self.upsample_parm[1]) 61 | self.point_upsample_layer4 = PointUpsampleLayer(64 + 128, self.upsample_parm[0]) 62 | 63 | self.graph_attention_layer_for_featurerefine = GraphAttentionConvLayerforFeatureRefine(self.num_classes) 64 | 65 | self.conv1 = nn.Conv1d(128, 128, 1) 66 | self.bn1 = nn.BatchNorm1d(128) 67 | self.drop = nn.Dropout(0.5) 68 | self.conv2 = nn.Conv1d(128, self.num_classes, 1) 69 | 70 | def forward(self, features, graph_prd, coarse_map): 71 | inif = features[:, :, 0:6] # (x,y,z,r,g,b) 72 | features = features[:, :, 2:] # (z, r, g, b, and (initial geofeatures if possible)) 73 | 74 | feature_prd = [] 75 | 76 | features = self.graph_attention_layer1(graph_prd[0], features) 77 | feature_prd.append(features) 78 | features = self.gragh_pooling_layer(features, coarse_map[0]) 79 | 80 | features = self.graph_attention_layer2(graph_prd[1], features) 81 | feature_prd.append(features) 82 | features = self.gragh_pooling_layer(features, coarse_map[1]) 83 | 84 | features = self.graph_attention_layer3(graph_prd[2], features) 85 | feature_prd.append(features) 86 | features = self.gragh_pooling_layer(features, coarse_map[2]) 87 | 88 | features = self.graph_attention_layer4(graph_prd[3], features) 89 | feature_prd.append(features) 90 | features = self.gragh_pooling_layer(features, coarse_map[3]) 91 | 92 | features = self.mid_graph_attention_layers(graph_prd[-1], features) 93 | 94 | features = self.point_upsample_layer1(graph_prd[3]['vertex'], graph_prd[3 + 1]['vertex'], feature_prd[3], features) 95 | features = self.point_upsample_layer2(graph_prd[2]['vertex'], graph_prd[2 + 1]['vertex'], feature_prd[2], features) 96 | features = self.point_upsample_layer3(graph_prd[1]['vertex'], graph_prd[1 + 1]['vertex'], feature_prd[1], features) 97 | features = self.point_upsample_layer4(graph_prd[0]['vertex'], graph_prd[0 + 1]['vertex'], feature_prd[0], features) 98 | 99 | features = features.permute(0, 2, 1) 100 | features = F.relu(self.bn1(self.conv1(features))) 101 | features = self.drop(features) 102 | features = self.conv2(features) 103 | features = features.permute(0, 2, 1) 104 | 105 | features = self.graph_attention_layer_for_featurerefine(inif, features, graph_prd[0]['adjids']) 106 | features = F.log_softmax(features, dim=2) 107 | return features 108 | 109 | 110 | def build_graph_pyramid(xyz, graph_inf): 111 | """ Builds a pyramid of graphs and pooling operations corresponding to progressively coarsened point cloud. 112 | Inputs: 113 | xyz: (batchsize, num_point, nfeature) 114 | graph_inf: parameters for graph building (see run.py) 115 | Outputs: 116 | graph_prd: graph pyramid contains the vertices and their edges at each layer 117 | coarse_map: record the corresponding relation between two close graph layers (for graph coarseing/pooling) 118 | """ 119 | stride_list, radius_list, maxsample_list = graph_inf['stride_list'], graph_inf['radius_list'], graph_inf['maxsample_list'] 120 | 121 | graph_prd = [] 122 | graph = {} 123 | coarse_map = [] 124 | 125 | xyz = xyz.contiguous() 126 | ids = pointnet2_utils.ball_query(radius_list[0], maxsample_list[0], xyz, xyz) 127 | graph['vertex'], graph['adjids'] = xyz, ids 128 | graph_prd.append(graph.copy()) 129 | 130 | for stride, radius, maxsample in zip(stride_list, radius_list[1:], maxsample_list[1:]): 131 | xyz, coarse_map_ids = graph_coarse(xyz, ids, stride) 132 | coarse_map.append(coarse_map_ids.int()) 133 | ids = pointnet2_utils.ball_query(radius, maxsample, xyz, xyz) 134 | graph['vertex'], graph['adjids'] = xyz, ids 135 | graph_prd.append(graph.copy()) 136 | 137 | return graph_prd, coarse_map 138 | 139 | if __name__ == '__main__': 140 | import os 141 | import torch 142 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 143 | input = torch.randn((8,6,4096)).permute(0, 2, 1).cuda() 144 | graph_prd, coarse_map = build_graph_pyramid(input[:, :, :3], graph_inf) 145 | model = GACNet(50, graph_inf, net_inf).cuda() 146 | logits = model(input, graph_prd, coarse_map) 147 | 148 | -------------------------------------------------------------------------------- /model/gacnet_utils.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import pointnet2_ops_lib.pointnet2_ops.pointnet2_utils as pointnet2_utils 7 | 8 | 9 | gac_par = [ 10 | [32, 16], #MLP for xyz 11 | [16, 16], #MLP for feature 12 | [64] #hidden node of MLP for mergering 13 | ] 14 | 15 | 16 | class MLP1D(nn.Module): 17 | def __init__(self, inchannel, mlp_1d): 18 | super(MLP1D, self).__init__() 19 | self.mlp_conv1ds = nn.ModuleList() 20 | self.mlp_bn1ds = nn.ModuleList() 21 | self.mlp_1d = mlp_1d 22 | last_channel = inchannel 23 | for i, outchannel in enumerate(self.mlp_1d): 24 | self.mlp_conv1ds.append(nn.Conv1d(last_channel, outchannel, 1)) 25 | self.mlp_bn1ds.append(nn.BatchNorm1d(outchannel)) 26 | last_channel = outchannel 27 | 28 | def forward(self, features): 29 | features = features.permute(0, 2, 1) 30 | for i, conv in enumerate(self.mlp_conv1ds): 31 | bn = self.mlp_bn1ds[i] 32 | features = F.relu(bn(conv(features))) 33 | features = features.permute(0, 2, 1) 34 | return features 35 | 36 | class MLP2D(nn.Module): 37 | def __init__(self, inchannel, mlp_2d): 38 | super(MLP2D, self).__init__() 39 | self.mlp_conv2ds = nn.ModuleList() 40 | self.mlp_bn2ds = nn.ModuleList() 41 | self.mlp_2d = mlp_2d 42 | last_channel = inchannel 43 | for i, outchannel in enumerate(self.mlp_2d): 44 | self.mlp_conv2ds.append(nn.Conv2d(last_channel, outchannel, 1)) 45 | self.mlp_bn2ds.append(nn.BatchNorm2d(outchannel)) 46 | last_channel = outchannel 47 | 48 | def forward(self, features): 49 | features = features.permute(0, 3, 2, 1) 50 | for i, conv in enumerate(self.mlp_conv2ds): 51 | bn = self.mlp_bn2ds[i] 52 | features = F.relu(bn(conv(features))) 53 | features = features.permute(0, 3, 2, 1) 54 | return features 55 | 56 | 57 | 58 | class CoeffGeneration(nn.Module): 59 | def __init__(self, inchannel): 60 | super(CoeffGeneration, self).__init__() 61 | self.inchannel = inchannel 62 | self.MlP2D = MLP2D(self.inchannel, gac_par[1]) 63 | self.MlP2D_2 = MLP2D(32, gac_par[2]) 64 | self.conv1 = nn.Conv2d(64, 16 + self.inchannel, 1) 65 | self.bn1 = nn.BatchNorm2d(16 + self.inchannel) 66 | 67 | def forward(self, grouped_features, features, grouped_xyz, mode='with_feature'): 68 | if mode == 'with_feature': 69 | coeff = grouped_features - features.unsqueeze(dim=2) 70 | coeff = self.MlP2D(coeff) 71 | coeff = torch.cat((grouped_xyz, coeff), dim=-1) 72 | if mode == 'edge_only': 73 | coeff = grouped_xyz 74 | if mode == 'feature_only': 75 | coeff = grouped_features - features.unsqueeze(dim=2) 76 | coeff = self.MlP2D(coeff) 77 | 78 | grouped_features = torch.cat((grouped_xyz, grouped_features), dim=-1) 79 | coeff = self.MlP2D_2(coeff) 80 | coeff = coeff.permute(0, 3, 2, 1) 81 | coeff = self.bn1(self.conv1(coeff)) 82 | coeff = coeff.permute(0, 3, 2, 1) 83 | coeff = F.softmax(coeff, dim=2) 84 | 85 | grouped_features = coeff * grouped_features 86 | grouped_features = torch.sum(grouped_features, dim=2) 87 | return grouped_features 88 | 89 | class GraphAttentionConvLayer(nn.Module): 90 | def __init__(self, feature_inchannel, mlp1, mlp2): 91 | super(GraphAttentionConvLayer, self).__init__() 92 | self.mlp1 = mlp1 93 | self.mlp2 = mlp2 94 | self.feature_inchannel = feature_inchannel 95 | self.edge_mapping = MLP2D(3, gac_par[0]) 96 | self.MLP1D = MLP1D(self.feature_inchannel, self.mlp1) 97 | self.MLP1D_2 = MLP1D(2 * self.mlp1[-1] + 16, self.mlp2) 98 | self.coeff_generation = CoeffGeneration(self.mlp1[-1]) 99 | 100 | def forward(self, graph, features): 101 | xyz, ids = graph['vertex'], graph['adjids'] 102 | grouped_xyz = pointnet2_utils.grouping_operation(xyz.permute(0, 2, 1).contiguous(), ids).permute(0, 2, 3, 1).contiguous() 103 | grouped_xyz -= xyz.unsqueeze(dim=2) 104 | grouped_xyz = self.edge_mapping(grouped_xyz) 105 | 106 | features = self.MLP1D(features) 107 | grouped_features = pointnet2_utils.grouping_operation(features.permute(0, 2, 1).contiguous(), ids).permute(0, 2, 3, 1).contiguous() 108 | 109 | new_features = self.coeff_generation(grouped_features, features, grouped_xyz) 110 | if self.mlp2 is not None and features is not None: 111 | new_features = torch.cat((features, new_features), dim=-1) 112 | new_features = self.MLP1D_2(new_features) 113 | return new_features 114 | 115 | 116 | class GraphPoolingLayer(nn.Module): 117 | def __init__(self, pooling='max'): 118 | super(GraphPoolingLayer, self).__init__() 119 | self.pooling = pooling 120 | 121 | def forward(self, features, coarse_map): 122 | grouped_features = pointnet2_utils.grouping_operation(features.permute(0, 2, 1).contiguous(), coarse_map).permute(0, 2, 3, 1).contiguous() 123 | if self.pooling == 'max': 124 | new_features = torch.max(grouped_features, dim=2)[0] 125 | return new_features 126 | 127 | class PointUpsampleLayer(nn.Module): 128 | def __init__(self, inchannel, upsample_parm): 129 | super(PointUpsampleLayer, self).__init__() 130 | self.inchannel = inchannel 131 | self.upsample_list = upsample_parm 132 | self.MLP1D = MLP1D(self.inchannel, self.upsample_list) 133 | 134 | 135 | def forward(self, xyz1, xyz2, features1, features2): 136 | B, N, C = xyz1.shape 137 | _, S, _ = xyz2.shape 138 | 139 | features1 = features1.permute(0, 2, 1) 140 | features2 = features2.permute(0, 2, 1) 141 | assert xyz1.is_contiguous() 142 | assert xyz2.is_contiguous() 143 | dist, idx = pointnet2_utils.three_nn(xyz1, xyz2) 144 | dist_recip = 1.0 / (dist + 1e-8) 145 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 146 | weight = dist_recip / norm 147 | assert features2.is_contiguous() 148 | interpolated_features = pointnet2_utils.three_interpolate(features2, idx, weight) 149 | 150 | new_features = torch.cat((interpolated_features, features1), dim=1).permute(0, 2, 1) 151 | 152 | if self.upsample_list is not None: 153 | new_features = self.MLP1D(new_features) 154 | return new_features 155 | 156 | class GraphAttentionConvLayerforFeatureRefine(nn.Module): 157 | def __init__(self, inchannel): 158 | super(GraphAttentionConvLayerforFeatureRefine, self).__init__() 159 | self.inchannel = inchannel 160 | self.edge_mapping = MLP2D(6, gac_par[0]) 161 | self.coeff_generation = CoeffGeneration(self.inchannel) 162 | self.conv1 = nn.Conv1d(16 + 2 * self.inchannel, self.inchannel, 1) 163 | 164 | def forward(self, initf, features, ids): 165 | initf = initf.permute(0, 2, 1).contiguous() 166 | grouped_initf = pointnet2_utils.grouping_operation(initf, ids).permute(0, 2, 3, 1).contiguous() 167 | grouped_initf -= initf.permute(0, 2, 1).unsqueeze(dim=2) 168 | grouped_initf = self.edge_mapping(grouped_initf) 169 | features = features.permute(0, 2, 1).contiguous() 170 | grouped_features = pointnet2_utils.grouping_operation(features, ids).permute(0, 2, 3, 1).contiguous() 171 | features = features.permute(0, 2, 1) 172 | 173 | new_features = self.coeff_generation(grouped_features, features, grouped_initf) 174 | 175 | new_features = torch.cat((features, new_features), dim=-1) 176 | new_features = new_features.permute(0, 2, 1) 177 | new_features = self.conv1(new_features) 178 | new_features = new_features.permute(0, 2, 1) 179 | 180 | return new_features 181 | 182 | 183 | 184 | def graph_coarse(xyz_org, ids_full, stride): 185 | """ Coarse graph with down sampling, and find their corresponding vertexes at previous (or father) level. """ 186 | if stride > 1: 187 | sub_pts_ids = pointnet2_utils.furthest_point_sample(xyz_org, stride) 188 | sub_xyz = pointnet2_utils.gather_operation(xyz_org.permute(0, 2, 1).contiguous(), sub_pts_ids).permute(0, 2, 1).contiguous() 189 | 190 | ids = pointnet2_utils.grouping_operation(ids_full.permute(0, 2, 1).float().contiguous(), 191 | sub_pts_ids.unsqueeze(dim=-1).contiguous()).long().squeeze(-1).permute(0, 2, 1).contiguous() # (batchsize, num_point, maxsample) 192 | 193 | return sub_xyz, ids 194 | else: 195 | return xyz_org, ids_full 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft pointnet2_ops/_ext-src 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | at::Tensor cube_select_sift(at::Tensor xyz, const float radius); 7 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | std::vector dist_nn(at::Tensor unknowns, at::Tensor knows); 8 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 9 | at::Tensor weight); 10 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 11 | at::Tensor weight, const int m); 12 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | void cube_select_sift_kernel_wrapper(int b, int n, float radius, 8 | const float *xyz, int *idx_out); 9 | 10 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 11 | const int nsample) { 12 | CHECK_CONTIGUOUS(new_xyz); 13 | CHECK_CONTIGUOUS(xyz); 14 | CHECK_IS_FLOAT(new_xyz); 15 | CHECK_IS_FLOAT(xyz); 16 | 17 | if (new_xyz.is_cuda()) { 18 | CHECK_CUDA(xyz); 19 | } 20 | 21 | at::Tensor idx = 22 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 23 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 24 | 25 | if (new_xyz.is_cuda()) { 26 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 27 | radius, nsample, new_xyz.data_ptr(), 28 | xyz.data_ptr(), idx.data_ptr()); 29 | } else { 30 | AT_ASSERT(false, "CPU not supported"); 31 | } 32 | 33 | return idx; 34 | } 35 | 36 | 37 | at::Tensor cube_select_sift(at::Tensor xyz, const float radius) { 38 | CHECK_CONTIGUOUS(xyz); 39 | CHECK_IS_FLOAT(xyz); 40 | 41 | if (xyz.is_cuda()) { 42 | CHECK_CUDA(xyz); 43 | } 44 | 45 | at::Tensor idx_out = 46 | torch::zeros({xyz.size(0), xyz.size(1), 8}, 47 | at::device(xyz.device()).dtype(at::ScalarType::Int)); 48 | 49 | if (xyz.is_cuda()) { 50 | cube_select_sift_kernel_wrapper(xyz.size(0), xyz.size(1), radius, 51 | xyz.data_ptr(), idx_out.data_ptr()); 52 | } else { 53 | AT_ASSERT(false, "CPU not supported"); 54 | } 55 | 56 | return idx_out; 57 | } 58 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | 56 | 57 | 58 | __global__ void cube_select_sift_kernel(int b, int n,float radius, const float*__restrict__ xyz, int*__restrict__ idx_out) { 59 | int batch_idx = blockIdx.x; 60 | xyz += batch_idx * n * 3; 61 | idx_out += batch_idx * n * 8; 62 | float temp_dist[8]; 63 | float judge_dist = radius * radius; 64 | for(int i = threadIdx.x; i < n;i += blockDim.x) { 65 | float x = xyz[i * 3]; 66 | float y = xyz[i * 3 + 1]; 67 | float z = xyz[i * 3 + 2]; 68 | for(int j = 0;j < 8;j ++) { 69 | temp_dist[j] = 1e8; 70 | idx_out[i * 8 + j] = i; // if not found, just return itself.. 71 | } 72 | for(int j = 0;j < n;j ++) { 73 | if(i == j) continue; 74 | float tx = xyz[j * 3]; 75 | float ty = xyz[j * 3 + 1]; 76 | float tz = xyz[j * 3 + 2]; 77 | float dist = (x - tx) * (x - tx) + (y - ty) * (y - ty) + (z - tz) * (z - tz); 78 | if(dist > judge_dist) continue; 79 | int _x = (tx > x); 80 | int _y = (ty > y); 81 | int _z = (tz > z); 82 | int temp_idx = _x * 4 + _y * 2 + _z; 83 | if(dist < temp_dist[temp_idx]) { 84 | idx_out[i * 8 + temp_idx] = j; 85 | temp_dist[temp_idx] = dist; 86 | } 87 | } 88 | } 89 | } 90 | 91 | void cube_select_sift_kernel_wrapper(int b, int n, float radius, 92 | const float *xyz, int *idx_out) { 93 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 94 | cube_select_sift_kernel<<>>( 95 | b, n, radius, xyz, idx_out); 96 | 97 | CUDA_CHECK_ERRORS(); 98 | } 99 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("dist_nn", &dist_nn); 13 | m.def("three_interpolate", &three_interpolate); 14 | m.def("three_interpolate_grad", &three_interpolate_grad); 15 | 16 | m.def("ball_query", &ball_query); 17 | m.def("cube_select_sift", &cube_select_sift); 18 | 19 | m.def("group_points", &group_points); 20 | m.def("group_points_grad", &group_points_grad); 21 | } 22 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void dist_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 7 | const float *known, float *dist2); 8 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 9 | const float *points, const int *idx, 10 | const float *weight, float *out); 11 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 12 | const float *grad_out, 13 | const int *idx, const float *weight, 14 | float *grad_points); 15 | 16 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 17 | CHECK_CONTIGUOUS(unknowns); 18 | CHECK_CONTIGUOUS(knows); 19 | CHECK_IS_FLOAT(unknowns); 20 | CHECK_IS_FLOAT(knows); 21 | 22 | if (unknowns.is_cuda()) { 23 | CHECK_CUDA(knows); 24 | } 25 | 26 | at::Tensor idx = 27 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 28 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 29 | at::Tensor dist2 = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 32 | 33 | if (unknowns.is_cuda()) { 34 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 35 | unknowns.data_ptr(), knows.data_ptr(), 36 | dist2.data_ptr(), idx.data_ptr()); 37 | } else { 38 | AT_ASSERT(false, "CPU not supported"); 39 | } 40 | 41 | return {dist2, idx}; 42 | } 43 | 44 | 45 | std::vector dist_nn(at::Tensor unknowns, at::Tensor knows) { 46 | CHECK_CONTIGUOUS(unknowns); 47 | CHECK_CONTIGUOUS(knows); 48 | CHECK_IS_FLOAT(unknowns); 49 | CHECK_IS_FLOAT(knows); 50 | 51 | if (unknowns.is_cuda()) { 52 | CHECK_CUDA(knows); 53 | } 54 | 55 | at::Tensor dist2 = 56 | torch::zeros({unknowns.size(0), unknowns.size(1), knows.size(1)}, 57 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 58 | 59 | if (unknowns.is_cuda()) { 60 | dist_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 61 | unknowns.data_ptr(), knows.data_ptr(), 62 | dist2.data_ptr()); 63 | } else { 64 | AT_ASSERT(false, "CPU not supported"); 65 | } 66 | 67 | return {dist2}; 68 | } 69 | 70 | 71 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 72 | at::Tensor weight) { 73 | CHECK_CONTIGUOUS(points); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(points); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (points.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 87 | at::device(points.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (points.is_cuda()) { 90 | three_interpolate_kernel_wrapper( 91 | points.size(0), points.size(1), points.size(2), idx.size(1), 92 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 93 | output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 101 | at::Tensor weight, const int m) { 102 | CHECK_CONTIGUOUS(grad_out); 103 | CHECK_CONTIGUOUS(idx); 104 | CHECK_CONTIGUOUS(weight); 105 | CHECK_IS_FLOAT(grad_out); 106 | CHECK_IS_INT(idx); 107 | CHECK_IS_FLOAT(weight); 108 | 109 | if (grad_out.is_cuda()) { 110 | CHECK_CUDA(idx); 111 | CHECK_CUDA(weight); 112 | } 113 | 114 | at::Tensor output = 115 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 116 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 117 | 118 | if (grad_out.is_cuda()) { 119 | three_interpolate_grad_kernel_wrapper( 120 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 121 | grad_out.data_ptr(), idx.data_ptr(), 122 | weight.data_ptr(), output.data_ptr()); 123 | } else { 124 | AT_ASSERT(false, "CPU not supported"); 125 | } 126 | 127 | return output; 128 | } 129 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | 71 | // input: unknown(b, n, 3) known(b, m, 3) 72 | // output: dist2(b, n, m) 73 | __global__ void dist_nn_kernel(int b, int n, int m, 74 | const float *__restrict__ unknown, 75 | const float *__restrict__ known, 76 | float *__restrict__ dist2) { 77 | int batch_index = blockIdx.x; 78 | unknown += batch_index * n * 3; 79 | known += batch_index * m * 3; 80 | dist2 += batch_index * n * m; 81 | 82 | double best1 = 1e40; 83 | int index = threadIdx.x; 84 | int stride = blockDim.x; 85 | for (int j = index; j < n; j += stride) { 86 | float ux = unknown[j * 3 + 0]; 87 | float uy = unknown[j * 3 + 1]; 88 | float uz = unknown[j * 3 + 2]; 89 | 90 | for (int k = 0; k < m; ++k) { 91 | float x = known[k * 3 + 0]; 92 | float y = known[k * 3 + 1]; 93 | float z = known[k * 3 + 2]; 94 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 95 | 96 | dist2[j + k] = d;} 97 | } 98 | } 99 | 100 | void dist_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 101 | const float *known, float *dist2) { 102 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 103 | dist_nn_kernel<<>>(b, n, m, unknown, known, 104 | dist2); 105 | 106 | CUDA_CHECK_ERRORS(); 107 | } 108 | 109 | 110 | 111 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 112 | // output: out(b, c, n) 113 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 114 | const float *__restrict__ points, 115 | const int *__restrict__ idx, 116 | const float *__restrict__ weight, 117 | float *__restrict__ out) { 118 | int batch_index = blockIdx.x; 119 | points += batch_index * m * c; 120 | 121 | idx += batch_index * n * 3; 122 | weight += batch_index * n * 3; 123 | 124 | out += batch_index * n * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 140 | points[l * m + i3] * w3; 141 | } 142 | } 143 | 144 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 145 | const float *points, const int *idx, 146 | const float *weight, float *out) { 147 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 148 | three_interpolate_kernel<<>>( 149 | b, c, m, n, points, idx, weight, out); 150 | 151 | CUDA_CHECK_ERRORS(); 152 | } 153 | 154 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 155 | // output: grad_points(b, c, m) 156 | 157 | __global__ void three_interpolate_grad_kernel( 158 | int b, int c, int n, int m, const float *__restrict__ grad_out, 159 | const int *__restrict__ idx, const float *__restrict__ weight, 160 | float *__restrict__ grad_points) { 161 | int batch_index = blockIdx.x; 162 | grad_out += batch_index * n * c; 163 | idx += batch_index * n * 3; 164 | weight += batch_index * n * 3; 165 | grad_points += batch_index * m * c; 166 | 167 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 168 | const int stride = blockDim.y * blockDim.x; 169 | for (int i = index; i < c * n; i += stride) { 170 | const int l = i / n; 171 | const int j = i % n; 172 | float w1 = weight[j * 3 + 0]; 173 | float w2 = weight[j * 3 + 1]; 174 | float w3 = weight[j * 3 + 2]; 175 | 176 | int i1 = idx[j * 3 + 0]; 177 | int i2 = idx[j * 3 + 1]; 178 | int i3 = idx[j * 3 + 2]; 179 | 180 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 181 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 182 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 183 | } 184 | } 185 | 186 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 187 | const float *grad_out, 188 | const int *idx, const float *weight, 189 | float *grad_points) { 190 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 191 | three_interpolate_grad_kernel<<>>( 192 | b, c, n, m, grad_out, idx, weight, grad_points); 193 | 194 | CUDA_CHECK_ERRORS(); 195 | } 196 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pointnet2_ops import pointnet2_utils 7 | 8 | 9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True): 10 | layers = [] 11 | for i in range(1, len(mlp_spec)): 12 | layers.append( 13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) 14 | ) 15 | if bn: 16 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 17 | layers.append(nn.ReLU(True)) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class _PointnetSAModuleBase(nn.Module): 23 | def __init__(self): 24 | super(_PointnetSAModuleBase, self).__init__() 25 | self.npoint = None 26 | self.groupers = None 27 | self.mlps = None 28 | 29 | def forward( 30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor] 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | r""" 33 | Parameters 34 | ---------- 35 | xyz : torch.Tensor 36 | (B, N, 3) tensor of the xyz coordinates of the features 37 | features : torch.Tensor 38 | (B, C, N) tensor of the descriptors of the the features 39 | 40 | Returns 41 | ------- 42 | new_xyz : torch.Tensor 43 | (B, npoint, 3) tensor of the new features' xyz 44 | new_features : torch.Tensor 45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 46 | """ 47 | 48 | new_features_list = [] 49 | 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | new_xyz = ( 52 | pointnet2_utils.gather_operation( 53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 54 | ) 55 | .transpose(1, 2) 56 | .contiguous() 57 | if self.npoint is not None 58 | else None 59 | ) 60 | 61 | for i in range(len(self.groupers)): 62 | new_features = self.groupers[i]( 63 | xyz, new_xyz, features 64 | ) # (B, C, npoint, nsample) 65 | 66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 67 | new_features = F.max_pool2d( 68 | new_features, kernel_size=[1, new_features.size(3)] 69 | ) # (B, mlp[-1], npoint, 1) 70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 71 | 72 | new_features_list.append(new_features) 73 | 74 | return new_xyz, torch.cat(new_features_list, dim=1) 75 | 76 | 77 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 78 | r"""Pointnet set abstrction layer with multiscale grouping 79 | 80 | Parameters 81 | ---------- 82 | npoint : int 83 | Number of features 84 | radii : list of float32 85 | list of radii to group with 86 | nsamples : list of int32 87 | Number of samples in each ball query 88 | mlps : list of list of int32 89 | Spec of the pointnet before the global max_pool for each scale 90 | bn : bool 91 | Use batchnorm 92 | """ 93 | 94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 96 | super(PointnetSAModuleMSG, self).__init__() 97 | 98 | assert len(radii) == len(nsamples) == len(mlps) 99 | 100 | self.npoint = npoint 101 | self.groupers = nn.ModuleList() 102 | self.mlps = nn.ModuleList() 103 | for i in range(len(radii)): 104 | radius = radii[i] 105 | nsample = nsamples[i] 106 | self.groupers.append( 107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 108 | if npoint is not None 109 | else pointnet2_utils.GroupAll(use_xyz) 110 | ) 111 | mlp_spec = mlps[i] 112 | if use_xyz: 113 | mlp_spec[0] += 3 114 | 115 | self.mlps.append(build_shared_mlp(mlp_spec, bn)) 116 | 117 | 118 | class PointnetSAModule(PointnetSAModuleMSG): 119 | r"""Pointnet set abstrction layer 120 | 121 | Parameters 122 | ---------- 123 | npoint : int 124 | Number of features 125 | radius : float 126 | Radius of ball 127 | nsample : int 128 | Number of samples in the ball query 129 | mlp : list 130 | Spec of the pointnet before the global max_pool 131 | bn : bool 132 | Use batchnorm 133 | """ 134 | 135 | def __init__( 136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 137 | ): 138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 139 | super(PointnetSAModule, self).__init__( 140 | mlps=[mlp], 141 | npoint=npoint, 142 | radii=[radius], 143 | nsamples=[nsample], 144 | bn=bn, 145 | use_xyz=use_xyz, 146 | ) 147 | 148 | 149 | class PointnetFPModule(nn.Module): 150 | r"""Propigates the features of one set to another 151 | 152 | Parameters 153 | ---------- 154 | mlp : list 155 | Pointnet module parameters 156 | bn : bool 157 | Use batchnorm 158 | """ 159 | 160 | def __init__(self, mlp, bn=True): 161 | # type: (PointnetFPModule, List[int], bool) -> None 162 | super(PointnetFPModule, self).__init__() 163 | self.mlp = build_shared_mlp(mlp, bn=bn) 164 | 165 | def forward(self, unknown, known, unknow_feats, known_feats): 166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 167 | r""" 168 | Parameters 169 | ---------- 170 | unknown : torch.Tensor 171 | (B, n, 3) tensor of the xyz positions of the unknown features 172 | known : torch.Tensor 173 | (B, m, 3) tensor of the xyz positions of the known features 174 | unknow_feats : torch.Tensor 175 | (B, C1, n) tensor of the features to be propigated to 176 | known_feats : torch.Tensor 177 | (B, C2, m) tensor of features to be propigated 178 | 179 | Returns 180 | ------- 181 | new_features : torch.Tensor 182 | (B, mlp[-1], n) tensor of the features of the unknown features 183 | """ 184 | 185 | if known is not None: 186 | dist, idx = pointnet2_utils.three_nn(unknown, known) 187 | dist_recip = 1.0 / (dist + 1e-8) 188 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 189 | weight = dist_recip / norm 190 | 191 | interpolated_feats = pointnet2_utils.three_interpolate( 192 | known_feats, idx, weight 193 | ) 194 | else: 195 | interpolated_feats = known_feats.expand( 196 | *(known_feats.size()[0:2] + [unknown.size(1)]) 197 | ) 198 | 199 | if unknow_feats is not None: 200 | new_features = torch.cat( 201 | [interpolated_feats, unknow_feats], dim=1 202 | ) # (B, C2 + C1, n) 203 | else: 204 | new_features = interpolated_feats 205 | 206 | new_features = new_features.unsqueeze(-1) 207 | new_features = self.mlp(new_features) 208 | 209 | return new_features.squeeze(-1) 210 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | from torch.autograd import Function 5 | from typing import * 6 | 7 | try: 8 | import pointnet2_ops._ext as _ext 9 | except ImportError: 10 | from torch.utils.cpp_extension import load 11 | import glob 12 | import os.path as osp 13 | import os 14 | 15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") 16 | 17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") 18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 19 | osp.join(_ext_src_root, "src", "*.cu") 20 | ) 21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 22 | 23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 24 | _ext = load( 25 | "_ext", 26 | sources=_ext_sources, 27 | extra_include_paths=[osp.join(_ext_src_root, "include")], 28 | extra_cflags=["-O3"], 29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], 30 | with_cuda=True, 31 | ) 32 | 33 | 34 | class FurthestPointSampling(Function): 35 | @staticmethod 36 | def forward(ctx, xyz, npoint): 37 | # type: (Any, torch.Tensor, int) -> torch.Tensor 38 | r""" 39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 40 | minimum distance 41 | 42 | Parameters 43 | ---------- 44 | xyz : torch.Tensor 45 | (B, N, 3) tensor where N > npoint 46 | npoint : int32 47 | number of features in the sampled set 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, npoint) tensor containing the set 53 | """ 54 | out = _ext.furthest_point_sampling(xyz, npoint) 55 | 56 | ctx.mark_non_differentiable(out) 57 | 58 | return out 59 | 60 | @staticmethod 61 | def backward(ctx, grad_out): 62 | return () 63 | 64 | 65 | furthest_point_sample = FurthestPointSampling.apply 66 | 67 | 68 | class GatherOperation(Function): 69 | @staticmethod 70 | def forward(ctx, features, idx): 71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 72 | r""" 73 | 74 | Parameters 75 | ---------- 76 | features : torch.Tensor 77 | (B, C, N) tensor 78 | 79 | idx : torch.Tensor 80 | (B, npoint) tensor of the features to gather 81 | 82 | Returns 83 | ------- 84 | torch.Tensor 85 | (B, C, npoint) tensor 86 | """ 87 | 88 | ctx.save_for_backward(idx, features) 89 | 90 | return _ext.gather_points(features, idx) 91 | 92 | @staticmethod 93 | def backward(ctx, grad_out): 94 | idx, features = ctx.saved_tensors 95 | N = features.size(2) 96 | 97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 98 | return grad_features, None 99 | 100 | 101 | gather_operation = GatherOperation.apply 102 | 103 | 104 | class ThreeNN(Function): 105 | @staticmethod 106 | def forward(ctx, unknown, known): 107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 108 | r""" 109 | Find the three nearest neighbors of unknown in known 110 | Parameters 111 | ---------- 112 | unknown : torch.Tensor 113 | (B, n, 3) tensor of known features 114 | known : torch.Tensor 115 | (B, m, 3) tensor of unknown features 116 | 117 | Returns 118 | ------- 119 | dist : torch.Tensor 120 | (B, n, 3) l2 distance to the three nearest neighbors 121 | idx : torch.Tensor 122 | (B, n, 3) index of 3 nearest neighbors 123 | """ 124 | dist2, idx = _ext.three_nn(unknown, known) 125 | dist = torch.sqrt(dist2) 126 | 127 | ctx.mark_non_differentiable(dist, idx) 128 | 129 | return dist, idx 130 | 131 | @staticmethod 132 | def backward(ctx, grad_dist, grad_idx): 133 | return () 134 | 135 | 136 | three_nn = ThreeNN.apply 137 | 138 | 139 | 140 | class DistNN(Function): 141 | @staticmethod 142 | def forward(ctx, unknown, known): 143 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 144 | r""" 145 | Find the three nearest neighbors of unknown in known 146 | Parameters 147 | ---------- 148 | unknown : torch.Tensor 149 | (B, n, 3) tensor of known features 150 | known : torch.Tensor 151 | (B, m, 3) tensor of unknown features 152 | 153 | Returns 154 | ------- 155 | dist : torch.Tensor 156 | (B, n, 3) l2 distance to the three nearest neighbors 157 | idx : torch.Tensor 158 | (B, n, 3) index of 3 nearest neighbors 159 | """ 160 | dist2 = _ext.dist_nn(unknown, known) 161 | # dist = torch.sqrt(dist2[0]) 162 | 163 | # ctx.mark_non_differentiable(dist[0]) 164 | ctx.mark_non_differentiable(dist2[0]) 165 | 166 | return dist2[0] 167 | 168 | @staticmethod 169 | def backward(ctx, grad_dist): 170 | return () 171 | 172 | 173 | dist_nn = DistNN.apply 174 | 175 | class ThreeInterpolate(Function): 176 | @staticmethod 177 | def forward(ctx, features, idx, weight): 178 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 179 | r""" 180 | Performs weight linear interpolation on 3 features 181 | Parameters 182 | ---------- 183 | features : torch.Tensor 184 | (B, c, m) Features descriptors to be interpolated from 185 | idx : torch.Tensor 186 | (B, n, 3) three nearest neighbors of the target features in features 187 | weight : torch.Tensor 188 | (B, n, 3) weights 189 | 190 | Returns 191 | ------- 192 | torch.Tensor 193 | (B, c, n) tensor of the interpolated features 194 | """ 195 | ctx.save_for_backward(idx, weight, features) 196 | 197 | return _ext.three_interpolate(features, idx, weight) 198 | 199 | @staticmethod 200 | def backward(ctx, grad_out): 201 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 202 | r""" 203 | Parameters 204 | ---------- 205 | grad_out : torch.Tensor 206 | (B, c, n) tensor with gradients of ouputs 207 | 208 | Returns 209 | ------- 210 | grad_features : torch.Tensor 211 | (B, c, m) tensor with gradients of features 212 | 213 | None 214 | 215 | None 216 | """ 217 | idx, weight, features = ctx.saved_tensors 218 | m = features.size(2) 219 | 220 | grad_features = _ext.three_interpolate_grad( 221 | grad_out.contiguous(), idx, weight, m 222 | ) 223 | 224 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) 225 | 226 | 227 | three_interpolate = ThreeInterpolate.apply 228 | 229 | 230 | class GroupingOperation(Function): 231 | @staticmethod 232 | def forward(ctx, features, idx): 233 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 234 | r""" 235 | 236 | Parameters 237 | ---------- 238 | features : torch.Tensor 239 | (B, C, N) tensor of features to group 240 | idx : torch.Tensor 241 | (B, npoint, nsample) tensor containing the indicies of features to group with 242 | 243 | Returns 244 | ------- 245 | torch.Tensor 246 | (B, C, npoint, nsample) tensor 247 | """ 248 | ctx.save_for_backward(idx, features) 249 | 250 | return _ext.group_points(features, idx) 251 | 252 | @staticmethod 253 | def backward(ctx, grad_out): 254 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 255 | r""" 256 | 257 | Parameters 258 | ---------- 259 | grad_out : torch.Tensor 260 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 261 | 262 | Returns 263 | ------- 264 | torch.Tensor 265 | (B, C, N) gradient of the features 266 | None 267 | """ 268 | idx, features = ctx.saved_tensors 269 | N = features.size(2) 270 | 271 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 272 | 273 | return grad_features, torch.zeros_like(idx) 274 | 275 | 276 | grouping_operation = GroupingOperation.apply 277 | 278 | 279 | class BallQuery(Function): 280 | @staticmethod 281 | def forward(ctx, radius, nsample, xyz, new_xyz): 282 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 283 | r""" 284 | 285 | Parameters 286 | ---------- 287 | radius : float 288 | radius of the balls 289 | nsample : int 290 | maximum number of features in the balls 291 | xyz : torch.Tensor 292 | (B, N, 3) xyz coordinates of the features 293 | new_xyz : torch.Tensor 294 | (B, npoint, 3) centers of the ball query 295 | 296 | Returns 297 | ------- 298 | torch.Tensor 299 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 300 | """ 301 | output = _ext.ball_query(new_xyz, xyz, radius, nsample) 302 | 303 | ctx.mark_non_differentiable(output) 304 | 305 | return output 306 | 307 | @staticmethod 308 | def backward(ctx, grad_out): 309 | return () 310 | 311 | 312 | ball_query = BallQuery.apply 313 | 314 | 315 | 316 | class PointSIFTSelect(Function): 317 | @staticmethod 318 | def forward(ctx, radius, xyz): 319 | # type: (Any, float, torch.Tensor) -> torch.Tensor 320 | r""" 321 | 322 | Parameters 323 | ---------- 324 | radius : float 325 | radius of the balls 326 | nsample : int 327 | maximum number of features in the balls 328 | xyz : torch.Tensor 329 | (B, N, 3) xyz coordinates of the features 330 | new_xyz : torch.Tensor 331 | (B, npoint, 3) centers of the ball query 332 | 333 | Returns 334 | ------- 335 | torch.Tensor 336 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 337 | """ 338 | output = _ext.cube_select_sift(xyz, radius) 339 | 340 | ctx.mark_non_differentiable(output) 341 | 342 | return output 343 | 344 | @staticmethod 345 | def backward(ctx, grad_out): 346 | return () 347 | 348 | 349 | pointsift_select = PointSIFTSelect.apply 350 | 351 | class QueryAndGroup(nn.Module): 352 | r""" 353 | Groups with a ball query of radius 354 | 355 | Parameters 356 | --------- 357 | radius : float32 358 | Radius of ball 359 | nsample : int32 360 | Maximum number of features to gather in the ball 361 | """ 362 | 363 | def __init__(self, radius, nsample, use_xyz=True): 364 | # type: (QueryAndGroup, float, int, bool) -> None 365 | super(QueryAndGroup, self).__init__() 366 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 367 | 368 | def forward(self, xyz, new_xyz, features=None): 369 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 370 | r""" 371 | Parameters 372 | ---------- 373 | xyz : torch.Tensor 374 | xyz coordinates of the features (B, N, 3) 375 | new_xyz : torch.Tensor 376 | centriods (B, npoint, 3) 377 | features : torch.Tensor 378 | Descriptors of the features (B, C, N) 379 | 380 | Returns 381 | ------- 382 | new_features : torch.Tensor 383 | (B, 3 + C, npoint, nsample) tensor 384 | """ 385 | 386 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 387 | xyz_trans = xyz.transpose(1, 2).contiguous() 388 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 389 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 390 | 391 | if features is not None: 392 | grouped_features = grouping_operation(features, idx) 393 | if self.use_xyz: 394 | new_features = torch.cat( 395 | [grouped_xyz, grouped_features], dim=1 396 | ) # (B, C + 3, npoint, nsample) 397 | else: 398 | new_features = grouped_features 399 | else: 400 | assert ( 401 | self.use_xyz 402 | ), "Cannot have not features and not use xyz as a feature!" 403 | new_features = grouped_xyz 404 | 405 | return new_features 406 | 407 | 408 | class GroupAll(nn.Module): 409 | r""" 410 | Groups all features 411 | 412 | Parameters 413 | --------- 414 | """ 415 | 416 | def __init__(self, use_xyz=True): 417 | # type: (GroupAll, bool) -> None 418 | super(GroupAll, self).__init__() 419 | self.use_xyz = use_xyz 420 | 421 | def forward(self, xyz, new_xyz, features=None): 422 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 423 | r""" 424 | Parameters 425 | ---------- 426 | xyz : torch.Tensor 427 | xyz coordinates of the features (B, N, 3) 428 | new_xyz : torch.Tensor 429 | Ignored 430 | features : torch.Tensor 431 | Descriptors of the features (B, C, N) 432 | 433 | Returns 434 | ------- 435 | new_features : torch.Tensor 436 | (B, C + 3, 1, N) tensor 437 | """ 438 | 439 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 440 | if features is not None: 441 | grouped_features = features.unsqueeze(2) 442 | if self.use_xyz: 443 | new_features = torch.cat( 444 | [grouped_xyz, grouped_features], dim=1 445 | ) # (B, 3 + C, 1, N) 446 | else: 447 | new_features = grouped_features 448 | else: 449 | new_features = grouped_xyz 450 | 451 | return new_features 452 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src") 10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 11 | osp.join(_ext_src_root, "src", "*.cu") 12 | ) 13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read()) 18 | 19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 20 | setup( 21 | name="pointnet2_ops", 22 | version=__version__, 23 | author="Erik Wijmans", 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name="pointnet2_ops._ext", 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) 40 | -------------------------------------------------------------------------------- /tool/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WingkeungM/GACNet_PyTorch/c5b3701e7029cf3c6a6ff2622255fc762de6690d/tool/.DS_Store -------------------------------------------------------------------------------- /tool/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | import argparse 6 | 7 | import numpy as np 8 | 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | from data.S3DIS.S3DISDataLoader import S3DISDatasetWholeScene 12 | from model.gacnet import GACNet, build_graph_pyramid 13 | 14 | g_classes = [x.rstrip() for x in open('./data/s3dis/s3dis_names.txt')] 15 | g_class2label = {cls: i for i,cls in enumerate(g_classes)} 16 | g_class2color = {'ceiling': [0,255,0], 17 | 'floor': [0,0,255], 18 | 'wall': [0,255,255], 19 | 'beam': [255,255,0], 20 | 'column': [255,0,255], 21 | 'window': [100,100,255], 22 | 'door': [200,200,100], 23 | 'table': [170,120,200], 24 | 'chair': [255,0,0], 25 | 'sofa': [200,100,100], 26 | 'bookcase': [10,200,100], 27 | 'board': [200,200,200], 28 | 'clutter': [50,50,50]} 29 | g_easy_view_labels = [7,8,9,10,11,1] 30 | g_label2color = {g_classes.index(cls): g_class2color[cls] for cls in g_classes} 31 | 32 | 33 | graph_inf = {'stride_list': [1024, 256, 64, 32], #can be seen as the downsampling rate 34 | 'radius_list': [0.1, 0.2, 0.4, 0.8, 1.6], # radius for neighbor points searching 35 | 'maxsample_list': [12, 21, 21, 21, 12] #number of neighbor points for each layer 36 | } 37 | 38 | # number of units for each mlp layer 39 | forward_parm = [ 40 | [ [32,32,64], [64] ], 41 | [ [64,64,128], [128] ], 42 | [ [128,128,256], [256] ], 43 | [ [256,256,512], [512] ], 44 | [ [256,256], [256] ] 45 | ] 46 | 47 | # for feature interpolation stage 48 | upsample_parm = [ 49 | [128, 128], 50 | [128, 128], 51 | [256, 256], 52 | [256, 256] 53 | ] 54 | 55 | # parameters for fully connection layer 56 | fullconect_parm = 128 57 | 58 | net_inf = {'forward_parm': forward_parm, 59 | 'upsample_parm': upsample_parm, 60 | 'fullconect_parm': fullconect_parm 61 | } 62 | 63 | 64 | 65 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 66 | ROOT_DIR = BASE_DIR 67 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 68 | 69 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter'] 70 | class2label = {cls: i for i,cls in enumerate(classes)} 71 | seg_classes = class2label 72 | seg_label_to_cat = {} 73 | for i,cat in enumerate(seg_classes.keys()): 74 | seg_label_to_cat[i] = cat 75 | 76 | def parse_args(): 77 | '''PARAMETERS''' 78 | parser = argparse.ArgumentParser('Model') 79 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 80 | parser.add_argument('--num_point', type=int, default=4096, help='Point Number [default: 4096]') 81 | parser.add_argument('--batch_size', type=int, default=32, help='batch size in testing [default: 32]') 82 | parser.add_argument('--visual', action='store_true', default=False, help='Whether visualize result [default: False]') 83 | parser.add_argument('--log_dir', type=str, default='logs_GACNet', help='Experiment root') 84 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]') 85 | parser.add_argument('--num_votes', type=int, default=5, help='Aggregate segmentation scores with voting [default: 5]') 86 | parser.add_argument('--num_class', type=int, default=13, help='Class number of the dataset [default: 13]') 87 | parser.add_argument('--datapath', type=str, default='/workspace/datasets/stanford_indoor3d/', help='Path of the sataset') 88 | parser.add_argument('--test_model', type=str, default='/workspace/experiment/checkpoints/GACNet.pth/', help='Path of the test model') 89 | 90 | return parser.parse_args() 91 | 92 | def add_vote(vote_label_pool, point_idx, pred_label, weight): 93 | B = pred_label.shape[0] 94 | N = pred_label.shape[1] 95 | for b in range(B): 96 | for n in range(N): 97 | if weight[b,n]: 98 | vote_label_pool[int(point_idx[b, n]), int(pred_label[b, n])] += 1 99 | return vote_label_pool 100 | 101 | def main(args): 102 | def log_string(str): 103 | logger.info(str) 104 | print(str) 105 | 106 | '''HYPER PARAMETER''' 107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 108 | experiment_dir = './log/' + args.log_dir 109 | visual_dir = experiment_dir + '/visual/' 110 | visual_dir = Path(visual_dir) 111 | visual_dir.mkdir(exist_ok=True) 112 | 113 | '''LOG''' 114 | args = parse_args() 115 | logger = logging.getLogger("Model") 116 | logger.setLevel(logging.INFO) 117 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 118 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir) 119 | file_handler.setLevel(logging.INFO) 120 | file_handler.setFormatter(formatter) 121 | logger.addHandler(file_handler) 122 | log_string('PARAMETER ...') 123 | log_string(args) 124 | 125 | 126 | TEST_DATASET_WHOLE_SCENE = S3DISDatasetWholeScene(root=args.datapath, block_points=args.num_point, split='test', 127 | stride=0.5, block_size=1.0, padding=0.001) 128 | log_string("The number of test data is: %d" %len(TEST_DATASET_WHOLE_SCENE)) 129 | 130 | '''MODEL LOADING''' 131 | classifier = GACNet(args.num_class, graph_inf, net_inf).cuda() 132 | checkpoint = torch.load(args.test_model) 133 | classifier.load_state_dict(checkpoint) 134 | 135 | with torch.no_grad(): 136 | scene_id = TEST_DATASET_WHOLE_SCENE.file_list 137 | scene_id = [x[:-5] for x in scene_id] 138 | num_batches = len(TEST_DATASET_WHOLE_SCENE) 139 | 140 | total_seen_class = [0 for _ in range(args.num_class)] 141 | total_correct_class = [0 for _ in range(args.num_class)] 142 | total_iou_deno_class = [0 for _ in range(args.num_class)] 143 | 144 | total_pre_class = [0 for _ in range(args.num_class)] 145 | precision_per_class = [0 for _ in range(args.num_class)] 146 | recall_per_class = [0 for _ in range(args.num_class)] 147 | 148 | log_string('---- EVALUATION WHOLE SCENE----') 149 | 150 | for batch_idx in range(num_batches): 151 | print("visualize [%d/%d] %s ..." % (batch_idx+1, num_batches, scene_id[batch_idx])) 152 | total_seen_class_tmp = [0 for _ in range(args.num_class)] 153 | total_correct_class_tmp = [0 for _ in range(args.num_class)] 154 | total_iou_deno_class_tmp = [0 for _ in range(args.num_class)] 155 | 156 | total_pre_class_tmp = [0 for _ in range(args.num_class)] 157 | precision_per_class_tmp = [0 for _ in range(args.num_class)] 158 | recall_per_class_tmp = [0 for _ in range(args.num_class)] 159 | if args.visual: 160 | fout = open(os.path.join(visual_dir, scene_id[batch_idx] + '_pred.obj'), 'w') 161 | fout_gt = open(os.path.join(visual_dir, scene_id[batch_idx] + '_gt.obj'), 'w') 162 | 163 | whole_scene_data = TEST_DATASET_WHOLE_SCENE.scene_points_list[batch_idx] 164 | whole_scene_label = TEST_DATASET_WHOLE_SCENE.semantic_labels_list[batch_idx] 165 | vote_label_pool = np.zeros((whole_scene_label.shape[0], args.num_class))#[num_points,num_classes] 166 | for _ in tqdm(range(args.num_votes), total=args.num_votes): 167 | scene_data, scene_label, scene_smpw, scene_point_index = TEST_DATASET_WHOLE_SCENE[batch_idx] 168 | num_blocks = scene_data.shape[0]#有多少个4096个点?? 169 | s_batch_num = (num_blocks + args.batch_size - 1) // args.batch_size#能凑多少个batch_size?? 170 | batch_data = np.zeros((args.batch_size, args.num_point, 9)) 171 | 172 | batch_label = np.zeros((args.batch_size, args.num_point)) 173 | batch_point_index = np.zeros((args.batch_size, args.num_point)) 174 | batch_smpw = np.zeros((args.batch_size, args.num_point)) 175 | for sbatch in range(s_batch_num): 176 | start_idx = sbatch * args.batch_size 177 | end_idx = min((sbatch + 1) * args.batch_size, num_blocks) 178 | real_batch_size = end_idx - start_idx 179 | batch_data[0:real_batch_size, ...] = scene_data[start_idx:end_idx, ...] 180 | batch_label[0:real_batch_size, ...] = scene_label[start_idx:end_idx, ...] 181 | batch_point_index[0:real_batch_size, ...] = scene_point_index[start_idx:end_idx, ...] 182 | batch_smpw[0:real_batch_size, ...] = scene_smpw[start_idx:end_idx, ...] 183 | 184 | torch_data = torch.Tensor(batch_data) 185 | torch_data= torch_data.float().cuda() 186 | 187 | classifier = classifier.eval() 188 | graph_prd, coarse_map = build_graph_pyramid(torch_data[:, :, 0:3], graph_inf) 189 | seg_pred = classifier(torch_data[:, :, :6], graph_prd, coarse_map) 190 | 191 | batch_pred_label = seg_pred.contiguous().cpu().data.max(2)[1].numpy() 192 | 193 | vote_label_pool = add_vote(vote_label_pool, batch_point_index[0:real_batch_size, ...], 194 | batch_pred_label[0:real_batch_size, ...], 195 | batch_smpw[0:real_batch_size, ...]) 196 | 197 | pred_label = np.argmax(vote_label_pool, 1) 198 | 199 | for l in range(args.num_class): 200 | total_seen_class_tmp[l] += np.sum((whole_scene_label == l)) 201 | total_correct_class_tmp[l] += np.sum((pred_label == l) & (whole_scene_label == l)) 202 | total_iou_deno_class_tmp[l] += np.sum(((pred_label == l) | (whole_scene_label == l))) 203 | 204 | total_pre_class[l] += np.sum((pred_label == l)) 205 | 206 | total_seen_class[l] += total_seen_class_tmp[l] 207 | total_correct_class[l] += total_correct_class_tmp[l] 208 | total_iou_deno_class[l] += total_iou_deno_class_tmp[l] 209 | 210 | iou_map = np.array(total_correct_class_tmp) / (np.array(total_iou_deno_class_tmp, dtype=np.float) + 1e-6) 211 | print(iou_map) 212 | arr = np.array(total_seen_class_tmp) 213 | tmp_iou = np.mean(iou_map[arr != 0]) 214 | log_string('Mean IoU of %s: %.4f' % (scene_id[batch_idx], tmp_iou)) 215 | # print('----------------------------') 216 | # 217 | # filename = os.path.join(visual_dir, scene_id[batch_idx] + '.txt') 218 | # with open(filename, 'w') as pl_save: 219 | # for i in pred_label: 220 | # pl_save.write(str(int(i)) + '\n') 221 | # pl_save.close() 222 | # for i in range(whole_scene_label.shape[0]): 223 | # color = g_label2color[pred_label[i]] 224 | # color_gt = g_label2color[whole_scene_label[i]] 225 | # if args.visual: 226 | # fout.write('v %f %f %f %d %d %d\n' % ( 227 | # whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color[0], color[1], 228 | # color[2])) 229 | # fout_gt.write( 230 | # 'v %f %f %f %d %d %d\n' % ( 231 | # whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color_gt[0], 232 | # color_gt[1], color_gt[2])) 233 | # if args.visual: 234 | # fout.close() 235 | # fout_gt.close() 236 | 237 | IoU = np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6) 238 | iou_per_class_str = '------- IoU --------\n' 239 | F1_per_class_str = '------- F1 --------\n' 240 | for l in range(args.num_class): 241 | precision_per_class[l] = total_correct_class[l] / total_pre_class[l] 242 | recall_per_class[l] = total_correct_class[l] / total_seen_class[l] 243 | iou_per_class_str += 'class %s, IoU: %.3f \n' % ( 244 | seg_label_to_cat[l] + ' ' * (10 - len(seg_label_to_cat[l])), 245 | total_correct_class[l] / float(total_iou_deno_class[l])) 246 | F1_per_class_str += 'class %s ,F1: %.3f \n' % ( 247 | seg_label_to_cat[l] + ' ' * (10 - len(seg_label_to_cat[l])), 248 | 2 * (precision_per_class[l] * recall_per_class[l]) / ( 249 | precision_per_class[l] + recall_per_class[l] + 1e-6)) 250 | log_string(iou_per_class_str) 251 | log_string(F1_per_class_str) 252 | log_string('eval point avg class IoU: %f' % np.mean(IoU)) 253 | log_string('eval whole scene point avg class acc: %f' % ( 254 | np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float) + 1e-6)))) 255 | log_string('eval whole scene point accuracy: %f' % ( 256 | np.sum(total_correct_class) / float(np.sum(total_seen_class) + 1e-6))) 257 | 258 | print("Done!") 259 | 260 | if __name__ == '__main__': 261 | args = parse_args() 262 | main(args) 263 | -------------------------------------------------------------------------------- /tool/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import logging 5 | import warnings 6 | import datetime 7 | import argparse 8 | import torch.utils.data 9 | import torch.nn.parallel 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch.nn.functional as F 14 | 15 | from tqdm import tqdm 16 | from pathlib import Path 17 | from collections import defaultdict 18 | from torch.autograd import Variable 19 | 20 | from util import transform 21 | from data.S3DIS.S3DISDataLoader import S3DISDataset 22 | from model.gacnet import GACNet, build_graph_pyramid 23 | 24 | 25 | 26 | 27 | warnings.filterwarnings("ignore") 28 | 29 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter'] 30 | class2label = {cls: i for i,cls in enumerate(classes)} 31 | seg_classes = class2label 32 | seg_label_to_cat = {} 33 | for i,cat in enumerate(seg_classes.keys()): 34 | seg_label_to_cat[i] = cat 35 | 36 | graph_inf = {'stride_list': [1024, 256, 64, 32], #can be seen as the downsampling rate 37 | 'radius_list': [0.1, 0.2, 0.4, 0.8, 1.6], # radius for neighbor points searching 38 | 'maxsample_list': [12, 21, 21, 21, 12] #number of neighbor points for each layer 39 | } 40 | 41 | # number of units for each mlp layer 42 | forward_parm = [ 43 | [ [32,32,64], [64] ], 44 | [ [64,64,128], [128] ], 45 | [ [128,128,256], [256] ], 46 | [ [256,256,512], [512] ], 47 | [ [256,256], [256] ] 48 | ] 49 | 50 | # for feature interpolation stage 51 | upsample_parm = [ 52 | [128, 128], 53 | [128, 128], 54 | [256, 256], 55 | [256, 256] 56 | ] 57 | 58 | # parameters for fully connection layer 59 | fullconect_parm = 128 60 | 61 | net_inf = {'forward_parm': forward_parm, 62 | 'upsample_parm': upsample_parm, 63 | 'fullconect_parm': fullconect_parm 64 | } 65 | 66 | 67 | 68 | def parse_args(): 69 | parser = argparse.ArgumentParser('GACNet') 70 | parser.add_argument('--gpu', type=str, default='1', help='specify gpu device') 71 | parser.add_argument('--multi_gpu', type=str, default=None, help='whether use multi gpu training') 72 | 73 | parser.add_argument('--epoch', type=int, default=100, help='number of epochs for training [default: 200]') 74 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers [default: 4]') 75 | parser.add_argument('--batchSize', type=int, default=2, help='input batch size [default: 24]') 76 | 77 | parser.add_argument('--alpha', type=float, default=0.2, help='alpha for leakyRelu [default: 0.2]') 78 | parser.add_argument('--dropout', type=float, default=0, help='dropout [default: 0]') 79 | 80 | parser.add_argument('--optimizer', type=str, default='SGD', help='type of optimizer') 81 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay for Adam') 82 | parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate for training [default: 0.001 for Adam, 0.01 for SGD]') 83 | 84 | parser.add_argument('--pretrain', type=str, default=None, help='whether use pretrain model') 85 | parser.add_argument('--log_dir', type=str, default='logs_gacnet/',help='decay rate of learning rate') 86 | 87 | parser.add_argument('--datapath', type=str, default='/workspace/VisualPointCloud/testgacnet', help='path of the dataset') 88 | parser.add_argument('--numpoint', type=int, default=4096, help='number of point for input [default: 4096]') 89 | parser.add_argument('--numclass', type=int, default=13, help='class number of the dataset [default: 13]') 90 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]') 91 | return parser.parse_args() 92 | 93 | def main(args): 94 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3' 95 | '''CREATE DIR''' 96 | experiment_dir = Path('./experiment/') 97 | experiment_dir.mkdir(exist_ok=True) 98 | file_dir = Path(str(experiment_dir) +'/'+ str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + '_' + args.log_dir)) 99 | file_dir.mkdir(exist_ok=True) 100 | checkpoints_dir = file_dir.joinpath('checkpoints/') 101 | checkpoints_dir.mkdir(exist_ok=True) 102 | log_dir = file_dir.joinpath(args.log_dir) 103 | log_dir.mkdir(exist_ok=True) 104 | 105 | '''LOG''' 106 | args = parse_args() 107 | logger = logging.getLogger('GACNet') 108 | logger.setLevel(logging.INFO) 109 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 110 | file_handler = logging.FileHandler(str(log_dir) + '/train_gacnet.txt') 111 | file_handler.setLevel(logging.INFO) 112 | file_handler.setFormatter(formatter) 113 | logger.addHandler(file_handler) 114 | logger.info('PARAMETER ...') 115 | logger.info(args) 116 | print('Load data...') 117 | 118 | train_transform = transform.Compose([transform.ToTensor()]) 119 | print("start loading training data ...") 120 | TRAIN_DATASET = S3DISDataset(split='train', data_root=args.datapath, num_point=args.numpoint, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None) 121 | dataloader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchSize, shuffle=True, num_workers=args.workers, 122 | pin_memory=True, drop_last=True, worker_init_fn = lambda x: np.random.seed(x+int(time.time()))) 123 | 124 | val_transform = transform.Compose([transform.ToTensor()]) 125 | print("start loading test data ...") 126 | TEST_DATASET = S3DISDataset(split='test', data_root=args.datapath, num_point=args.numpoint, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None) 127 | testdataloader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=8, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 128 | 129 | weights = torch.Tensor(TRAIN_DATASET.labelweights).cuda() 130 | 131 | blue = lambda x: '\033[94m' + x + '\033[0m' 132 | model = GACNet(args.numclass, graph_inf, net_inf) 133 | 134 | if args.pretrain is not None: 135 | model.load_state_dict(torch.load(args.pretrain)) 136 | print('load model %s'%args.pretrain) 137 | logger.info('load model %s'%args.pretrain) 138 | else: 139 | print('Training from scratch') 140 | logger.info('Training from scratch') 141 | pretrain = args.pretrain 142 | init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0 143 | 144 | def adjust_learning_rate(optimizer, step): 145 | """Sets the learning rate to the initial LR decayed by 30 every 20000 steps""" 146 | lr = args.learning_rate * (0.3 ** (step // 20000)) 147 | for param_group in optimizer.param_groups: 148 | param_group['lr'] = lr 149 | 150 | if args.optimizer == 'SGD': 151 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9) 152 | elif args.optimizer == 'Adam': 153 | optimizer = torch.optim.Adam( 154 | model.parameters(), 155 | lr=args.learning_rate, 156 | betas=(0.9, 0.999), 157 | eps=1e-08, 158 | weight_decay=args.decay_rate 159 | ) 160 | 161 | '''GPU selection and multi-GPU''' 162 | if args.multi_gpu is not None: 163 | device_ids = [int(x) for x in args.multi_gpu.split(',')] 164 | torch.backends.cudnn.benchmark = True 165 | model.cuda(device_ids[0]) 166 | model = torch.nn.DataParallel(model, device_ids=device_ids) 167 | else: 168 | model.cuda() 169 | 170 | history = defaultdict(lambda: list()) 171 | best_acc = 0 172 | best_meaniou = 0 173 | step = 0 174 | 175 | for epoch in range(init_epoch,args.epoch): 176 | for i, data in tqdm(enumerate(dataloader, 0),total=len(dataloader),smoothing=0.9): 177 | points, target = data 178 | points, target = Variable(points.float()), Variable(target.long()) 179 | points, target = points.cuda(), target.cuda() 180 | optimizer.zero_grad() 181 | model = model.train() 182 | 183 | graph_prd, coarse_map = build_graph_pyramid(points[:, :, 0:3], graph_inf) 184 | pred = model(points[:, :, :6], graph_prd, coarse_map) 185 | 186 | pred = pred.contiguous().view(-1, args.numclass) 187 | target = target.view(-1, 1)[:, 0] 188 | loss = F.nll_loss(pred, target, weight=weights) 189 | history['loss'].append(loss.cpu().data.numpy()) 190 | loss.backward() 191 | optimizer.step() 192 | step += 1 193 | adjust_learning_rate(optimizer, step) 194 | 195 | test_metrics, test_hist_acc, cat_mean_iou = test_seg(model, testdataloader, seg_label_to_cat) 196 | mean_iou = np.mean(cat_mean_iou) 197 | 198 | print('Epoch %d %s accuracy: %f meanIOU: %f' % ( 199 | epoch, blue('test'), test_metrics['accuracy'],mean_iou)) 200 | logger.info('Epoch %d %s accuracy: %f meanIOU: %f' % ( 201 | epoch, 'test', test_metrics['accuracy'],mean_iou)) 202 | if test_metrics['accuracy'] > best_acc: 203 | best_acc = test_metrics['accuracy'] 204 | torch.save(model.state_dict(), '%s/GACNet_%.3d_%.4f_%.4f.pth' % (checkpoints_dir, epoch, best_acc, best_meaniou)) 205 | logger.info(cat_mean_iou) 206 | logger.info('Save model..') 207 | print('Save model..') 208 | print(cat_mean_iou) 209 | if mean_iou > best_meaniou: 210 | best_meaniou = mean_iou 211 | torch.save(model.state_dict(), '%s/GACNet_%.3d_%.4f_%.4f.pth' % (checkpoints_dir, epoch, best_acc, best_meaniou)) 212 | logger.info(cat_mean_iou) 213 | logger.info('Save model..') 214 | print('Save model..') 215 | print(cat_mean_iou) 216 | print('Best accuracy is: %.5f'%best_acc) 217 | logger.info('Best accuracy is: %.5f'%best_acc) 218 | print('Best meanIOU is: %.5f'%best_meaniou) 219 | logger.info('Best meanIOU is: %.5f'%best_meaniou) 220 | 221 | def test_seg(model, loader, catdict, num_classes = 13): 222 | iou_tabel = np.zeros((len(catdict),3)) 223 | metrics = defaultdict(lambda:list()) 224 | hist_acc = [] 225 | for batch_id, (points, target) in tqdm(enumerate(loader), total=len(loader), smoothing=0.9): 226 | batchsize, num_point, _ = points.size() 227 | 228 | points, target = Variable(points.float()), Variable(target.long()) 229 | 230 | points, target = points.cuda(), target.cuda() 231 | 232 | graph_prd, coarse_map = build_graph_pyramid(points[:, :, 0:3], graph_inf) 233 | pred = model(points[:, :, :6], graph_prd, coarse_map) 234 | 235 | mean_iou, iou_tabel = compute_iou(pred, target, iou_tabel) 236 | pred = pred.contiguous().view(-1, num_classes) 237 | target = target.view(-1, 1)[:, 0] 238 | pred_choice = pred.data.max(1)[1] 239 | correct = pred_choice.eq(target.data).cpu().sum() 240 | metrics['accuracy'].append(correct.item()/ (batchsize * num_point)) 241 | metrics['iou'].append(mean_iou) 242 | iou_tabel[:,2] = iou_tabel[:,0] /(iou_tabel[:,1]+0.01) 243 | hist_acc += metrics['accuracy'] 244 | metrics['accuracy'] = np.mean(metrics['accuracy']) 245 | iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou']) 246 | iou_tabel['Category_IOU'] = [cat_value for cat_value in catdict.values()] 247 | cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean() 248 | 249 | return metrics, hist_acc, cat_iou 250 | 251 | def compute_iou(pred,target,iou_tabel=None): 252 | ious = [] 253 | target = target.cpu().data.numpy() 254 | for j in range(pred.size(0)): 255 | batch_pred = pred[j] 256 | batch_target = target[j] 257 | batch_choice = batch_pred.data.max(1)[1].cpu().data.numpy() 258 | for cat in np.unique(batch_target): 259 | intersection = np.sum((batch_target == cat) & (batch_choice == cat)) 260 | union = float(np.sum((batch_target == cat) | (batch_choice == cat))) 261 | iou = intersection/union 262 | ious.append(iou) 263 | iou_tabel[cat,0] += iou 264 | iou_tabel[cat,1] += 1 265 | return np.mean(ious), iou_tabel 266 | 267 | if __name__ == '__main__': 268 | args = parse_args() 269 | main(args) 270 | -------------------------------------------------------------------------------- /util/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def make_dataset(split='train', data_root=None, data_list=None): 9 | if not os.path.isfile(data_list): 10 | raise (RuntimeError("Point list file do not exist: " + data_list + "\n")) 11 | point_list = [] 12 | list_read = open(data_list).readlines() 13 | print("Totally {} samples in {} set.".format(len(list_read), split)) 14 | for line in list_read: 15 | point_list.append(os.path.join(data_root, line.strip())) 16 | return point_list 17 | 18 | 19 | class PointData(Dataset): 20 | def __init__(self, split='train', data_root=None, data_list=None, transform=None, num_point=None, random_index=False): 21 | assert split in ['train', 'val', 'test'] 22 | self.split = split 23 | self.data_list = make_dataset(split, data_root, data_list) 24 | self.transform = transform 25 | self.num_point = num_point 26 | self.random_index = random_index 27 | 28 | def __len__(self): 29 | return len(self.data_list) 30 | 31 | def __getitem__(self, index): 32 | data_path = self.data_list[index] 33 | f = h5py.File(data_path, 'r') 34 | data = f['data'][:] 35 | if self.split is 'test': 36 | label = 255 # place holder 37 | else: 38 | label = f['label'][:] 39 | f.close() 40 | if self.num_point is None: 41 | self.num_point = data.shape[0] 42 | idxs = np.arange(data.shape[0]) 43 | if self.random_index: 44 | np.random.shuffle(idxs) 45 | idxs = idxs[0:self.num_point] 46 | data = data[idxs, :] 47 | if label.size != 1: # seg data 48 | label = label[idxs] 49 | if self.transform is not None: 50 | data, label = self.transform(data, label) 51 | return data, label 52 | 53 | 54 | if __name__ == '__main__': 55 | data_root = '/mnt/sda1/hszhao/dataset/3d/s3dis' 56 | data_list = '/mnt/sda1/hszhao/dataset/3d/s3dis/list/train12346.txt' 57 | point_data = PointData('train', data_root, data_list) 58 | print('point data size:', point_data.__len__()) 59 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape) 60 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape) 61 | -------------------------------------------------------------------------------- /util/ply.py: -------------------------------------------------------------------------------- 1 | # Basic libs 2 | import numpy as np 3 | import sys 4 | 5 | 6 | # Define PLY types 7 | ply_dtypes = dict([ 8 | (b'int8', 'i1'), 9 | (b'char', 'i1'), 10 | (b'uint8', 'u1'), 11 | (b'uchar', 'u1'), 12 | (b'int16', 'i2'), 13 | (b'short', 'i2'), 14 | (b'uint16', 'u2'), 15 | (b'ushort', 'u2'), 16 | (b'int32', 'i4'), 17 | (b'int', 'i4'), 18 | (b'uint32', 'u4'), 19 | (b'uint', 'u4'), 20 | (b'float32', 'f4'), 21 | (b'float', 'f4'), 22 | (b'float64', 'f8'), 23 | (b'double', 'f8') 24 | ]) 25 | 26 | # Numpy reader format 27 | valid_formats = {'ascii': '', 'binary_big_endian': '>', 28 | 'binary_little_endian': '<'} 29 | 30 | 31 | # ---------------------------------------------------------------------------------------------------------------------- 32 | # 33 | # Functions 34 | # \***************/ 35 | # 36 | 37 | 38 | def parse_header(plyfile, ext): 39 | # Variables 40 | line = [] 41 | properties = [] 42 | num_points = None 43 | 44 | while b'end_header' not in line and line != b'': 45 | line = plyfile.readline() 46 | 47 | if b'element' in line: 48 | line = line.split() 49 | num_points = int(line[2]) 50 | 51 | elif b'property' in line: 52 | line = line.split() 53 | properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 54 | 55 | return num_points, properties 56 | 57 | 58 | def parse_mesh_header(plyfile, ext): 59 | # Variables 60 | line = [] 61 | vertex_properties = [] 62 | num_points = None 63 | num_faces = None 64 | current_element = None 65 | 66 | 67 | while b'end_header' not in line and line != b'': 68 | line = plyfile.readline() 69 | 70 | # Find point element 71 | if b'element vertex' in line: 72 | current_element = 'vertex' 73 | line = line.split() 74 | num_points = int(line[2]) 75 | 76 | elif b'element face' in line: 77 | current_element = 'face' 78 | line = line.split() 79 | num_faces = int(line[2]) 80 | 81 | elif b'property' in line: 82 | if current_element == 'vertex': 83 | line = line.split() 84 | vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 85 | elif current_element == 'vertex': 86 | if not line.startswith('property list uchar int'): 87 | raise ValueError('Unsupported faces property : ' + line) 88 | 89 | return num_points, num_faces, vertex_properties 90 | 91 | 92 | def read_ply(filename, triangular_mesh=False): 93 | """ 94 | Read ".ply" files 95 | 96 | Parameters 97 | ---------- 98 | filename : string 99 | the name of the file to read. 100 | 101 | Returns 102 | ------- 103 | result : array 104 | data stored in the file 105 | 106 | Examples 107 | -------- 108 | Store data in file 109 | 110 | >>> points = np.random.rand(5, 3) 111 | >>> values = np.random.randint(2, size=10) 112 | >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values']) 113 | 114 | Read the file 115 | 116 | >>> data = read_ply('example.ply') 117 | >>> values = data['values'] 118 | array([0, 0, 1, 1, 0]) 119 | 120 | >>> points = np.vstack((data['x'], data['y'], data['z'])).T 121 | array([[ 0.466 0.595 0.324] 122 | [ 0.538 0.407 0.654] 123 | [ 0.850 0.018 0.988] 124 | [ 0.395 0.394 0.363] 125 | [ 0.873 0.996 0.092]]) 126 | 127 | """ 128 | 129 | with open(filename, 'rb') as plyfile: 130 | 131 | 132 | # Check if the file start with ply 133 | if b'ply' not in plyfile.readline(): 134 | raise ValueError('The file does not start whith the word ply') 135 | 136 | # get binary_little/big or ascii 137 | fmt = plyfile.readline().split()[1].decode() 138 | if fmt == "ascii": 139 | raise ValueError('The file is not binary') 140 | 141 | # get extension for building the numpy dtypes 142 | ext = valid_formats[fmt] 143 | 144 | # PointCloud reader vs mesh reader 145 | if triangular_mesh: 146 | 147 | # Parse header 148 | num_points, num_faces, properties = parse_mesh_header(plyfile, ext) 149 | 150 | # Get point data 151 | vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points) 152 | 153 | # Get face data 154 | face_properties = [('k', ext + 'u1'), 155 | ('v1', ext + 'i4'), 156 | ('v2', ext + 'i4'), 157 | ('v3', ext + 'i4')] 158 | faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces) 159 | 160 | # Return vertex data and concatenated faces 161 | faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T 162 | data = [vertex_data, faces] 163 | 164 | else: 165 | 166 | # Parse header 167 | num_points, properties = parse_header(plyfile, ext) 168 | 169 | # Get data 170 | data = np.fromfile(plyfile, dtype=properties, count=num_points) 171 | 172 | return data 173 | 174 | 175 | def header_properties(field_list, field_names): 176 | 177 | # List of lines to write 178 | lines = [] 179 | 180 | # First line describing element vertex 181 | lines.append('element vertex %d' % field_list[0].shape[0]) 182 | 183 | # Properties lines 184 | i = 0 185 | for fields in field_list: 186 | for field in fields.T: 187 | lines.append('property %s %s' % (field.dtype.name, field_names[i])) 188 | i += 1 189 | 190 | return lines 191 | 192 | 193 | def write_ply(filename, field_list, field_names, triangular_faces=None): 194 | """ 195 | Write ".ply" files 196 | 197 | Parameters 198 | ---------- 199 | filename : string 200 | the name of the file to which the data is saved. A '.ply' extension will be appended to the 201 | file name if it does no already have one. 202 | 203 | field_list : list, tuple, numpy array 204 | the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a 205 | tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered 206 | as one field. 207 | 208 | field_names : list 209 | the name of each fields as a list of strings. Has to be the same length as the number of 210 | fields. 211 | 212 | Examples 213 | -------- 214 | >>> points = np.random.rand(10, 3) 215 | >>> write_ply('example1.ply', points, ['x', 'y', 'z']) 216 | 217 | >>> values = np.random.randint(2, size=10) 218 | >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values']) 219 | 220 | >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8) 221 | >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', values'] 222 | >>> write_ply('example3.ply', [points, colors, values], field_names) 223 | 224 | """ 225 | 226 | # Format list input to the right form 227 | field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,)) 228 | for i, field in enumerate(field_list): 229 | if field.ndim < 2: 230 | field_list[i] = field.reshape(-1, 1) 231 | if field.ndim > 2: 232 | print('fields have more than 2 dimensions') 233 | return False 234 | 235 | # check all fields have the same number of data 236 | n_points = [field.shape[0] for field in field_list] 237 | if not np.all(np.equal(n_points, n_points[0])): 238 | print('wrong field dimensions') 239 | return False 240 | 241 | # Check if field_names and field_list have same nb of column 242 | n_fields = np.sum([field.shape[1] for field in field_list]) 243 | if (n_fields != len(field_names)): 244 | print('wrong number of field names') 245 | return False 246 | 247 | # Add extension if not there 248 | if not filename.endswith('.ply'): 249 | filename += '.ply' 250 | 251 | # open in text mode to write the header 252 | with open(filename, 'w') as plyfile: 253 | 254 | # First magical word 255 | header = ['ply'] 256 | 257 | # Encoding format 258 | header.append('format binary_' + sys.byteorder + '_endian 1.0') 259 | 260 | # Points properties description 261 | header.extend(header_properties(field_list, field_names)) 262 | 263 | # Add faces if needded 264 | if triangular_faces is not None: 265 | header.append('element face {:d}'.format(triangular_faces.shape[0])) 266 | header.append('property list uchar int vertex_indices') 267 | 268 | # End of header 269 | header.append('end_header') 270 | 271 | # Write all lines 272 | for line in header: 273 | plyfile.write("%s\n" % line) 274 | 275 | # open in binary/append to use tofile 276 | with open(filename, 'ab') as plyfile: 277 | 278 | # Create a structured array 279 | i = 0 280 | type_list = [] 281 | for fields in field_list: 282 | for field in fields.T: 283 | type_list += [(field_names[i], field.dtype.str)] 284 | i += 1 285 | data = np.empty(field_list[0].shape[0], dtype=type_list) 286 | i = 0 287 | for fields in field_list: 288 | for field in fields.T: 289 | data[field_names[i]] = field 290 | i += 1 291 | 292 | data.tofile(plyfile) 293 | 294 | if triangular_faces is not None: 295 | triangular_faces = triangular_faces.astype(np.int32) 296 | type_list = [('k', 'uint8')] + [(str(ind), 'int32') for ind in range(3)] 297 | data = np.empty(triangular_faces.shape[0], dtype=type_list) 298 | data['k'] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8) 299 | data['0'] = triangular_faces[:, 0] 300 | data['1'] = triangular_faces[:, 1] 301 | data['2'] = triangular_faces[:, 2] 302 | data.tofile(plyfile) 303 | 304 | return True 305 | 306 | 307 | def describe_element(name, df): 308 | """ Takes the columns of the dataframe and builds a ply-like description 309 | 310 | Parameters 311 | ---------- 312 | name: str 313 | df: pandas DataFrame 314 | 315 | Returns 316 | ------- 317 | element: list[str] 318 | """ 319 | property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'} 320 | element = ['element ' + name + ' ' + str(len(df))] 321 | 322 | if name == 'face': 323 | element.append("property list uchar int points_indices") 324 | 325 | else: 326 | for i in range(len(df.columns)): 327 | # get first letter of dtype to infer format 328 | f = property_formats[str(df.dtypes[i])[0]] 329 | element.append('property ' + f + ' ' + df.columns.values[i]) 330 | 331 | return element -------------------------------------------------------------------------------- /util/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | class Compose(object): 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, data, label): 11 | for t in self.transforms: 12 | data, label = t(data, label) 13 | return data, label 14 | 15 | 16 | class ToTensor(object): 17 | def __call__(self, data, label): 18 | data = torch.from_numpy(data) 19 | if not isinstance(data, torch.FloatTensor): 20 | data = data.float() 21 | label = torch.from_numpy(label) 22 | if not isinstance(label, torch.LongTensor): 23 | label = label.long() 24 | return data, label 25 | 26 | 27 | class RandomRotate(object): 28 | def __init__(self, rotate_angle=None, along_z=False): 29 | self.rotate_angle = rotate_angle 30 | self.along_z = along_z 31 | 32 | def __call__(self, data, label): 33 | if self.rotate_angle is None: 34 | rotate_angle = np.random.uniform() * 2 * np.pi 35 | else: 36 | rotate_angle = self.rotate_angle 37 | cosval, sinval = np.cos(rotate_angle), np.sin(rotate_angle) 38 | if self.along_z: 39 | rotation_matrix = np.array([[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]) 40 | else: 41 | rotation_matrix = np.array([[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]) 42 | data[:, 0:3] = np.dot(data[:, 0:3], rotation_matrix) 43 | if data.shape[1] > 3: # use normal 44 | data[:, 3:6] = np.dot(data[:, 3:6], rotation_matrix) 45 | return data, label 46 | 47 | 48 | class RandomRotatePerturbation(object): 49 | def __init__(self, angle_sigma=0.06, angle_clip=0.18): 50 | self.angle_sigma = angle_sigma 51 | self.angle_clip = angle_clip 52 | 53 | def __call__(self, data, label): 54 | angles = np.clip(self.angle_sigma*np.random.randn(3), -self.angle_clip, self.angle_clip) 55 | Rx = np.array([[1, 0, 0], 56 | [0, np.cos(angles[0]), -np.sin(angles[0])], 57 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 58 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 59 | [0, 1, 0], 60 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 61 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 62 | [np.sin(angles[2]), np.cos(angles[2]), 0], 63 | [0, 0, 1]]) 64 | R = np.dot(Rz, np.dot(Ry, Rx)) 65 | data[:, 0:3] = np.dot(data[:, 0:3], R) 66 | if data.shape[1] > 3: # use normal 67 | data[:, 3:6] = np.dot(data[:, 3:6], R) 68 | return data, label 69 | 70 | 71 | class RandomScale(object): 72 | def __init__(self, scale_low=0.8, scale_high=1.25): 73 | self.scale_low = scale_low 74 | self.scale_high = scale_high 75 | 76 | def __call__(self, data, label): 77 | scale = np.random.uniform(self.scale_low, self.scale_high) 78 | data[:, 0:3] *= scale 79 | return data, label 80 | 81 | 82 | class RandomShift(object): 83 | def __init__(self, shift_range=0.1): 84 | self.shift_range = shift_range 85 | 86 | def __call__(self, data, label): 87 | shift = np.random.uniform(-self.shift_range, self.shift_range, 3) 88 | data[:, 0:3] += shift 89 | return data, label 90 | 91 | 92 | class RandomJitter(object): 93 | def __init__(self, sigma=0.01, clip=0.05): 94 | self.sigma = sigma 95 | self.clip = clip 96 | 97 | def __call__(self, data, label): 98 | assert (self.clip > 0) 99 | jitter = np.clip(self.sigma * np.random.randn(data.shape[0], 3), -1 * self.clip, self.clip) 100 | data[:, 0:3] += jitter 101 | return data, label 102 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn.modules.conv import _ConvNd 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | import torch.nn.init as initer 10 | 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def step_learning_rate(optimizer, base_lr, epoch, step_epoch, multiplier=0.1, clip=1e-6): 31 | """Sets the learning rate to the base LR decayed by 10 every step epochs""" 32 | lr = max(base_lr * (multiplier ** (epoch // step_epoch)), clip) 33 | for param_group in optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | 37 | def poly_learning_rate(optimizer, base_lr, curr_iter, max_iter, power=0.9): 38 | """poly learning rate policy""" 39 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power 40 | for param_group in optimizer.param_groups: 41 | param_group['lr'] = lr 42 | 43 | 44 | def intersectionAndUnion(output, target, K, ignore_index=255): 45 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 46 | assert (output.ndim in [1, 2, 3]) 47 | assert output.shape == target.shape 48 | output = output.reshape(output.size).copy() 49 | target = target.reshape(target.size) 50 | output[np.where(target == ignore_index)[0]] = 255 51 | intersection = output[np.where(output == target)[0]] 52 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1)) 53 | area_output, _ = np.histogram(output, bins=np.arange(K+1)) 54 | area_target, _ = np.histogram(target, bins=np.arange(K+1)) 55 | area_union = area_output + area_target - area_intersection 56 | return area_intersection, area_union, area_target 57 | 58 | 59 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 60 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 61 | assert (output.dim() in [1, 2, 3]) 62 | assert output.shape == target.shape 63 | output = output.view(-1) 64 | target = target.view(-1) 65 | output[target == ignore_index] = ignore_index 66 | intersection = output[output == target] 67 | # https://github.com/pytorch/pytorch/issues/1382 68 | area_intersection = torch.histc(intersection.float().cpu(), bins=K, min=0, max=K-1) 69 | area_output = torch.histc(output.float().cpu(), bins=K, min=0, max=K-1) 70 | area_target = torch.histc(target.float().cpu(), bins=K, min=0, max=K-1) 71 | area_union = area_output + area_target - area_intersection 72 | return area_intersection.cuda(), area_union.cuda(), area_target.cuda() 73 | 74 | 75 | def check_mkdir(dir_name): 76 | if not os.path.exists(dir_name): 77 | os.mkdir(dir_name) 78 | 79 | 80 | def check_makedirs(dir_name): 81 | if not os.path.exists(dir_name): 82 | os.makedirs(dir_name) 83 | 84 | 85 | def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'): 86 | """ 87 | :param model: Pytorch Model which is nn.Module 88 | :param conv: 'kaiming' or 'xavier' 89 | :param batchnorm: 'normal' or 'constant' 90 | :param linear: 'kaiming' or 'xavier' 91 | :param lstm: 'kaiming' or 'xavier' 92 | """ 93 | for m in model.modules(): 94 | if isinstance(m, (_ConvNd)): 95 | if conv == 'kaiming': 96 | initer.kaiming_normal_(m.weight) 97 | elif conv == 'xavier': 98 | initer.xavier_normal_(m.weight) 99 | else: 100 | raise ValueError("init type of conv error.\n") 101 | if m.bias is not None: 102 | initer.constant_(m.bias, 0) 103 | 104 | elif isinstance(m, _BatchNorm): 105 | if batchnorm == 'normal': 106 | initer.normal_(m.weight, 1.0, 0.02) 107 | elif batchnorm == 'constant': 108 | initer.constant_(m.weight, 1.0) 109 | else: 110 | raise ValueError("init type of batchnorm error.\n") 111 | initer.constant_(m.bias, 0.0) 112 | 113 | elif isinstance(m, nn.Linear): 114 | if linear == 'kaiming': 115 | initer.kaiming_normal_(m.weight) 116 | elif linear == 'xavier': 117 | initer.xavier_normal_(m.weight) 118 | else: 119 | raise ValueError("init type of linear error.\n") 120 | if m.bias is not None: 121 | initer.constant_(m.bias, 0) 122 | 123 | elif isinstance(m, nn.LSTM): 124 | for name, param in m.named_parameters(): 125 | if 'weight' in name: 126 | if lstm == 'kaiming': 127 | initer.kaiming_normal_(param) 128 | elif lstm == 'xavier': 129 | initer.xavier_normal_(param) 130 | else: 131 | raise ValueError("init type of lstm error.\n") 132 | elif 'bias' in name: 133 | initer.constant_(param, 0) 134 | 135 | 136 | def convert_to_syncbn(model): 137 | def recursive_set(cur_module, name, module): 138 | if len(name.split('.')) > 1: 139 | recursive_set(getattr(cur_module, name[:name.find('.')]), name[name.find('.')+1:], module) 140 | else: 141 | setattr(cur_module, name, module) 142 | from lib.sync_bn import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 143 | for name, m in model.named_modules(): 144 | if isinstance(m, nn.BatchNorm1d): 145 | recursive_set(model, name, SynchronizedBatchNorm1d(m.num_features, m.eps, m.momentum, m.affine)) 146 | elif isinstance(m, nn.BatchNorm2d): 147 | recursive_set(model, name, SynchronizedBatchNorm2d(m.num_features, m.eps, m.momentum, m.affine)) 148 | elif isinstance(m, nn.BatchNorm3d): 149 | recursive_set(model, name, SynchronizedBatchNorm3d(m.num_features, m.eps, m.momentum, m.affine)) 150 | 151 | 152 | def colorize(gray, palette): 153 | # gray: numpy array of the label and 1*3N size list palette 154 | color = Image.fromarray(gray.astype(np.uint8)).convert('P') 155 | color.putpalette(palette) 156 | return color 157 | --------------------------------------------------------------------------------