├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── lib ├── pointnet2_modules.py ├── pointnet2_utils.py ├── pytorch_utils.py ├── setup.py └── src │ ├── ball_query.cpp │ ├── ball_query_gpu.cu │ ├── ball_query_gpu.h │ ├── cuda_utils.h │ ├── group_points.cpp │ ├── group_points_gpu.cu │ ├── group_points_gpu.h │ ├── interpolate.cpp │ ├── interpolate_gpu.cu │ ├── interpolate_gpu.h │ ├── pointnet2_api.cpp │ ├── sampling.cpp │ ├── sampling_gpu.cu │ └── sampling_gpu.h ├── main.py ├── model.py ├── pretrained_model └── model.best.t7 └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | checkpoints/ 4 | build/ 5 | dist/ 6 | pointnet2.egg-info/ 7 | experiment/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hang Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # flownet3d_pytorch 2 | The pytorch implementation of [flownet3d](https://github.com/xingyul/flownet3d) based on [WangYueFt/dcp](https://github.com/WangYueFt/dcp), [sshaoshuai/Pointnet2.PyTorch](https://github.com/sshaoshuai/Pointnet2.PyTorch) and [yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch) 3 | 4 | ## Installation 5 | 6 | ### Requirements 7 | PyTorch>=1.0: https://pytorch.org 8 | 9 | scipy>=1.2 10 | 11 | numpy 12 | 13 | h5py 14 | 15 | tqdm 16 | 17 | ### Install 18 | Install this library by running the following command: 19 | ```bash 20 | cd lib 21 | python setup.py install 22 | cd ../ 23 | ``` 24 | ## Training 25 | 26 | The processed Flyingthings3d data is provided [here](https://drive.google.com/file/d/1CMaxdt-Tg1Wct8v8eGNwuT7qRSIyJPY-/view?usp=sharing) for download (total size ~11GB). 27 | 28 | Then run the following command to train: 29 | ```bash 30 | python main.py --exp_name=flownet3d --dataset_path=xx/yy 31 | ``` 32 | where xx/yy is the dataset path 33 | 34 | ## Performance comparison 35 | All of the following experiments were tested on a TITAN RTX 36 | 37 | 1. GPU memory usage: 38 | 39 | batch size|flownet3d(GB)|flownet3d_pytorch(GB) 40 | ---|---|--- 41 | 16|16.9|9.1 42 | 43 | 2. Training time per epoch on Flyingthings3d dataset: 44 | 45 | batch size|flownet3d(min)|flownet3d_pytorch(min) 46 | ---|---|--- 47 | 16|6.7|3.4 48 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import os 6 | import sys 7 | import glob 8 | import h5py 9 | import numpy as np 10 | from scipy.spatial.transform import Rotation 11 | from torch.utils.data import Dataset 12 | from sklearn.neighbors import NearestNeighbors 13 | from scipy.spatial.distance import minkowski 14 | 15 | 16 | # Part of the code is referred from: https://github.com/charlesq34/pointnet 17 | 18 | def download(): 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | DATA_DIR = os.path.join(BASE_DIR, 'data') 21 | if not os.path.exists(DATA_DIR): 22 | os.mkdir(DATA_DIR) 23 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 24 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 25 | zipfile = os.path.basename(www) 26 | os.system('wget %s; unzip %s' % (www, zipfile)) 27 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 28 | os.system('rm %s' % (zipfile)) 29 | 30 | 31 | def load_data(partition): 32 | # download() 33 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 34 | DATA_DIR = os.path.join(BASE_DIR, '../../datasets') 35 | all_data = [] 36 | all_label = [] 37 | for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)): 38 | f = h5py.File(h5_name) 39 | data = f['data'][:].astype('float32') 40 | label = f['label'][:].astype('int64') 41 | f.close() 42 | all_data.append(data) 43 | all_label.append(label) 44 | all_data = np.concatenate(all_data, axis=0) 45 | all_label = np.concatenate(all_label, axis=0) 46 | return all_data, all_label 47 | 48 | 49 | def translate_pointcloud(pointcloud): 50 | xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3]) 51 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 52 | 53 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 54 | return translated_pointcloud 55 | 56 | 57 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.05): 58 | N, C = pointcloud.shape 59 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip) 60 | return pointcloud 61 | 62 | def farthest_subsample_points(pointcloud1, pointcloud2, num_subsampled_points=768): 63 | pointcloud1 = pointcloud1.T 64 | pointcloud2 = pointcloud2.T 65 | num_points = pointcloud1.shape[0] 66 | nbrs1 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto', 67 | metric=lambda x, y: minkowski(x, y), n_jobs=1).fit(pointcloud1) 68 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1]) 69 | idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((num_subsampled_points,)) 70 | nbrs2 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto', 71 | metric=lambda x, y: minkowski(x, y), n_jobs=1).fit(pointcloud2) 72 | random_p2 = random_p1 #np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 2, -2]) 73 | idx2 = nbrs2.kneighbors(random_p2, return_distance=False).reshape((num_subsampled_points,)) 74 | return pointcloud1[idx1, :].T, pointcloud2[idx2, :].T 75 | 76 | 77 | class ModelNet40(Dataset): 78 | def __init__(self, num_points, num_subsampled_points = 768, partition='train', gaussian_noise=False, unseen=False, factor=4): 79 | self.data, self.label = load_data(partition) 80 | self.num_points = num_points 81 | self.partition = partition 82 | self.gaussian_noise = gaussian_noise 83 | self.unseen = unseen 84 | self.label = self.label.squeeze() 85 | self.factor = factor 86 | self.num_subsampled_points = num_subsampled_points 87 | if num_points != num_subsampled_points: 88 | self.subsampled = True 89 | else: 90 | self.subsampled = False 91 | if self.unseen: 92 | ######## simulate testing on first 20 categories while training on last 20 categories 93 | if self.partition == 'test': 94 | self.data = self.data[self.label>=20] 95 | self.label = self.label[self.label>=20] 96 | elif self.partition == 'train': 97 | self.data = self.data[self.label<20] 98 | self.label = self.label[self.label<20] 99 | 100 | def __getitem__(self, item): 101 | pointcloud = self.data[item][:self.num_points] 102 | # if self.gaussian_noise: 103 | # pointcloud = jitter_pointcloud(pointcloud) 104 | if self.partition != 'train': 105 | np.random.seed(item) 106 | anglex = np.random.uniform() * np.pi / self.factor 107 | angley = np.random.uniform() * np.pi / self.factor 108 | anglez = np.random.uniform() * np.pi / self.factor 109 | 110 | cosx = np.cos(anglex) 111 | cosy = np.cos(angley) 112 | cosz = np.cos(anglez) 113 | sinx = np.sin(anglex) 114 | siny = np.sin(angley) 115 | sinz = np.sin(anglez) 116 | Rx = np.array([[1, 0, 0], 117 | [0, cosx, -sinx], 118 | [0, sinx, cosx]]) 119 | Ry = np.array([[cosy, 0, siny], 120 | [0, 1, 0], 121 | [-siny, 0, cosy]]) 122 | Rz = np.array([[cosz, -sinz, 0], 123 | [sinz, cosz, 0], 124 | [0, 0, 1]]) 125 | # 生成旋转矩阵 126 | R_ab = Rx.dot(Ry).dot(Rz) 127 | R_ba = R_ab.T 128 | # 生成平移向量 129 | translation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5), 130 | np.random.uniform(-0.5, 0.5)]) 131 | translation_ba = -R_ba.dot(translation_ab) 132 | 133 | pointcloud1 = pointcloud.T 134 | rotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex]) 135 | pointcloud2 = rotation_ab.apply(pointcloud1.T).T + np.expand_dims(translation_ab, axis=1) 136 | 137 | euler_ab = np.asarray([anglez, angley, anglex]) 138 | euler_ba = -euler_ab[::-1] 139 | 140 | pointcloud1 = np.random.permutation(pointcloud1.T).T 141 | pointcloud2 = np.random.permutation(pointcloud2.T).T 142 | 143 | if self.gaussian_noise: 144 | pointcloud1 = jitter_pointcloud(pointcloud1) 145 | pointcloud2 = jitter_pointcloud(pointcloud2) 146 | 147 | if self.subsampled: 148 | pointcloud1, pointcloud2 = farthest_subsample_points(pointcloud1, pointcloud2, 149 | num_subsampled_points=self.num_subsampled_points) 150 | 151 | return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \ 152 | translation_ab.astype('float32'), R_ba.astype('float32'), translation_ba.astype('float32'), \ 153 | euler_ab.astype('float32'), euler_ba.astype('float32') 154 | 155 | def __len__(self): 156 | return self.data.shape[0] 157 | 158 | 159 | class SceneflowDataset(Dataset): 160 | def __init__(self, npoints=2048, root='data_preprocessing/data_processed_maxcut_35_both_mask_20k_2k', partition='train'): 161 | self.npoints = npoints 162 | self.partition = partition 163 | self.root = root 164 | if self.partition=='train': 165 | self.datapath = glob.glob(os.path.join(self.root, 'TRAIN*.npz')) 166 | else: 167 | self.datapath = glob.glob(os.path.join(self.root, 'TEST*.npz')) 168 | self.cache = {} 169 | self.cache_size = 30000 170 | 171 | ###### deal with one bad datapoint with nan value 172 | self.datapath = [d for d in self.datapath if 'TRAIN_C_0140_left_0006-0' not in d] 173 | ###### 174 | print(self.partition, ': ',len(self.datapath)) 175 | 176 | def __getitem__(self, index): 177 | if index in self.cache: 178 | pos1, pos2, color1, color2, flow, mask1 = self.cache[index] 179 | else: 180 | fn = self.datapath[index] 181 | with open(fn, 'rb') as fp: 182 | data = np.load(fp) 183 | pos1 = data['points1'].astype('float32') 184 | pos2 = data['points2'].astype('float32') 185 | color1 = data['color1'].astype('float32') 186 | color2 = data['color2'].astype('float32') 187 | flow = data['flow'].astype('float32') 188 | mask1 = data['valid_mask1'] 189 | 190 | if len(self.cache) < self.cache_size: 191 | self.cache[index] = (pos1, pos2, color1, color2, flow, mask1) 192 | 193 | if self.partition == 'train': 194 | n1 = pos1.shape[0] 195 | sample_idx1 = np.random.choice(n1, self.npoints, replace=False) 196 | n2 = pos2.shape[0] 197 | sample_idx2 = np.random.choice(n2, self.npoints, replace=False) 198 | 199 | pos1 = pos1[sample_idx1, :] 200 | pos2 = pos2[sample_idx2, :] 201 | color1 = color1[sample_idx1, :] 202 | color2 = color2[sample_idx2, :] 203 | flow = flow[sample_idx1, :] 204 | mask1 = mask1[sample_idx1] 205 | else: 206 | pos1 = pos1[:self.npoints, :] 207 | pos2 = pos2[:self.npoints, :] 208 | color1 = color1[:self.npoints, :] 209 | color2 = color2[:self.npoints, :] 210 | flow = flow[:self.npoints, :] 211 | mask1 = mask1[:self.npoints] 212 | 213 | pos1_center = np.mean(pos1, 0) 214 | pos1 -= pos1_center 215 | pos2 -= pos1_center 216 | 217 | return pos1, pos2, color1, color2, flow, mask1 218 | 219 | def __len__(self): 220 | return len(self.datapath) 221 | 222 | 223 | 224 | 225 | if __name__ == '__main__': 226 | train = ModelNet40(1024) 227 | test = ModelNet40(1024, 'test') 228 | for data in train: 229 | print(data[0].shape) 230 | break 231 | -------------------------------------------------------------------------------- /lib/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from . import pointnet2_utils 6 | from . import pytorch_utils as pt_utils 7 | from typing import List 8 | 9 | 10 | class _PointnetSAModuleBase(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.npoint = None 15 | self.groupers = None 16 | self.mlps = None 17 | self.pool_method = 'max_pool' 18 | 19 | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): 20 | """ 21 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features 22 | :param features: (B, N, C) tensor of the descriptors of the the features 23 | :param new_xyz: 24 | :return: 25 | new_xyz: (B, npoint, 3) tensor of the new features' xyz 26 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 27 | """ 28 | new_features_list = [] 29 | 30 | xyz_flipped = xyz.transpose(1, 2).contiguous() 31 | if new_xyz is None: 32 | new_xyz = pointnet2_utils.gather_operation( 33 | xyz_flipped, 34 | pointnet2_utils.furthest_point_sample(xyz, self.npoint) 35 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 36 | 37 | for i in range(len(self.groupers)): 38 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) 39 | 40 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 41 | if self.pool_method == 'max_pool': 42 | new_features = F.max_pool2d( 43 | new_features, kernel_size=[1, new_features.size(3)] 44 | ) # (B, mlp[-1], npoint, 1) 45 | elif self.pool_method == 'avg_pool': 46 | new_features = F.avg_pool2d( 47 | new_features, kernel_size=[1, new_features.size(3)] 48 | ) # (B, mlp[-1], npoint, 1) 49 | else: 50 | raise NotImplementedError 51 | 52 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 53 | new_features_list.append(new_features) 54 | 55 | return new_xyz, torch.cat(new_features_list, dim=1) 56 | 57 | 58 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 59 | """Pointnet set abstraction layer with multiscale grouping""" 60 | 61 | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, 62 | use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 63 | """ 64 | :param npoint: int 65 | :param radii: list of float, list of radii to group with 66 | :param nsamples: list of int, number of samples in each ball query 67 | :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale 68 | :param bn: whether to use batchnorm 69 | :param use_xyz: 70 | :param pool_method: max_pool / avg_pool 71 | :param instance_norm: whether to use instance_norm 72 | """ 73 | super().__init__() 74 | 75 | assert len(radii) == len(nsamples) == len(mlps) 76 | 77 | self.npoint = npoint 78 | self.groupers = nn.ModuleList() 79 | self.mlps = nn.ModuleList() 80 | for i in range(len(radii)): 81 | radius = radii[i] 82 | nsample = nsamples[i] 83 | self.groupers.append( 84 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 85 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 86 | ) 87 | mlp_spec = mlps[i] 88 | if use_xyz: 89 | mlp_spec[0] += 3 90 | 91 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) 92 | self.pool_method = pool_method 93 | 94 | 95 | class PointnetSAModule(PointnetSAModuleMSG): 96 | """Pointnet set abstraction layer""" 97 | 98 | def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, 99 | bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 100 | """ 101 | :param mlp: list of int, spec of the pointnet before the global max_pool 102 | :param npoint: int, number of features 103 | :param radius: float, radius of ball 104 | :param nsample: int, number of samples in the ball query 105 | :param bn: whether to use batchnorm 106 | :param use_xyz: 107 | :param pool_method: max_pool / avg_pool 108 | :param instance_norm: whether to use instance_norm 109 | """ 110 | super().__init__( 111 | mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, 112 | pool_method=pool_method, instance_norm=instance_norm 113 | ) 114 | 115 | 116 | class PointnetFPModule(nn.Module): 117 | r"""Propigates the features of one set to another""" 118 | 119 | def __init__(self, *, mlp: List[int], bn: bool = True): 120 | """ 121 | :param mlp: list of int 122 | :param bn: whether to use batchnorm 123 | """ 124 | super().__init__() 125 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 126 | 127 | def forward( 128 | self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor 129 | ) -> torch.Tensor: 130 | """ 131 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features 132 | :param known: (B, m, 3) tensor of the xyz positions of the known features 133 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to 134 | :param known_feats: (B, C2, m) tensor of features to be propigated 135 | :return: 136 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features 137 | """ 138 | if known is not None: 139 | dist, idx = pointnet2_utils.three_nn(unknown, known) 140 | dist_recip = 1.0 / (dist + 1e-8) 141 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 142 | weight = dist_recip / norm 143 | 144 | interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) 145 | else: 146 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) 147 | 148 | if unknow_feats is not None: 149 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) 150 | else: 151 | new_features = interpolated_feats 152 | 153 | new_features = new_features.unsqueeze(-1) 154 | new_features = self.mlp(new_features) 155 | 156 | return new_features.squeeze(-1) 157 | 158 | 159 | if __name__ == "__main__": 160 | pass 161 | -------------------------------------------------------------------------------- /lib/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.autograd import Function 4 | import torch.nn as nn 5 | from typing import Tuple 6 | 7 | import pointnet2_cuda as pointnet2 8 | 9 | 10 | class FurthestPointSampling(Function): 11 | @staticmethod 12 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 13 | """ 14 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 15 | minimum distance 16 | :param ctx: 17 | :param xyz: (B, N, 3) where N > npoint 18 | :param npoint: int, number of features in the sampled set 19 | :return: 20 | output: (B, npoint) tensor containing the set 21 | """ 22 | assert xyz.is_contiguous() 23 | 24 | B, N, _ = xyz.size() 25 | output = torch.cuda.IntTensor(B, npoint) 26 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 27 | 28 | pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 29 | return output 30 | 31 | @staticmethod 32 | def backward(xyz, a=None): 33 | return None, None 34 | 35 | 36 | furthest_point_sample = FurthestPointSampling.apply 37 | 38 | 39 | class GatherOperation(Function): 40 | 41 | @staticmethod 42 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 43 | """ 44 | :param ctx: 45 | :param features: (B, C, N) 46 | :param idx: (B, npoint) index tensor of the features to gather 47 | :return: 48 | output: (B, C, npoint) 49 | """ 50 | assert features.is_contiguous() 51 | assert idx.is_contiguous() 52 | 53 | B, npoint = idx.size() 54 | _, C, N = features.size() 55 | output = torch.cuda.FloatTensor(B, C, npoint) 56 | 57 | pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) 58 | 59 | ctx.for_backwards = (idx, C, N) 60 | return output 61 | 62 | @staticmethod 63 | def backward(ctx, grad_out): 64 | idx, C, N = ctx.for_backwards 65 | B, npoint = idx.size() 66 | 67 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 68 | grad_out_data = grad_out.data.contiguous() 69 | pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 70 | return grad_features, None 71 | 72 | 73 | gather_operation = GatherOperation.apply 74 | 75 | class KNN(Function): 76 | 77 | @staticmethod 78 | def forward(ctx, k: int, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Find the three nearest neighbors of unknown in known 81 | :param ctx: 82 | :param unknown: (B, N, 3) 83 | :param known: (B, M, 3) 84 | :return: 85 | dist: (B, N, k) l2 distance to the three nearest neighbors 86 | idx: (B, N, k) index of 3 nearest neighbors 87 | """ 88 | assert unknown.is_contiguous() 89 | assert known.is_contiguous() 90 | 91 | B, N, _ = unknown.size() 92 | m = known.size(1) 93 | dist2 = torch.cuda.FloatTensor(B, N, k) 94 | idx = torch.cuda.IntTensor(B, N, k) 95 | 96 | pointnet2.knn_wrapper(B, N, m, k, unknown, known, dist2, idx) 97 | return torch.sqrt(dist2), idx 98 | 99 | @staticmethod 100 | def backward(ctx, a=None, b=None): 101 | return None, None, None 102 | knn = KNN.apply 103 | 104 | class ThreeNN(Function): 105 | 106 | @staticmethod 107 | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 108 | """ 109 | Find the three nearest neighbors of unknown in known 110 | :param ctx: 111 | :param unknown: (B, N, 3) 112 | :param known: (B, M, 3) 113 | :return: 114 | dist: (B, N, 3) l2 distance to the three nearest neighbors 115 | idx: (B, N, 3) index of 3 nearest neighbors 116 | """ 117 | assert unknown.is_contiguous() 118 | assert known.is_contiguous() 119 | 120 | B, N, _ = unknown.size() 121 | m = known.size(1) 122 | dist2 = torch.cuda.FloatTensor(B, N, 3) 123 | idx = torch.cuda.IntTensor(B, N, 3) 124 | 125 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) 126 | return torch.sqrt(dist2), idx 127 | 128 | @staticmethod 129 | def backward(ctx, a=None, b=None): 130 | return None, None 131 | 132 | 133 | three_nn = ThreeNN.apply 134 | 135 | 136 | class ThreeInterpolate(Function): 137 | 138 | @staticmethod 139 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 140 | """ 141 | Performs weight linear interpolation on 3 features 142 | :param ctx: 143 | :param features: (B, C, M) Features descriptors to be interpolated from 144 | :param idx: (B, n, 3) three nearest neighbors of the target features in features 145 | :param weight: (B, n, 3) weights 146 | :return: 147 | output: (B, C, N) tensor of the interpolated features 148 | """ 149 | assert features.is_contiguous() 150 | assert idx.is_contiguous() 151 | assert weight.is_contiguous() 152 | 153 | B, c, m = features.size() 154 | n = idx.size(1) 155 | ctx.three_interpolate_for_backward = (idx, weight, m) 156 | output = torch.cuda.FloatTensor(B, c, n) 157 | 158 | pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) 159 | return output 160 | 161 | @staticmethod 162 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 163 | """ 164 | :param ctx: 165 | :param grad_out: (B, C, N) tensor with gradients of outputs 166 | :return: 167 | grad_features: (B, C, M) tensor with gradients of features 168 | None: 169 | None: 170 | """ 171 | idx, weight, m = ctx.three_interpolate_for_backward 172 | B, c, n = grad_out.size() 173 | 174 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) 175 | grad_out_data = grad_out.data.contiguous() 176 | 177 | pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) 178 | return grad_features, None, None 179 | 180 | 181 | three_interpolate = ThreeInterpolate.apply 182 | 183 | 184 | class GroupingOperation(Function): 185 | 186 | @staticmethod 187 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 188 | """ 189 | :param ctx: 190 | :param features: (B, C, N) tensor of features to group 191 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with 192 | :return: 193 | output: (B, C, npoint, nsample) tensor 194 | """ 195 | assert features.is_contiguous() 196 | assert idx.is_contiguous() 197 | idx = idx.int() 198 | B, nfeatures, nsample = idx.size() 199 | _, C, N = features.size() 200 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) 201 | 202 | pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) 203 | 204 | ctx.for_backwards = (idx, N) 205 | return output 206 | 207 | @staticmethod 208 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 209 | """ 210 | :param ctx: 211 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward 212 | :return: 213 | grad_features: (B, C, N) gradient of the features 214 | """ 215 | idx, N = ctx.for_backwards 216 | 217 | B, C, npoint, nsample = grad_out.size() 218 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 219 | 220 | grad_out_data = grad_out.data.contiguous() 221 | pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) 222 | return grad_features, None 223 | 224 | 225 | grouping_operation = GroupingOperation.apply 226 | 227 | 228 | class BallQuery(Function): 229 | 230 | @staticmethod 231 | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: 232 | """ 233 | :param ctx: 234 | :param radius: float, radius of the balls 235 | :param nsample: int, maximum number of features in the balls 236 | :param xyz: (B, N, 3) xyz coordinates of the features 237 | :param new_xyz: (B, npoint, 3) centers of the ball query 238 | :return: 239 | idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls 240 | """ 241 | assert new_xyz.is_contiguous() 242 | assert xyz.is_contiguous() 243 | 244 | B, N, _ = xyz.size() 245 | npoint = new_xyz.size(1) 246 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() 247 | 248 | pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) 249 | return idx 250 | 251 | @staticmethod 252 | def backward(ctx, a=None): 253 | return None, None, None, None 254 | 255 | 256 | ball_query = BallQuery.apply 257 | 258 | 259 | class QueryAndGroup(nn.Module): 260 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True): 261 | """ 262 | :param radius: float, radius of ball 263 | :param nsample: int, maximum number of features to gather in the ball 264 | :param use_xyz: 265 | """ 266 | super().__init__() 267 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 268 | 269 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: 270 | """ 271 | :param xyz: (B, N, 3) xyz coordinates of the features 272 | :param new_xyz: (B, npoint, 3) centroids 273 | :param features: (B, C, N) descriptors of the features 274 | :return: 275 | new_features: (B, 3 + C, npoint, nsample) 276 | """ 277 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 278 | xyz_trans = xyz.transpose(1, 2).contiguous() 279 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 280 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 281 | 282 | if features is not None: 283 | grouped_features = grouping_operation(features, idx) 284 | if self.use_xyz: 285 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample) 286 | else: 287 | new_features = grouped_features 288 | else: 289 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" 290 | new_features = grouped_xyz 291 | 292 | return new_features 293 | 294 | 295 | class GroupAll(nn.Module): 296 | def __init__(self, use_xyz: bool = True): 297 | super().__init__() 298 | self.use_xyz = use_xyz 299 | 300 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): 301 | """ 302 | :param xyz: (B, N, 3) xyz coordinates of the features 303 | :param new_xyz: ignored 304 | :param features: (B, C, N) descriptors of the features 305 | :return: 306 | new_features: (B, C + 3, 1, N) 307 | """ 308 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 309 | if features is not None: 310 | grouped_features = features.unsqueeze(2) 311 | if self.use_xyz: 312 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) 313 | else: 314 | new_features = grouped_features 315 | else: 316 | new_features = grouped_xyz 317 | 318 | return new_features 319 | -------------------------------------------------------------------------------- /lib/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.ReLU(inplace=True), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, 17 | ): 18 | super().__init__() 19 | 20 | for i in range(len(args) - 1): 21 | self.add_module( 22 | name + 'layer{}'.format(i), 23 | Conv2d( 24 | args[i], 25 | args[i + 1], 26 | bn=(not first or not preact or (i != 0)) and bn, 27 | activation=activation 28 | if (not first or not preact or (i != 0)) else None, 29 | preact=preact, 30 | instance_norm=instance_norm 31 | ) 32 | ) 33 | 34 | 35 | class _ConvBase(nn.Sequential): 36 | 37 | def __init__( 38 | self, 39 | in_size, 40 | out_size, 41 | kernel_size, 42 | stride, 43 | padding, 44 | activation, 45 | bn, 46 | init, 47 | conv=None, 48 | batch_norm=None, 49 | bias=True, 50 | preact=False, 51 | name="", 52 | instance_norm=False, 53 | instance_norm_func=None 54 | ): 55 | super().__init__() 56 | 57 | bias = bias and (not bn) 58 | conv_unit = conv( 59 | in_size, 60 | out_size, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | padding=padding, 64 | bias=bias 65 | ) 66 | init(conv_unit.weight) 67 | if bias: 68 | nn.init.constant_(conv_unit.bias, 0) 69 | 70 | if bn: 71 | if not preact: 72 | bn_unit = batch_norm(out_size) 73 | else: 74 | bn_unit = batch_norm(in_size) 75 | if instance_norm: 76 | if not preact: 77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 78 | else: 79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 80 | 81 | if preact: 82 | if bn: 83 | self.add_module(name + 'bn', bn_unit) 84 | 85 | if activation is not None: 86 | self.add_module(name + 'activation', activation) 87 | 88 | if not bn and instance_norm: 89 | self.add_module(name + 'in', in_unit) 90 | 91 | self.add_module(name + 'conv', conv_unit) 92 | 93 | if not preact: 94 | if bn: 95 | self.add_module(name + 'bn', bn_unit) 96 | 97 | if activation is not None: 98 | self.add_module(name + 'activation', activation) 99 | 100 | if not bn and instance_norm: 101 | self.add_module(name + 'in', in_unit) 102 | 103 | 104 | class _BNBase(nn.Sequential): 105 | 106 | def __init__(self, in_size, batch_norm=None, name=""): 107 | super().__init__() 108 | self.add_module(name + "bn", batch_norm(in_size)) 109 | 110 | nn.init.constant_(self[0].weight, 1.0) 111 | nn.init.constant_(self[0].bias, 0) 112 | 113 | 114 | class BatchNorm1d(_BNBase): 115 | 116 | def __init__(self, in_size: int, *, name: str = ""): 117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 118 | 119 | 120 | class BatchNorm2d(_BNBase): 121 | 122 | def __init__(self, in_size: int, name: str = ""): 123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 124 | 125 | 126 | class Conv1d(_ConvBase): 127 | 128 | def __init__( 129 | self, 130 | in_size: int, 131 | out_size: int, 132 | *, 133 | kernel_size: int = 1, 134 | stride: int = 1, 135 | padding: int = 0, 136 | activation=nn.ReLU(inplace=True), 137 | bn: bool = False, 138 | init=nn.init.kaiming_normal_, 139 | bias: bool = True, 140 | preact: bool = False, 141 | name: str = "", 142 | instance_norm=False 143 | ): 144 | super().__init__( 145 | in_size, 146 | out_size, 147 | kernel_size, 148 | stride, 149 | padding, 150 | activation, 151 | bn, 152 | init, 153 | conv=nn.Conv1d, 154 | batch_norm=BatchNorm1d, 155 | bias=bias, 156 | preact=preact, 157 | name=name, 158 | instance_norm=instance_norm, 159 | instance_norm_func=nn.InstanceNorm1d 160 | ) 161 | 162 | 163 | class Conv2d(_ConvBase): 164 | 165 | def __init__( 166 | self, 167 | in_size: int, 168 | out_size: int, 169 | *, 170 | kernel_size: Tuple[int, int] = (1, 1), 171 | stride: Tuple[int, int] = (1, 1), 172 | padding: Tuple[int, int] = (0, 0), 173 | activation=nn.ReLU(inplace=True), 174 | bn: bool = False, 175 | init=nn.init.kaiming_normal_, 176 | bias: bool = True, 177 | preact: bool = False, 178 | name: str = "", 179 | instance_norm=False 180 | ): 181 | super().__init__( 182 | in_size, 183 | out_size, 184 | kernel_size, 185 | stride, 186 | padding, 187 | activation, 188 | bn, 189 | init, 190 | conv=nn.Conv2d, 191 | batch_norm=BatchNorm2d, 192 | bias=bias, 193 | preact=preact, 194 | name=name, 195 | instance_norm=instance_norm, 196 | instance_norm_func=nn.InstanceNorm2d 197 | ) 198 | 199 | 200 | class FC(nn.Sequential): 201 | 202 | def __init__( 203 | self, 204 | in_size: int, 205 | out_size: int, 206 | *, 207 | activation=nn.ReLU(inplace=True), 208 | bn: bool = False, 209 | init=None, 210 | preact: bool = False, 211 | name: str = "" 212 | ): 213 | super().__init__() 214 | 215 | fc = nn.Linear(in_size, out_size, bias=not bn) 216 | if init is not None: 217 | init(fc.weight) 218 | if not bn: 219 | nn.init.constant(fc.bias, 0) 220 | 221 | if preact: 222 | if bn: 223 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 224 | 225 | if activation is not None: 226 | self.add_module(name + 'activation', activation) 227 | 228 | self.add_module(name + 'fc', fc) 229 | 230 | if not preact: 231 | if bn: 232 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 233 | 234 | if activation is not None: 235 | self.add_module(name + 'activation', activation) 236 | 237 | -------------------------------------------------------------------------------- /lib/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='pointnet2', 6 | ext_modules=[ 7 | CUDAExtension('pointnet2_cuda', [ 8 | 'src/pointnet2_api.cpp', 9 | 10 | 'src/ball_query.cpp', 11 | 'src/ball_query_gpu.cu', 12 | 'src/group_points.cpp', 13 | 'src/group_points_gpu.cu', 14 | 'src/interpolate.cpp', 15 | 'src/interpolate_gpu.cu', 16 | 'src/sampling.cpp', 17 | 'src/sampling_gpu.cu', 18 | ], 19 | extra_compile_args={'cxx': ['-g'], 20 | 'nvcc': ['-O2']}) 21 | ], 22 | cmdclass={'build_ext': BuildExtension} 23 | ) 24 | -------------------------------------------------------------------------------- /lib/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "ball_query_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 15 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(new_xyz_tensor); 17 | CHECK_INPUT(xyz_tensor); 18 | const float *new_xyz = new_xyz_tensor.data(); 19 | const float *xyz = xyz_tensor.data(); 20 | int *idx = idx_tensor.data(); 21 | 22 | cudaStream_t stream = THCState_getCurrentStream(state); 23 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 24 | return 1; 25 | } -------------------------------------------------------------------------------- /lib/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 20 | xyz += bs_idx * n * 3; 21 | idx += bs_idx * m * nsample + pt_idx * nsample; 22 | 23 | float radius2 = radius * radius; 24 | float new_x = new_xyz[0]; 25 | float new_y = new_xyz[1]; 26 | float new_z = new_xyz[2]; 27 | 28 | int cnt = 0; 29 | for (int k = 0; k < n; ++k) { 30 | float x = xyz[k * 3 + 0]; 31 | float y = xyz[k * 3 + 1]; 32 | float z = xyz[k * 3 + 2]; 33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 34 | if (d2 < radius2){ 35 | if (cnt == 0){ 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[l] = k; 38 | } 39 | } 40 | idx[cnt] = k; 41 | ++cnt; 42 | if (cnt >= nsample) break; 43 | } 44 | } 45 | } 46 | 47 | 48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 50 | // new_xyz: (B, M, 3) 51 | // xyz: (B, N, 3) 52 | // output: 53 | // idx: (B, M, nsample) 54 | 55 | cudaError_t err; 56 | 57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 58 | dim3 threads(THREADS_PER_BLOCK); 59 | 60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 61 | // cudaDeviceSynchronize(); // for using printf in kernel function 62 | err = cudaGetLastError(); 63 | if (cudaSuccess != err) { 64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 65 | exit(-1); 66 | } 67 | } -------------------------------------------------------------------------------- /lib/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU_H 2 | #define _BALL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /lib/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | #define THREADS_PER_BLOCK 256 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | 10 | inline int opt_n_threads(int work_size) { 11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 12 | 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /lib/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "group_points_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 12 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 13 | 14 | float *grad_points = grad_points_tensor.data(); 15 | const int *idx = idx_tensor.data(); 16 | const float *grad_out = grad_out_tensor.data(); 17 | 18 | cudaStream_t stream = THCState_getCurrentStream(state); 19 | 20 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); 21 | return 1; 22 | } 23 | 24 | 25 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 26 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 27 | 28 | const float *points = points_tensor.data(); 29 | const int *idx = idx_tensor.data(); 30 | float *out = out_tensor.data(); 31 | 32 | cudaStream_t stream = THCState_getCurrentStream(state); 33 | 34 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); 35 | return 1; 36 | } -------------------------------------------------------------------------------- /lib/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | 8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 10 | // grad_out: (B, C, npoints, nsample) 11 | // idx: (B, npoints, nsample) 12 | // output: 13 | // grad_points: (B, C, N) 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int pt_idx = index / nsample; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 19 | 20 | int sample_idx = index % nsample; 21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 23 | 24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 25 | } 26 | 27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 29 | // grad_out: (B, C, npoints, nsample) 30 | // idx: (B, npoints, nsample) 31 | // output: 32 | // grad_points: (B, C, N) 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | 47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 49 | // points: (B, C, N) 50 | // idx: (B, npoints, nsample) 51 | // output: 52 | // out: (B, C, npoints, nsample) 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int pt_idx = index / nsample; 57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 58 | 59 | int sample_idx = index % nsample; 60 | 61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 64 | 65 | out[out_idx] = points[in_idx]; 66 | } 67 | 68 | 69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 70 | const float *points, const int *idx, float *out, cudaStream_t stream) { 71 | // points: (B, C, N) 72 | // idx: (B, npoints, nsample) 73 | // output: 74 | // out: (B, C, npoints, nsample) 75 | cudaError_t err; 76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 77 | dim3 threads(THREADS_PER_BLOCK); 78 | 79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); 80 | // cudaDeviceSynchronize(); // for using printf in kernel function 81 | err = cudaGetLastError(); 82 | if (cudaSuccess != err) { 83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 84 | exit(-1); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /lib/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS_GPU_H 2 | #define _GROUP_POINTS_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 12 | 13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 14 | const float *points, const int *idx, float *out, cudaStream_t stream); 15 | 16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /lib/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "interpolate_gpu.h" 10 | 11 | extern THCState *state; 12 | 13 | 14 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 15 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 16 | const float *unknown = unknown_tensor.data(); 17 | const float *known = known_tensor.data(); 18 | float *dist2 = dist2_tensor.data(); 19 | int *idx = idx_tensor.data(); 20 | 21 | cudaStream_t stream = THCState_getCurrentStream(state); 22 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 23 | } 24 | 25 | void knn_wrapper_fast(int b, int n, int m, int k, at::Tensor unknown_tensor, 26 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 27 | const float *unknown = unknown_tensor.data(); 28 | const float *known = known_tensor.data(); 29 | float *dist2 = dist2_tensor.data(); 30 | int *idx = idx_tensor.data(); 31 | 32 | cudaStream_t stream = THCState_getCurrentStream(state); 33 | knn_kernel_launcher_fast(b, n, m, k, unknown, known, dist2, idx, stream); 34 | } 35 | 36 | 37 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 38 | at::Tensor points_tensor, 39 | at::Tensor idx_tensor, 40 | at::Tensor weight_tensor, 41 | at::Tensor out_tensor) { 42 | 43 | const float *points = points_tensor.data(); 44 | const float *weight = weight_tensor.data(); 45 | float *out = out_tensor.data(); 46 | const int *idx = idx_tensor.data(); 47 | 48 | cudaStream_t stream = THCState_getCurrentStream(state); 49 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 50 | } 51 | 52 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 53 | at::Tensor grad_out_tensor, 54 | at::Tensor idx_tensor, 55 | at::Tensor weight_tensor, 56 | at::Tensor grad_points_tensor) { 57 | 58 | const float *grad_out = grad_out_tensor.data(); 59 | const float *weight = weight_tensor.data(); 60 | float *grad_points = grad_points_tensor.data(); 61 | const int *idx = idx_tensor.data(); 62 | 63 | cudaStream_t stream = THCState_getCurrentStream(state); 64 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 65 | } -------------------------------------------------------------------------------- /lib/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | 9 | __global__ void knn_kernel_fast(int b, int n, int m, int k, const float *__restrict__ unknown, 10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 11 | // unknown: (B, N, 3) 12 | // known: (B, M, 3) 13 | // output: 14 | // dist2: (B, N, k) 15 | // idx: (B, N, k) 16 | 17 | int bs_idx = blockIdx.y; 18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (bs_idx >= b || pt_idx >= n) return; 20 | 21 | unknown += bs_idx * n * 3 + pt_idx * 3; 22 | known += bs_idx * m * 3; 23 | dist2 += bs_idx * n * k + pt_idx * k; 24 | idx += bs_idx * n * k + pt_idx * k; 25 | 26 | float ux = unknown[0]; 27 | float uy = unknown[1]; 28 | float uz = unknown[2]; 29 | 30 | double best[200]; 31 | int besti[200]; 32 | for(int i = 0; i < k; i++){ 33 | best[i] = 1e40; 34 | besti[i] = 0; 35 | } 36 | for (int i = 0; i < m; ++i) { 37 | float x = known[i * 3 + 0]; 38 | float y = known[i * 3 + 1]; 39 | float z = known[i * 3 + 2]; 40 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 41 | for(int j = 0; j < k; j++){ 42 | if(d < best[j]){ 43 | for(int l = k - 1; l > j; l--){ 44 | best[l] = best[l - 1]; 45 | besti[l] = besti[l - 1]; 46 | } 47 | best[j] = d; 48 | besti[j] = i; 49 | break; 50 | } 51 | } 52 | } 53 | for(int i = 0; i < k; i++){ 54 | idx[i] = besti[i]; 55 | dist2[i] = best[i]; 56 | } 57 | } 58 | 59 | 60 | void knn_kernel_launcher_fast(int b, int n, int m, int k, const float *unknown, 61 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 62 | // unknown: (B, N, 3) 63 | // known: (B, M, 3) 64 | // output: 65 | // dist2: (B, N, k) 66 | // idx: (B, N, k) 67 | 68 | cudaError_t err; 69 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 70 | dim3 threads(THREADS_PER_BLOCK); 71 | 72 | knn_kernel_fast<<>>(b, n, m, k, unknown, known, dist2, idx); 73 | 74 | err = cudaGetLastError(); 75 | if (cudaSuccess != err) { 76 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 77 | exit(-1); 78 | } 79 | } 80 | 81 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 82 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 83 | // unknown: (B, N, 3) 84 | // known: (B, M, 3) 85 | // output: 86 | // dist2: (B, N, 3) 87 | // idx: (B, N, 3) 88 | 89 | int bs_idx = blockIdx.y; 90 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 91 | if (bs_idx >= b || pt_idx >= n) return; 92 | 93 | unknown += bs_idx * n * 3 + pt_idx * 3; 94 | known += bs_idx * m * 3; 95 | dist2 += bs_idx * n * 3 + pt_idx * 3; 96 | idx += bs_idx * n * 3 + pt_idx * 3; 97 | 98 | float ux = unknown[0]; 99 | float uy = unknown[1]; 100 | float uz = unknown[2]; 101 | 102 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 103 | int besti1 = 0, besti2 = 0, besti3 = 0; 104 | for (int k = 0; k < m; ++k) { 105 | float x = known[k * 3 + 0]; 106 | float y = known[k * 3 + 1]; 107 | float z = known[k * 3 + 2]; 108 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 109 | if (d < best1) { 110 | best3 = best2; besti3 = besti2; 111 | best2 = best1; besti2 = besti1; 112 | best1 = d; besti1 = k; 113 | } 114 | else if (d < best2) { 115 | best3 = best2; besti3 = besti2; 116 | best2 = d; besti2 = k; 117 | } 118 | else if (d < best3) { 119 | best3 = d; besti3 = k; 120 | } 121 | } 122 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 123 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 124 | } 125 | 126 | 127 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 128 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 129 | // unknown: (B, N, 3) 130 | // known: (B, M, 3) 131 | // output: 132 | // dist2: (B, N, 3) 133 | // idx: (B, N, 3) 134 | 135 | cudaError_t err; 136 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 137 | dim3 threads(THREADS_PER_BLOCK); 138 | 139 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 140 | 141 | err = cudaGetLastError(); 142 | if (cudaSuccess != err) { 143 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 144 | exit(-1); 145 | } 146 | } 147 | 148 | 149 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 150 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 151 | // points: (B, C, M) 152 | // idx: (B, N, 3) 153 | // weight: (B, N, 3) 154 | // output: 155 | // out: (B, C, N) 156 | 157 | int bs_idx = blockIdx.z; 158 | int c_idx = blockIdx.y; 159 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 160 | 161 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 162 | 163 | weight += bs_idx * n * 3 + pt_idx * 3; 164 | points += bs_idx * c * m + c_idx * m; 165 | idx += bs_idx * n * 3 + pt_idx * 3; 166 | out += bs_idx * c * n + c_idx * n; 167 | 168 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 169 | } 170 | 171 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 172 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 173 | // points: (B, C, M) 174 | // idx: (B, N, 3) 175 | // weight: (B, N, 3) 176 | // output: 177 | // out: (B, C, N) 178 | 179 | cudaError_t err; 180 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 181 | dim3 threads(THREADS_PER_BLOCK); 182 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 183 | 184 | err = cudaGetLastError(); 185 | if (cudaSuccess != err) { 186 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 187 | exit(-1); 188 | } 189 | } 190 | 191 | 192 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 193 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 194 | // grad_out: (B, C, N) 195 | // weight: (B, N, 3) 196 | // output: 197 | // grad_points: (B, C, M) 198 | 199 | int bs_idx = blockIdx.z; 200 | int c_idx = blockIdx.y; 201 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 202 | 203 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 204 | 205 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 206 | weight += bs_idx * n * 3 + pt_idx * 3; 207 | grad_points += bs_idx * c * m + c_idx * m; 208 | idx += bs_idx * n * 3 + pt_idx * 3; 209 | 210 | 211 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 212 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 213 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 214 | } 215 | 216 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 217 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 218 | // grad_out: (B, C, N) 219 | // weight: (B, N, 3) 220 | // output: 221 | // grad_points: (B, C, M) 222 | 223 | cudaError_t err; 224 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 225 | dim3 threads(THREADS_PER_BLOCK); 226 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 227 | 228 | err = cudaGetLastError(); 229 | if (cudaSuccess != err) { 230 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 231 | exit(-1); 232 | } 233 | } -------------------------------------------------------------------------------- /lib/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 12 | 13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 14 | const float *known, float *dist2, int *idx, cudaStream_t stream); 15 | 16 | void knn_wrapper_fast(int b, int n, int m, int k, at::Tensor unknown_tensor, 17 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 18 | 19 | void knn_kernel_launcher_fast(int b, int n, int m, int k, const float *unknown, 20 | const float *known, float *dist2, int *idx, cudaStream_t stream); 21 | 22 | 23 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 24 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 25 | 26 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 27 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 28 | 29 | 30 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 31 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 32 | 33 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 34 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /lib/src/pointnet2_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ball_query_gpu.h" 5 | #include "group_points_gpu.h" 6 | #include "sampling_gpu.h" 7 | #include "interpolate_gpu.h" 8 | 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); 12 | 13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); 14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); 15 | 16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 18 | 19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 20 | 21 | m.def("knn_wrapper", &knn_wrapper_fast, "knn_wrapper_fast"); 22 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 23 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 24 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 25 | } 26 | -------------------------------------------------------------------------------- /lib/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "sampling_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 12 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ 13 | const float *points = points_tensor.data(); 14 | const int *idx = idx_tensor.data(); 15 | float *out = out_tensor.data(); 16 | 17 | cudaStream_t stream = THCState_getCurrentStream(state); 18 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 19 | return 1; 20 | } 21 | 22 | 23 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 24 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 25 | 26 | const float *grad_out = grad_out_tensor.data(); 27 | const int *idx = idx_tensor.data(); 28 | float *grad_points = grad_points_tensor.data(); 29 | 30 | cudaStream_t stream = THCState_getCurrentStream(state); 31 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 32 | return 1; 33 | } 34 | 35 | 36 | int furthest_point_sampling_wrapper(int b, int n, int m, 37 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 38 | 39 | const float *points = points_tensor.data(); 40 | float *temp = temp_tensor.data(); 41 | int *idx = idx_tensor.data(); 42 | 43 | cudaStream_t stream = THCState_getCurrentStream(state); 44 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 45 | return 1; 46 | } 47 | -------------------------------------------------------------------------------- /lib/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | 8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 10 | // points: (B, C, N) 11 | // idx: (B, M) 12 | // output: 13 | // out: (B, C, M) 14 | 15 | int bs_idx = blockIdx.z; 16 | int c_idx = blockIdx.y; 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 19 | 20 | out += bs_idx * c * m + c_idx * m + pt_idx; 21 | idx += bs_idx * m + pt_idx; 22 | points += bs_idx * c * n + c_idx * n; 23 | out[0] = points[idx[0]]; 24 | } 25 | 26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 27 | const float *points, const int *idx, float *out, cudaStream_t stream) { 28 | // points: (B, C, N) 29 | // idx: (B, npoints) 30 | // output: 31 | // out: (B, C, npoints) 32 | 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 47 | const int *__restrict__ idx, float *__restrict__ grad_points) { 48 | // grad_out: (B, C, M) 49 | // idx: (B, M) 50 | // output: 51 | // grad_points: (B, C, N) 52 | 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 57 | 58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 59 | idx += bs_idx * m + pt_idx; 60 | grad_points += bs_idx * c * n + c_idx * n; 61 | 62 | atomicAdd(grad_points + idx[0], grad_out[0]); 63 | } 64 | 65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 67 | // grad_out: (B, C, npoints) 68 | // idx: (B, npoints) 69 | // output: 70 | // grad_points: (B, C, N) 71 | 72 | cudaError_t err; 73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 74 | dim3 threads(THREADS_PER_BLOCK); 75 | 76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 77 | 78 | err = cudaGetLastError(); 79 | if (cudaSuccess != err) { 80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 81 | exit(-1); 82 | } 83 | } 84 | 85 | 86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ 87 | const float v1 = dists[idx1], v2 = dists[idx2]; 88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 89 | dists[idx1] = max(v1, v2); 90 | dists_i[idx1] = v2 > v1 ? i2 : i1; 91 | } 92 | 93 | template 94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 96 | // dataset: (B, N, 3) 97 | // tmp: (B, N) 98 | // output: 99 | // idx: (B, M) 100 | 101 | if (m <= 0) return; 102 | __shared__ float dists[block_size]; 103 | __shared__ int dists_i[block_size]; 104 | 105 | int batch_index = blockIdx.x; 106 | dataset += batch_index * n * 3; 107 | temp += batch_index * n; 108 | idxs += batch_index * m; 109 | 110 | int tid = threadIdx.x; 111 | const int stride = block_size; 112 | 113 | int old = 0; 114 | if (threadIdx.x == 0) 115 | idxs[0] = old; 116 | 117 | __syncthreads(); 118 | for (int j = 1; j < m; j++) { 119 | int besti = 0; 120 | float best = -1; 121 | float x1 = dataset[old * 3 + 0]; 122 | float y1 = dataset[old * 3 + 1]; 123 | float z1 = dataset[old * 3 + 2]; 124 | for (int k = tid; k < n; k += stride) { 125 | float x2, y2, z2; 126 | x2 = dataset[k * 3 + 0]; 127 | y2 = dataset[k * 3 + 1]; 128 | z2 = dataset[k * 3 + 2]; 129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 130 | // if (mag <= 1e-3) 131 | // continue; 132 | 133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 134 | float d2 = min(d, temp[k]); 135 | temp[k] = d2; 136 | besti = d2 > best ? k : besti; 137 | best = d2 > best ? d2 : best; 138 | } 139 | dists[tid] = best; 140 | dists_i[tid] = besti; 141 | __syncthreads(); 142 | 143 | if (block_size >= 1024) { 144 | if (tid < 512) { 145 | __update(dists, dists_i, tid, tid + 512); 146 | } 147 | __syncthreads(); 148 | } 149 | 150 | if (block_size >= 512) { 151 | if (tid < 256) { 152 | __update(dists, dists_i, tid, tid + 256); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 256) { 157 | if (tid < 128) { 158 | __update(dists, dists_i, tid, tid + 128); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 128) { 163 | if (tid < 64) { 164 | __update(dists, dists_i, tid, tid + 64); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 64) { 169 | if (tid < 32) { 170 | __update(dists, dists_i, tid, tid + 32); 171 | } 172 | __syncthreads(); 173 | } 174 | if (block_size >= 32) { 175 | if (tid < 16) { 176 | __update(dists, dists_i, tid, tid + 16); 177 | } 178 | __syncthreads(); 179 | } 180 | if (block_size >= 16) { 181 | if (tid < 8) { 182 | __update(dists, dists_i, tid, tid + 8); 183 | } 184 | __syncthreads(); 185 | } 186 | if (block_size >= 8) { 187 | if (tid < 4) { 188 | __update(dists, dists_i, tid, tid + 4); 189 | } 190 | __syncthreads(); 191 | } 192 | if (block_size >= 4) { 193 | if (tid < 2) { 194 | __update(dists, dists_i, tid, tid + 2); 195 | } 196 | __syncthreads(); 197 | } 198 | if (block_size >= 2) { 199 | if (tid < 1) { 200 | __update(dists, dists_i, tid, tid + 1); 201 | } 202 | __syncthreads(); 203 | } 204 | 205 | old = dists_i[0]; 206 | if (tid == 0) 207 | idxs[j] = old; 208 | } 209 | } 210 | 211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 213 | // dataset: (B, N, 3) 214 | // tmp: (B, N) 215 | // output: 216 | // idx: (B, M) 217 | 218 | cudaError_t err; 219 | unsigned int n_threads = opt_n_threads(n); 220 | 221 | switch (n_threads) { 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | 248 | err = cudaGetLastError(); 249 | if (cudaSuccess != err) { 250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 251 | exit(-1); 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /lib/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 11 | 12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 13 | const float *points, const int *idx, float *out, cudaStream_t stream); 14 | 15 | 16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | 23 | int furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from __future__ import print_function 6 | import os 7 | import gc 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.optim.lr_scheduler import MultiStepLR, StepLR 14 | from data import ModelNet40, SceneflowDataset 15 | from model import FlowNet3D 16 | import numpy as np 17 | from torch.utils.data import DataLoader 18 | # from tensorboardX import SummaryWriter 19 | from tqdm import tqdm 20 | 21 | 22 | class IOStream: 23 | def __init__(self, path): 24 | self.f = open(path, 'a') 25 | 26 | def cprint(self, text): 27 | print(text) 28 | self.f.write(text + '\n') 29 | self.f.flush() 30 | 31 | def close(self): 32 | self.f.close() 33 | 34 | 35 | def _init_(args): 36 | if not os.path.exists('checkpoints'): 37 | os.makedirs('checkpoints') 38 | if not os.path.exists('checkpoints/' + args.exp_name): 39 | os.makedirs('checkpoints/' + args.exp_name) 40 | if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'): 41 | os.makedirs('checkpoints/' + args.exp_name + '/' + 'models') 42 | os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup') 43 | os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup') 44 | os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup') 45 | 46 | def weights_init(m): 47 | classname=m.__class__.__name__ 48 | if classname.find('Conv2d') != -1: 49 | nn.init.kaiming_normal_(m.weight.data) 50 | if classname.find('Conv1d') != -1: 51 | nn.init.kaiming_normal_(m.weight.data) 52 | 53 | def scene_flow_EPE_np(pred, labels, mask): 54 | error = np.sqrt(np.sum((pred - labels)**2, 2) + 1e-20) 55 | 56 | gtflow_len = np.sqrt(np.sum(labels*labels, 2) + 1e-20) # B,N 57 | acc1 = np.sum(np.logical_or((error <= 0.05)*mask, (error/gtflow_len <= 0.05)*mask), axis=1) 58 | acc2 = np.sum(np.logical_or((error <= 0.1)*mask, (error/gtflow_len <= 0.1)*mask), axis=1) 59 | 60 | mask_sum = np.sum(mask, 1) 61 | acc1 = acc1[mask_sum > 0] / mask_sum[mask_sum > 0] 62 | acc1 = np.mean(acc1) 63 | acc2 = acc2[mask_sum > 0] / mask_sum[mask_sum > 0] 64 | acc2 = np.mean(acc2) 65 | 66 | EPE = np.sum(error * mask, 1)[mask_sum > 0] / mask_sum[mask_sum > 0] 67 | EPE = np.mean(EPE) 68 | return EPE, acc1, acc2 69 | 70 | def test_one_epoch(args, net, test_loader): 71 | net.eval() 72 | 73 | total_loss = 0 74 | total_epe = 0 75 | total_acc3d = 0 76 | total_acc3d_2 = 0 77 | num_examples = 0 78 | for i, data in tqdm(enumerate(test_loader), total = len(test_loader)): 79 | pc1, pc2, color1, color2, flow, mask1 = data 80 | pc1 = pc1.cuda().transpose(2,1).contiguous() 81 | pc2 = pc2.cuda().transpose(2,1).contiguous() 82 | color1 = color1.cuda().transpose(2,1).contiguous() 83 | color2 = color2.cuda().transpose(2,1).contiguous() 84 | flow = flow.cuda() 85 | mask1 = mask1.cuda().float() 86 | 87 | batch_size = pc1.size(0) 88 | num_examples += batch_size 89 | flow_pred = net(pc1, pc2, color1, color2).permute(0,2,1) 90 | loss = torch.mean(mask1 * torch.sum((flow_pred - flow) * (flow_pred - flow), -1) / 2.0) 91 | epe_3d, acc_3d, acc_3d_2 = scene_flow_EPE_np(flow_pred.detach().cpu().numpy(), flow.detach().cpu().numpy(), mask1.detach().cpu().numpy()) 92 | total_epe += epe_3d * batch_size 93 | total_acc3d += acc_3d * batch_size 94 | total_acc3d_2+=acc_3d_2*batch_size 95 | # print('batch EPE 3D: %f\tACC 3D: %f\tACC 3D 2: %f' % (epe_3d, acc_3d, acc_3d_2)) 96 | 97 | total_loss += loss.item() * batch_size 98 | 99 | 100 | return total_loss * 1.0 / num_examples, total_epe * 1.0 / num_examples, total_acc3d * 1.0 / num_examples, total_acc3d_2 * 1.0 / num_examples 101 | 102 | 103 | def train_one_epoch(args, net, train_loader, opt): 104 | net.train() 105 | num_examples = 0 106 | total_loss = 0 107 | for i, data in tqdm(enumerate(train_loader), total = len(train_loader)): 108 | pc1, pc2, color1, color2, flow, mask1 = data 109 | pc1 = pc1.cuda().transpose(2,1).contiguous() 110 | pc2 = pc2.cuda().transpose(2,1).contiguous() 111 | color1 = color1.cuda().transpose(2,1).contiguous() 112 | color2 = color2.cuda().transpose(2,1).contiguous() 113 | flow = flow.cuda().transpose(2,1).contiguous() 114 | mask1 = mask1.cuda().float() 115 | 116 | batch_size = pc1.size(0) 117 | opt.zero_grad() 118 | num_examples += batch_size 119 | flow_pred = net(pc1, pc2, color1, color2) 120 | loss = torch.mean(mask1 * torch.sum((flow_pred - flow) ** 2, 1) / 2.0) 121 | loss.backward() 122 | 123 | opt.step() 124 | total_loss += loss.item() * batch_size 125 | 126 | # if (i+1) % 100 == 0: 127 | # print("batch: %d, mean loss: %f" % (i, total_loss / 100 / batch_size)) 128 | # total_loss = 0 129 | return total_loss * 1.0 / num_examples 130 | 131 | 132 | def test(args, net, test_loader, boardio, textio): 133 | 134 | test_loss, epe, acc, acc_2 = test_one_epoch(args, net, test_loader) 135 | 136 | textio.cprint('==FINAL TEST==') 137 | textio.cprint('mean test loss: %f\tEPE 3D: %f\tACC 3D: %f\tACC 3D 2: %f'%(test_loss, epe, acc, acc_2)) 138 | 139 | 140 | def train(args, net, train_loader, test_loader, boardio, textio): 141 | if args.use_sgd: 142 | print("Use SGD") 143 | opt = optim.SGD(net.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=1e-4) 144 | else: 145 | print("Use Adam") 146 | opt = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4) 147 | # scheduler = MultiStepLR(opt, milestones=[75, 150, 200], gamma=0.1) 148 | scheduler = StepLR(opt, 10, gamma = 0.7) 149 | 150 | best_test_loss = np.inf 151 | for epoch in range(args.epochs): 152 | textio.cprint('==epoch: %d, learning rate: %f=='%(epoch, opt.param_groups[0]['lr'])) 153 | train_loss = train_one_epoch(args, net, train_loader, opt) 154 | textio.cprint('mean train EPE loss: %f'%train_loss) 155 | 156 | test_loss, epe, acc, acc_2 = test_one_epoch(args, net, test_loader) 157 | textio.cprint('mean test loss: %f\tEPE 3D: %f\tACC 3D: %f\tACC 3D 2: %f'%(test_loss, epe, acc, acc_2)) 158 | if best_test_loss >= test_loss: 159 | best_test_loss = test_loss 160 | textio.cprint('best test loss till now: %f'%test_loss) 161 | if torch.cuda.device_count() > 1: 162 | torch.save(net.module.state_dict(), 'checkpoints/%s/models/model.best.t7' % args.exp_name) 163 | else: 164 | torch.save(net.state_dict(), 'checkpoints/%s/models/model.best.t7' % args.exp_name) 165 | 166 | scheduler.step() 167 | # if torch.cuda.device_count() > 1: 168 | # torch.save(net.module.state_dict(), 'checkpoints/%s/models/model.%d.t7' % (args.exp_name, epoch)) 169 | # else: 170 | # torch.save(net.state_dict(), 'checkpoints/%s/models/model.%d.t7' % (args.exp_name, epoch)) 171 | # gc.collect() 172 | 173 | 174 | def main(): 175 | parser = argparse.ArgumentParser(description='Point Cloud Registration') 176 | parser.add_argument('--exp_name', type=str, default='exp', metavar='N', 177 | help='Name of the experiment') 178 | parser.add_argument('--model', type=str, default='flownet', metavar='N', 179 | choices=['flownet'], 180 | help='Model to use, [flownet]') 181 | parser.add_argument('--emb_dims', type=int, default=512, metavar='N', 182 | help='Dimension of embeddings') 183 | parser.add_argument('--num_points', type=int, default=2048, 184 | help='Point Number [default: 2048]') 185 | parser.add_argument('--dropout', type=float, default=0.5, metavar='N', 186 | help='Dropout ratio in transformer') 187 | parser.add_argument('--batch_size', type=int, default=64, metavar='batch_size', 188 | help='Size of batch)') 189 | parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size', 190 | help='Size of batch)') 191 | parser.add_argument('--epochs', type=int, default=250, metavar='N', 192 | help='number of episode to train ') 193 | parser.add_argument('--use_sgd', action='store_true', default=False, 194 | help='Use SGD') 195 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 196 | help='learning rate (default: 0.001, 0.1 if using sgd)') 197 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 198 | help='SGD momentum (default: 0.9)') 199 | parser.add_argument('--no_cuda', action='store_true', default=False, 200 | help='enables CUDA training') 201 | parser.add_argument('--seed', type=int, default=1234, metavar='S', 202 | help='random seed (default: 1)') 203 | parser.add_argument('--eval', action='store_true', default=False, 204 | help='evaluate the model') 205 | parser.add_argument('--cycle', type=bool, default=False, metavar='N', 206 | help='Whether to use cycle consistency') 207 | parser.add_argument('--gaussian_noise', type=bool, default=False, metavar='N', 208 | help='Wheter to add gaussian noise') 209 | parser.add_argument('--unseen', type=bool, default=False, metavar='N', 210 | help='Whether to test on unseen category') 211 | parser.add_argument('--dataset', type=str, default='SceneflowDataset', 212 | choices=['SceneflowDataset'], metavar='N', 213 | help='dataset to use') 214 | parser.add_argument('--dataset_path', type=str, default='../../datasets/data_processed_maxcut_35_20k_2k_8192', metavar='N', 215 | help='dataset to use') 216 | parser.add_argument('--model_path', type=str, default='', metavar='N', 217 | help='Pretrained model path') 218 | 219 | args = parser.parse_args() 220 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 221 | # CUDA settings 222 | torch.backends.cudnn.deterministic = True 223 | torch.manual_seed(args.seed) 224 | torch.cuda.manual_seed_all(args.seed) 225 | np.random.seed(args.seed) 226 | 227 | # boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name) 228 | boardio = [] 229 | _init_(args) 230 | 231 | textio = IOStream('checkpoints/' + args.exp_name + '/run.log') 232 | textio.cprint(str(args)) 233 | 234 | if args.dataset == 'modelnet40': 235 | train_loader = DataLoader( 236 | ModelNet40(num_points=args.num_points, partition='train', gaussian_noise=args.gaussian_noise, 237 | unseen=args.unseen, factor=args.factor), 238 | batch_size=args.batch_size, shuffle=True, drop_last=True) 239 | test_loader = DataLoader( 240 | ModelNet40(num_points=args.num_points, partition='test', gaussian_noise=args.gaussian_noise, 241 | unseen=args.unseen, factor=args.factor), 242 | batch_size=args.test_batch_size, shuffle=False, drop_last=False) 243 | elif args.dataset == 'SceneflowDataset': 244 | train_loader = DataLoader( 245 | SceneflowDataset(npoints=args.num_points, root = args.dataset_path, partition='train'), 246 | batch_size=args.batch_size, shuffle=True, drop_last=True) 247 | test_loader = DataLoader( 248 | SceneflowDataset(npoints=args.num_points, root = args.dataset_path, partition='test'), 249 | batch_size=args.test_batch_size, shuffle=False, drop_last=False) 250 | else: 251 | raise Exception("not implemented") 252 | 253 | if args.model == 'flownet': 254 | net = FlowNet3D(args).cuda() 255 | net.apply(weights_init) 256 | if args.eval: 257 | if args.model_path is '': 258 | model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7' 259 | else: 260 | model_path = args.model_path 261 | print(model_path) 262 | if not os.path.exists(model_path): 263 | print("can't find pretrained model") 264 | return 265 | net.load_state_dict(torch.load(model_path), strict=False) 266 | if torch.cuda.device_count() > 1: 267 | net = nn.DataParallel(net) 268 | print("Let's use", torch.cuda.device_count(), "GPUs!") 269 | else: 270 | raise Exception('Not implemented') 271 | if args.eval: 272 | test(args, net, test_loader, boardio, textio) 273 | else: 274 | train(args, net, train_loader, test_loader, boardio, textio) 275 | 276 | 277 | print('FINISH') 278 | # boardio.close() 279 | 280 | 281 | if __name__ == '__main__': 282 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from util import PointNetSetAbstraction,PointNetFeaturePropogation,FlowEmbedding,PointNetSetUpConv 6 | 7 | 8 | 9 | class FlowNet3D(nn.Module): 10 | def __init__(self,args): 11 | super(FlowNet3D,self).__init__() 12 | 13 | self.sa1 = PointNetSetAbstraction(npoint=1024, radius=0.5, nsample=16, in_channel=3, mlp=[32,32,64], group_all=False) 14 | self.sa2 = PointNetSetAbstraction(npoint=256, radius=1.0, nsample=16, in_channel=64, mlp=[64, 64, 128], group_all=False) 15 | self.sa3 = PointNetSetAbstraction(npoint=64, radius=2.0, nsample=8, in_channel=128, mlp=[128, 128, 256], group_all=False) 16 | self.sa4 = PointNetSetAbstraction(npoint=16, radius=4.0, nsample=8, in_channel=256, mlp=[256,256,512], group_all=False) 17 | 18 | self.fe_layer = FlowEmbedding(radius=10.0, nsample=64, in_channel = 128, mlp=[128, 128, 128], pooling='max', corr_func='concat') 19 | 20 | self.su1 = PointNetSetUpConv(nsample=8, radius=2.4, f1_channel = 256, f2_channel = 512, mlp=[], mlp2=[256, 256]) 21 | self.su2 = PointNetSetUpConv(nsample=8, radius=1.2, f1_channel = 128+128, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256]) 22 | self.su3 = PointNetSetUpConv(nsample=8, radius=0.6, f1_channel = 64, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256]) 23 | self.fp = PointNetFeaturePropogation(in_channel = 256+3, mlp = [256, 256]) 24 | 25 | self.conv1 = nn.Conv1d(256, 128, kernel_size=1, bias=False) 26 | self.bn1 = nn.BatchNorm1d(128) 27 | self.conv2=nn.Conv1d(128, 3, kernel_size=1, bias=True) 28 | 29 | def forward(self, pc1, pc2, feature1, feature2): 30 | l1_pc1, l1_feature1 = self.sa1(pc1, feature1) 31 | l2_pc1, l2_feature1 = self.sa2(l1_pc1, l1_feature1) 32 | 33 | l1_pc2, l1_feature2 = self.sa1(pc2, feature2) 34 | l2_pc2, l2_feature2 = self.sa2(l1_pc2, l1_feature2) 35 | 36 | _, l2_feature1_new = self.fe_layer(l2_pc1, l2_pc2, l2_feature1, l2_feature2) 37 | 38 | l3_pc1, l3_feature1 = self.sa3(l2_pc1, l2_feature1_new) 39 | l4_pc1, l4_feature1 = self.sa4(l3_pc1, l3_feature1) 40 | 41 | l3_fnew1 = self.su1(l3_pc1, l4_pc1, l3_feature1, l4_feature1) 42 | l2_fnew1 = self.su2(l2_pc1, l3_pc1, torch.cat([l2_feature1, l2_feature1_new], dim=1), l3_fnew1) 43 | l1_fnew1 = self.su3(l1_pc1, l2_pc1, l1_feature1, l2_fnew1) 44 | l0_fnew1 = self.fp(pc1, l1_pc1, feature1, l1_fnew1) 45 | 46 | x = F.relu(self.bn1(self.conv1(l0_fnew1))) 47 | sf = self.conv2(x) 48 | return sf 49 | 50 | if __name__ == '__main__': 51 | import os 52 | import torch 53 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 54 | input = torch.randn((8,3,2048)) 55 | label = torch.randn(8,16) 56 | model = FlowNet3D() 57 | output = model(input,input) 58 | print(output.size()) 59 | -------------------------------------------------------------------------------- /pretrained_model/model.best.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyangwinter/flownet3d_pytorch/ae0847d242d3582b3f6f115e64f61e637ef80355/pretrained_model/model.best.t7 -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | from lib import pointnet2_utils as pointutils 7 | # import lib.pointnet2_utils as pointutils 8 | 9 | def quat2mat(quat): 10 | x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] 11 | 12 | B = quat.size(0) 13 | 14 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 15 | wx, wy, wz = w*x, w*y, w*z 16 | xy, xz, yz = x*y, x*z, y*z 17 | 18 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 19 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 20 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) 21 | return rotMat 22 | 23 | def transform_point_cloud(point_cloud, rotation, translation): 24 | if len(rotation.size()) == 2: 25 | rot_mat = quat2mat(rotation) 26 | else: 27 | rot_mat = rotation 28 | return torch.matmul(rot_mat, point_cloud) + translation.unsqueeze(2) 29 | 30 | 31 | def npmat2euler(mats, seq='zyx'): 32 | eulers = [] 33 | for i in range(mats.shape[0]): 34 | r = Rotation.from_dcm(mats[i]) 35 | eulers.append(r.as_euler(seq, degrees=True)) 36 | return np.asarray(eulers, dtype='float32') 37 | 38 | def timeit(tag, t): 39 | print("{}: {}s".format(tag, time() - t)) 40 | return time() 41 | 42 | def pc_normalize(pc): 43 | l = pc.shape[0] 44 | centroid = np.mean(pc, axis=0) 45 | pc = pc - centroid 46 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 47 | pc = pc / m 48 | return pc 49 | 50 | def square_distance(src, dst): 51 | """ 52 | Calculate Euclid distance between each two points. 53 | 54 | src^T * dst = xn * xm + yn * ym + zn * zm; 55 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 56 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 57 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 58 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 59 | 60 | Input: 61 | src: source points, [B, N, C] 62 | dst: target points, [B, M, C] 63 | Output: 64 | dist: per-point square distance, [B, N, M] 65 | """ 66 | B, N, _ = src.shape 67 | _, M, _ = dst.shape 68 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 69 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 70 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 71 | return dist 72 | 73 | 74 | def index_points(points, idx): 75 | """ 76 | 77 | Input: 78 | points: input points data, [B, N, C] 79 | idx: sample index data, [B, S] 80 | Return: 81 | new_points:, indexed points data, [B, S, C] 82 | """ 83 | device = points.device 84 | B = points.shape[0] 85 | view_shape = list(idx.shape) 86 | view_shape[1:] = [1] * (len(view_shape) - 1) 87 | repeat_shape = list(idx.shape) 88 | repeat_shape[0] = 1 89 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 90 | new_points = points[batch_indices, idx, :] 91 | return new_points 92 | 93 | 94 | def farthest_point_sample(xyz, npoint): 95 | """ 96 | Input: 97 | xyz: pointcloud data, [B, N, C] 98 | npoint: number of samples 99 | Return: 100 | centroids: sampled pointcloud index, [B, npoint] 101 | """ 102 | device = xyz.device 103 | B, N, C = xyz.shape 104 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 105 | distance = torch.ones(B, N).to(device) * 1e10 106 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 107 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 108 | for i in range(npoint): 109 | centroids[:, i] = farthest 110 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 111 | dist = torch.sum((xyz - centroid) ** 2, -1) 112 | mask = dist < distance 113 | distance[mask] = dist[mask] 114 | farthest = torch.max(distance, -1)[1] 115 | return centroids 116 | 117 | def knn_point(k, pos1, pos2): 118 | ''' 119 | Input: 120 | k: int32, number of k in k-nn search 121 | pos1: (batch_size, ndataset, c) float32 array, input points 122 | pos2: (batch_size, npoint, c) float32 array, query points 123 | Output: 124 | val: (batch_size, npoint, k) float32 array, L2 distances 125 | idx: (batch_size, npoint, k) int32 array, indices to input points 126 | ''' 127 | B, N, C = pos1.shape 128 | M = pos2.shape[1] 129 | pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1) 130 | pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1) 131 | dist = torch.sum(-(pos1-pos2)**2,-1) 132 | val,idx = dist.topk(k=k,dim = -1) 133 | return torch.sqrt(-val), idx 134 | 135 | 136 | def query_ball_point(radius, nsample, xyz, new_xyz): 137 | """ 138 | Input: 139 | radius: local region radius 140 | nsample: max sample number in local region 141 | xyz: all points, [B, N, C] 142 | new_xyz: query points, [B, S, C] 143 | Return: 144 | group_idx: grouped points index, [B, S, nsample] 145 | """ 146 | device = xyz.device 147 | B, N, C = xyz.shape 148 | _, S, _ = new_xyz.shape 149 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 150 | sqrdists = square_distance(new_xyz, xyz) 151 | group_idx[sqrdists > radius ** 2] = N 152 | mask = group_idx != N 153 | cnt = mask.sum(dim=-1) 154 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 155 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 156 | mask = group_idx == N 157 | group_idx[mask] = group_first[mask] 158 | return group_idx, cnt 159 | 160 | 161 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 162 | """ 163 | Input: 164 | npoint: 165 | radius: 166 | nsample: 167 | xyz: input points position data, [B, N, C] 168 | points: input points data, [B, N, D] 169 | Return: 170 | new_xyz: sampled points position data, [B, 1, C] 171 | new_points: sampled points data, [B, 1, N, C+D] 172 | """ 173 | B, N, C = xyz.shape 174 | S = npoint 175 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 176 | new_xyz = index_points(xyz, fps_idx) 177 | idx, _ = query_ball_point(radius, nsample, xyz, new_xyz) 178 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 179 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 180 | if points is not None: 181 | grouped_points = index_points(points, idx) 182 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 183 | else: 184 | new_points = grouped_xyz_norm 185 | if returnfps: 186 | return new_xyz, new_points, grouped_xyz, fps_idx 187 | else: 188 | return new_xyz, new_points 189 | 190 | 191 | def sample_and_group_all(xyz, points): 192 | """ 193 | Input: 194 | xyz: input points position data, [B, N, C] 195 | points: input points data, [B, N, D] 196 | Return: 197 | new_xyz: sampled points position data, [B, 1, C] 198 | new_points: sampled points data, [B, 1, N, C+D] 199 | """ 200 | device = xyz.device 201 | B, N, C = xyz.shape 202 | new_xyz = torch.zeros(B, 1, C).to(device) 203 | grouped_xyz = xyz.view(B, 1, N, C) 204 | if points is not None: 205 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 206 | else: 207 | new_points = grouped_xyz 208 | return new_xyz, new_points 209 | 210 | class PointNetSetAbstraction(nn.Module): 211 | def __init__(self, npoint, radius, nsample, in_channel, mlp, mlp2 = None, group_all = False): 212 | super(PointNetSetAbstraction, self).__init__() 213 | self.npoint = npoint 214 | self.radius = radius 215 | self.nsample = nsample 216 | self.group_all = group_all 217 | self.mlp_convs = nn.ModuleList() 218 | self.mlp_bns = nn.ModuleList() 219 | self.mlp2_convs = nn.ModuleList() 220 | last_channel = in_channel+3 221 | for out_channel in mlp: 222 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias = False)) 223 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 224 | last_channel = out_channel 225 | for out_channel in mlp2: 226 | self.mlp2_convs.append(nn.Sequential(nn.Conv1d(last_channel, out_channel, 1, bias=False), 227 | nn.BatchNorm1d(out_channel))) 228 | last_channel = out_channel 229 | if group_all: 230 | self.queryandgroup = pointutils.GroupAll() 231 | else: 232 | self.queryandgroup = pointutils.QueryAndGroup(radius, nsample) 233 | 234 | def forward(self, xyz, points): 235 | """ 236 | Input: 237 | xyz: input points position data, [B, C, N] 238 | points: input points data, [B, D, N] 239 | Return: 240 | new_xyz: sampled points position data, [B, S, C] 241 | new_points_concat: sample points feature data, [B, S, D'] 242 | """ 243 | device = xyz.device 244 | B, C, N = xyz.shape 245 | xyz_t = xyz.permute(0, 2, 1).contiguous() 246 | # if points is not None: 247 | # points = points.permute(0, 2, 1).contiguous() 248 | 249 | # 选取邻域点 250 | if self.group_all == False: 251 | fps_idx = pointutils.furthest_point_sample(xyz_t, self.npoint) # [B, N] 252 | new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N] 253 | else: 254 | new_xyz = xyz 255 | new_points = self.queryandgroup(xyz_t, new_xyz.transpose(2, 1).contiguous(), points) # [B, 3+C, N, S] 256 | 257 | # new_xyz: sampled points position data, [B, C, npoint] 258 | # new_points: sampled points data, [B, C+D, npoint, nsample] 259 | for i, conv in enumerate(self.mlp_convs): 260 | bn = self.mlp_bns[i] 261 | new_points = F.relu(bn(conv(new_points))) 262 | 263 | new_points = torch.max(new_points, -1)[0] 264 | 265 | for i, conv in enumerate(self.mlp2_convs): 266 | new_points = F.relu(conv(new_points)) 267 | return new_xyz, new_points 268 | 269 | class FlowEmbedding(nn.Module): 270 | def __init__(self, radius, nsample, in_channel, mlp, pooling='max', corr_func='concat', knn = True): 271 | super(FlowEmbedding, self).__init__() 272 | self.radius = radius 273 | self.nsample = nsample 274 | self.knn = knn 275 | self.pooling = pooling 276 | self.corr_func = corr_func 277 | self.mlp_convs = nn.ModuleList() 278 | self.mlp_bns = nn.ModuleList() 279 | if corr_func is 'concat': 280 | last_channel = in_channel*2+3 281 | for out_channel in mlp: 282 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias=False)) 283 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 284 | last_channel = out_channel 285 | 286 | def forward(self, pos1, pos2, feature1, feature2): 287 | """ 288 | Input: 289 | xyz1: (batch_size, 3, npoint) 290 | xyz2: (batch_size, 3, npoint) 291 | feat1: (batch_size, channel, npoint) 292 | feat2: (batch_size, channel, npoint) 293 | Output: 294 | xyz1: (batch_size, 3, npoint) 295 | feat1_new: (batch_size, mlp[-1], npoint) 296 | """ 297 | pos1_t = pos1.permute(0, 2, 1).contiguous() 298 | pos2_t = pos2.permute(0, 2, 1).contiguous() 299 | B, N, C = pos1_t.shape 300 | if self.knn: 301 | _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) 302 | else: 303 | # If the ball neighborhood points are less than nsample, 304 | # than use the knn neighborhood points 305 | idx, cnt = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t) 306 | # 利用knn取最近的那些点 307 | _, idx_knn = pointutils.knn(self.nsample, pos1_t, pos2_t) 308 | cnt = cnt.view(B, -1, 1).repeat(1, 1, self.nsample) 309 | idx = idx_knn[cnt > (self.nsample-1)] 310 | 311 | pos2_grouped = pointutils.grouping_operation(pos2, idx) # [B, 3, N, S] 312 | pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N, S] 313 | 314 | feat2_grouped = pointutils.grouping_operation(feature2, idx) # [B, C, N, S] 315 | if self.corr_func=='concat': 316 | feat_diff = torch.cat([feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample)], dim = 1) 317 | 318 | feat1_new = torch.cat([pos_diff, feat_diff], dim = 1) # [B, 2*C+3,N,S] 319 | for i, conv in enumerate(self.mlp_convs): 320 | bn = self.mlp_bns[i] 321 | feat1_new = F.relu(bn(conv(feat1_new))) 322 | 323 | feat1_new = torch.max(feat1_new, -1)[0] # [B, mlp[-1], npoint] 324 | return pos1, feat1_new 325 | 326 | class PointNetSetUpConv(nn.Module): 327 | def __init__(self, nsample, radius, f1_channel, f2_channel, mlp, mlp2, knn = True): 328 | super(PointNetSetUpConv, self).__init__() 329 | self.nsample = nsample 330 | self.radius = radius 331 | self.knn = knn 332 | self.mlp1_convs = nn.ModuleList() 333 | self.mlp2_convs = nn.ModuleList() 334 | last_channel = f2_channel+3 335 | for out_channel in mlp: 336 | self.mlp1_convs.append(nn.Sequential(nn.Conv2d(last_channel, out_channel, 1, bias=False), 337 | nn.BatchNorm2d(out_channel), 338 | nn.ReLU(inplace=False))) 339 | last_channel = out_channel 340 | if len(mlp) is not 0: 341 | last_channel = mlp[-1] + f1_channel 342 | else: 343 | last_channel = last_channel + f1_channel 344 | for out_channel in mlp2: 345 | self.mlp2_convs.append(nn.Sequential(nn.Conv1d(last_channel, out_channel, 1, bias=False), 346 | nn.BatchNorm1d(out_channel), 347 | nn.ReLU(inplace=False))) 348 | last_channel = out_channel 349 | 350 | def forward(self, pos1, pos2, feature1, feature2): 351 | """ 352 | Feature propagation from xyz2 (less points) to xyz1 (more points) 353 | 354 | Inputs: 355 | xyz1: (batch_size, 3, npoint1) 356 | xyz2: (batch_size, 3, npoint2) 357 | feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points) 358 | feat2: (batch_size, channel1, npoint2) features for xyz2 points 359 | Output: 360 | feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3) 361 | 362 | TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating. 363 | """ 364 | pos1_t = pos1.permute(0, 2, 1).contiguous() 365 | pos2_t = pos2.permute(0, 2, 1).contiguous() 366 | B,C,N = pos1.shape 367 | if self.knn: 368 | _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) 369 | else: 370 | idx, _ = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t) 371 | 372 | pos2_grouped = pointutils.grouping_operation(pos2, idx) 373 | pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B,3,N1,S] 374 | 375 | feat2_grouped = pointutils.grouping_operation(feature2, idx) 376 | feat_new = torch.cat([feat2_grouped, pos_diff], dim = 1) # [B,C1+3,N1,S] 377 | for conv in self.mlp1_convs: 378 | feat_new = conv(feat_new) 379 | # max pooling 380 | feat_new = feat_new.max(-1)[0] # [B,mlp1[-1],N1] 381 | # concatenate feature in early layer 382 | if feature1 is not None: 383 | feat_new = torch.cat([feat_new, feature1], dim=1) 384 | # feat_new = feat_new.view(B,-1,N,1) 385 | for conv in self.mlp2_convs: 386 | feat_new = conv(feat_new) 387 | 388 | return feat_new 389 | 390 | class PointNetFeaturePropogation(nn.Module): 391 | def __init__(self, in_channel, mlp): 392 | super(PointNetFeaturePropogation, self).__init__() 393 | self.mlp_convs = nn.ModuleList() 394 | self.mlp_bns = nn.ModuleList() 395 | last_channel = in_channel 396 | for out_channel in mlp: 397 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 398 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 399 | last_channel = out_channel 400 | 401 | def forward(self, pos1, pos2, feature1, feature2): 402 | """ 403 | Input: 404 | xyz1: input points position data, [B, C, N] 405 | xyz2: sampled input points position data, [B, C, S] 406 | points1: input points data, [B, D, N] 407 | points2: input points data, [B, D, S] 408 | Return: 409 | new_points: upsampled points data, [B, D', N] 410 | """ 411 | pos1_t = pos1.permute(0, 2, 1).contiguous() 412 | pos2_t = pos2.permute(0, 2, 1).contiguous() 413 | B, C, N = pos1.shape 414 | 415 | # dists = square_distance(pos1, pos2) 416 | # dists, idx = dists.sort(dim=-1) 417 | # dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 418 | dists,idx = pointutils.three_nn(pos1_t,pos2_t) 419 | dists[dists < 1e-10] = 1e-10 420 | weight = 1.0 / dists 421 | weight = weight / torch.sum(weight, -1,keepdim = True) # [B,N,3] 422 | interpolated_feat = torch.sum(pointutils.grouping_operation(feature2, idx) * weight.view(B, 1, N, 3), dim = -1) # [B,C,N,3] 423 | 424 | if feature1 is not None: 425 | feat_new = torch.cat([interpolated_feat, feature1], 1) 426 | else: 427 | feat_new = interpolated_feat 428 | 429 | for i, conv in enumerate(self.mlp_convs): 430 | bn = self.mlp_bns[i] 431 | feat_new = F.relu(bn(conv(feat_new))) 432 | return feat_new --------------------------------------------------------------------------------