├── README.md ├── model_modelnet.py ├── model_scannet.py ├── pointconv ├── cpp_modules.py ├── layers.py └── utils.py ├── tf_ops ├── 3d_interpolation │ ├── interpolate.cpp │ └── tf_interpolate.cpp ├── compile_ops.sh ├── grouping │ ├── tf_grouping.cpp │ ├── tf_grouping_g.cu │ └── tf_grouping_g.cu.o └── sampling │ ├── tf_sampling.cpp │ ├── tf_sampling_g.cu │ └── tf_sampling_g.cu.o ├── train_modelnet.py └── train_scannet.py /README.md: -------------------------------------------------------------------------------- 1 | # PointConv tensorflow 2.0 layers 2 | 3 | This repository containts implementations of the PointConv (Wu et al, 2019) feature encoder and feature decoder layers as `tf.keras.layers` classes. This allows for PointConv layers to be used as part of the standard `tf.keras` api. The repository does not aim to be an exact implementation of the original repostiroy, rather a useful tool for building custom models or simple backend encoders for unordered point sets. For more details regarding the technical details check out the [original paper](https://arxiv.org/abs/1811.07246) and [github page](https://github.com/DylanWusee/pointconv). The implementation also matches the style of the [PointNet++ keras layers](https://github.com/dgriffiths3/pointnet2-tensorflow2). 4 | 5 | > Note: I have only implemented the feature encoding layer. I will add the decoder as soon as I find time. Once the decoder is done, I will upload a ScanNet per-point segmentation model example as well. 6 | 7 | ## Setup 8 | 9 | Requirements: 10 | 11 | ``` 12 | python >= 3.6 13 | tensorflow >= 2.2+ 14 | cuda == 10.1 15 | ``` 16 | > Note: This repository uses the `train_step` model override which is new for `tensorflow 2.2.0`, as such if you wish to use the provided training scripts it is important your tensorflow is not an older version. The layers will work for tensorflow 2.0+. 17 | 18 | 19 | To compile the C++ tensorflow ops, first ensure the `CUDA_ROOT` path in `tf_ops/compile_ops.sh` points correctly to your cuda folder and then compile the ops with: 20 | 21 | ``` 22 | chmod u+x tf_ops/compile_ops.sh 23 | tf_ops/compile_ops.sh 24 | ``` 25 | 26 | ## Usage 27 | 28 | The layers follow the standard `tf.keras.layers` api. To import in your own project, copy the `pointconv` and `tf_ops` folders and set a relative path to find the layers. Here is an example of how a simple PointConv SetAbstraction model can be built using `tf.keras.Model()`. 29 | 30 | ``` 31 | from tensorflow import keras 32 | from pointconv.layers import PointConvSetAbstraction 33 | 34 | class MyModel(keras.Model): 35 | 36 | def __init__(self, batch_size): 37 | super(MyModel, self).__init__() 38 | 39 | self.layer1 = PointConvSA(npoint=512, radius=0.1, sigma=0.1, K=32, mlp=[64, 64, 128], bn=True) 40 | self.layer2 = PointConvSA(npoint=128, radius=0.2, sigma=0.2, K=32, mlp=[128, 128, 256], bn=True) 41 | self.layer2 = PointConvSA(npoint=1, radius=0.8, sigma=0.4, K=32, mlp=[256, 512, 1024], group_all=True bn=True) 42 | 43 | # To make a classifier, just add some fully-connected layers 44 | 45 | self.fn1 = keras.layers.Dense(512) 46 | self.fn2 = keras.layers.Dense(256) 47 | self.fn3 = keras.layers.Dense(n_classes, tf.nn.softmax) 48 | 49 | def call(input): 50 | 51 | xyz, points = self.layer1(input, None, training=training) 52 | xyz, points = self.layer2(xyz, points, training=training) 53 | xyz, points = self.layer3(xyz, points, training=training) 54 | 55 | net = tf.reshape(points, (self.batch_size, -1)) 56 | 57 | net = self.dense1(net) 58 | net = self.dense2(net) 59 | pred = self.dense3(net) 60 | 61 | return pred 62 | ``` 63 | 64 | A full working example of an implemented model for classification and point-wise semantic segmentation can be found in `model_modelnet.py` and `model_scannet.py` respectively. To run, first download the training data from [here](https://drive.google.com/drive/folders/1v5B68RHgDI95KM4EhDrRJxLacJAHcoxz) and place in a folder called `data`. Configure the `config` dictionary to point to where you have saved it. Once the `config` is set, start the training with: 65 | 66 | ``` 67 | python train_modelnet.py 68 | ``` 69 | 70 | or: 71 | 72 | ``` 73 | python train_scannet.py 74 | ``` 75 | 76 | If the config is left to the default you can view training logs with: 77 | 78 | ``` 79 | cd 80 | tensorboard --logdir=logs --port=6006 81 | ``` 82 | and navigate to `localhost:6006` in a web browser. 83 | 84 | ## Note 85 | 86 | If you use these layers in your project remember to cite the original authors: 87 | 88 | ``` 89 | @article{wu2018pointconv, 90 | title={PointConv: Deep Convolutional Networks on 3D Point Clouds}, 91 | author={Wu, Wenxuan and Qi, Zhongang and Fuxin, Li}, 92 | journal={arXiv preprint arXiv:1811.07246}, 93 | year={2018} 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /model_modelnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | 4 | from pointconv.layers import PointConvSA 5 | 6 | 7 | class PointConvModel(keras.Model): 8 | 9 | def __init__(self, batch_size, bn=False, num_classes=40): 10 | super(PointConvModel, self).__init__() 11 | 12 | self.batch_size = batch_size 13 | self.num_classes = num_classes 14 | self.activation = tf.nn.relu 15 | self.kernel_initializer = 'glorot_normal' 16 | self.sigma = 0.1 17 | self.K = 32 18 | self.bn = bn 19 | 20 | self.init_network() 21 | 22 | 23 | def init_network(self): 24 | 25 | self.layer1 = PointConvSA( 26 | npoint = 512, 27 | radius = 0.1, 28 | sigma = self.sigma, 29 | K = self.K, 30 | mlp = [64, 64, 128], 31 | bn = self.bn 32 | ) 33 | 34 | self.layer2 = PointConvSA( 35 | npoint = 128, 36 | radius = 0.2, 37 | sigma = 2*self.sigma, 38 | K = self.K, 39 | mlp = [128, 128, 256], 40 | bn = self.bn 41 | ) 42 | 43 | self.layer3 = PointConvSA( 44 | npoint = 1, 45 | radius = 0.8, 46 | sigma = 4*self.sigma, 47 | K = self.K, 48 | mlp = [256, 512, 1024], 49 | group_all=True, 50 | bn = self.bn 51 | ) 52 | 53 | self.dense1 = keras.layers.Dense(512, activation=self.activation) 54 | self.dropout1 = keras.layers.Dropout(0.4) 55 | 56 | self.dense2 = keras.layers.Dense(256, activation=self.activation) 57 | self.dropout2 = keras.layers.Dropout(0.4) 58 | 59 | self.dense3 = keras.layers.Dense(self.num_classes, activation=tf.nn.softmax) 60 | 61 | 62 | def forward_pass(self, input, training): 63 | 64 | xyz, points = self.layer1(input, None, training=training) 65 | xyz, points = self.layer2(xyz, points, training=training) 66 | xyz, points = self.layer3(xyz, points, training=training) 67 | 68 | net = tf.reshape(points, (self.batch_size, -1)) 69 | 70 | net = self.dense1(net) 71 | net = self.dropout1(net) 72 | 73 | net = self.dense2(net) 74 | net = self.dropout2(net) 75 | 76 | pred = self.dense3(net) 77 | 78 | return pred 79 | 80 | 81 | def train_step(self, input): 82 | 83 | with tf.GradientTape() as tape: 84 | 85 | pred = self.forward_pass(input[0], True) 86 | loss = self.compiled_loss(input[1], pred) 87 | 88 | gradients = tape.gradient(loss, self.trainable_variables) 89 | self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) 90 | 91 | self.compiled_metrics.update_state(input[1], pred) 92 | 93 | return {m.name: m.result() for m in self.metrics} 94 | 95 | 96 | def test_step(self, input): 97 | 98 | pred = self.forward_pass(input[0], False) 99 | loss = self.compiled_loss(input[1], pred) 100 | 101 | self.compiled_metrics.update_state(input[1], pred) 102 | 103 | return {m.name: m.result() for m in self.metrics} 104 | 105 | 106 | def call(self, input, training=False): 107 | 108 | return self.forward_pass(input, training) 109 | -------------------------------------------------------------------------------- /model_scannet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | 4 | from pointconv.layers import PointConvSA, PointConvFP 5 | 6 | 7 | class PointConvModel(keras.Model): 8 | 9 | def __init__(self, batch_size, bn=False, num_classes=21): 10 | super(PointConvModel, self).__init__() 11 | 12 | self.batch_size = batch_size 13 | self.num_classes = num_classes 14 | self.activation = tf.nn.relu 15 | self.kernel_initializer = 'glorot_normal' 16 | self.sigma = 0.1 17 | self.K = 32 18 | self.bn = bn 19 | 20 | self.init_network() 21 | 22 | def init_network(self): 23 | 24 | out_ch = 512 25 | 26 | self.sa_layer1 = PointConvSA( 27 | npoint=1024, radius=0.1, sigma=self.sigma, K=self.K, mlp=[32, 32, 64], bn=self.bn) 28 | self.sa_layer2 = PointConvSA( 29 | npoint=256, radius=0.2, sigma=2*self.sigma, K=self.K, mlp=[64, 64, 128], bn=self.bn) 30 | self.sa_layer3 = PointConvSA( 31 | npoint=64, radius=0.4, sigma=4*self.sigma, K=self.K, mlp=[128, 128, 256], bn=self.bn) 32 | self.sa_layer4 = PointConvSA( 33 | npoint=36, radius=0.8, sigma=8*self.sigma, K=self.K, mlp=[256, 256, 512], bn=self.bn) 34 | 35 | 36 | self.fp_layer1 = PointConvFP( 37 | radius=0.8, sigma=8*self.sigma, K=16, mlp=[out_ch, 512], out_ch=out_ch, bn=self.bn) 38 | self.fp_layer2 = PointConvFP( 39 | radius=0.4, sigma=4*self.sigma, K=16, mlp=[256, 256], out_ch=out_ch, bn=self.bn) 40 | self.fp_layer3 = PointConvFP( 41 | radius=0.2, sigma=2*self.sigma, K=16, mlp=[256, 128], out_ch=out_ch, bn=self.bn) 42 | self.fp_layer4 = PointConvFP( 43 | radius=0.1, sigma=self.sigma, K=16, mlp=[128, 128, 128], out_ch=out_ch, bn=self.bn) 44 | 45 | self.dense1 = keras.layers.Dense(128, activation=self.activation) 46 | self.dropout1 = keras.layers.Dropout(0.4) 47 | self.dense2 = keras.layers.Dense(self.num_classes, activation=tf.nn.softmax) 48 | 49 | def forward_pass(self, input, training): 50 | 51 | l0_xyz = input 52 | l0_points = None 53 | 54 | l1_xyz, l1_points = self.sa_layer1(l0_xyz, l0_points, training=training) 55 | l2_xyz, l2_points = self.sa_layer2(l1_xyz, l1_points, training=training) 56 | l3_xyz, l3_points = self.sa_layer3(l2_xyz, l2_points, training=training) 57 | l4_xyz, l4_points = self.sa_layer4(l3_xyz, l3_points, training=training) 58 | 59 | l3_points = self.fp_layer1(l3_xyz, l4_xyz, l3_points, l4_points, training=training) 60 | l2_points = self.fp_layer2(l2_xyz, l3_xyz, l2_points, l3_points, training=training) 61 | l1_points = self.fp_layer3(l1_xyz, l2_xyz, l1_points, l2_points, training=training) 62 | points = self.fp_layer4(l0_xyz, l1_xyz, l0_points, l1_points) 63 | 64 | net = self.dense1(points) 65 | net = self.dropout1(net) 66 | 67 | pred = self.dense2(net) 68 | 69 | return pred 70 | 71 | def train_step(self, input): 72 | 73 | with tf.GradientTape() as tape: 74 | 75 | pred = self.forward_pass(input[0], True) 76 | loss = self.compiled_loss(input[1], pred) 77 | 78 | gradients = tape.gradient(loss, self.trainable_variables) 79 | self.optimizer.apply_gradients( 80 | zip(gradients, self.trainable_variables)) 81 | 82 | self.compiled_metrics.update_state(input[1], pred) 83 | 84 | return {m.name: m.result() for m in self.metrics} 85 | 86 | def test_step(self, input): 87 | 88 | pred = self.forward_pass(input[0], False) 89 | loss = self.compiled_loss(input[1], pred) 90 | 91 | self.compiled_metrics.update_state(input[1], pred) 92 | 93 | return {m.name: m.result() for m in self.metrics} 94 | 95 | def call(self, input, training=False): 96 | 97 | return self.forward_pass(input, training) 98 | -------------------------------------------------------------------------------- /pointconv/cpp_modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.python.framework import ops 6 | 7 | from tensorflow.keras.layers import MaxPool1D, Layer 8 | 9 | sampling_module = tf.load_op_library('./tf_ops/sampling/tf_sampling_so.so') 10 | grouping_module = tf.load_op_library('./tf_ops/grouping/tf_grouping_so.so') 11 | interpolate_module = tf.load_op_library('./tf_ops/3d_interpolation/tf_interpolate_so.so') 12 | 13 | 14 | def prob_sample(inp, inpr): 15 | return sampling_module.prob_sample(inp, inpr) 16 | 17 | 18 | ops.NoGradient('ProbSample') 19 | 20 | 21 | def gather_point(inp, idx): 22 | return sampling_module.gather_point(inp, idx) 23 | 24 | 25 | @tf.RegisterGradient('GatherPoint') 26 | def _gather_point_grad(op, out_g): 27 | inp = op.inputs[0] 28 | idx = op.inputs[1] 29 | return [sampling_module.gather_point_grad(inp, idx, out_g), None] 30 | 31 | 32 | def farthest_point_sample(npoint, inp): 33 | return sampling_module.farthest_point_sample(inp, npoint) 34 | 35 | 36 | ops.NoGradient('FarthestPointSample') 37 | 38 | 39 | def query_ball_point(radius, nsample, xyz1, xyz2): 40 | return grouping_module.query_ball_point(xyz1, xyz2, radius, nsample) 41 | 42 | 43 | ops.NoGradient('QueryBallPoint') 44 | 45 | 46 | def select_top_k(k, dist): 47 | 48 | return grouping_module.selection_sort(dist, k) 49 | 50 | 51 | ops.NoGradient('SelectionSort') 52 | 53 | 54 | def group_point(points, idx): 55 | 56 | return grouping_module.group_point(points, idx) 57 | 58 | 59 | @tf.RegisterGradient('GroupPoint') 60 | def _group_point_grad(op, grad_out): 61 | points = op.inputs[0] 62 | idx = op.inputs[1] 63 | return [grouping_module.group_point_grad(points, idx, grad_out), None] 64 | 65 | 66 | def knn_point(k, xyz1, xyz2): 67 | 68 | b = xyz1.get_shape()[0].value 69 | n = xyz1.get_shape()[1].value 70 | c = xyz1.get_shape()[2].value 71 | m = xyz2.get_shape()[1].value 72 | print(b, n, c, m) 73 | print(xyz1, (b, 1, n, c)) 74 | xyz1 = tf.tile(tf.reshape(xyz1, (b, 1, n, c)), [1, m, 1, 1]) 75 | xyz2 = tf.tile(tf.reshape(xyz2, (b, m, 1, c)), [1, 1, n, 1]) 76 | dist = tf.reduce_sum((xyz1-xyz2)**2, -1) 77 | print(dist, k) 78 | outi, out = select_top_k(k, dist) 79 | idx = tf.slice(outi, [0, 0, 0], [-1, -1, k]) 80 | val = tf.slice(out, [0, 0, 0], [-1, -1, k]) 81 | print(idx, val) 82 | #val, idx = tf.nn.top_k(-dist, k=k) # ONLY SUPPORT CPU 83 | return val, idx 84 | 85 | 86 | def three_nn(xyz1, xyz2): 87 | return interpolate_module.three_nn(xyz1, xyz2) 88 | 89 | 90 | ops.NoGradient('ThreeNN') 91 | 92 | 93 | def three_interpolate(points, idx, weight): 94 | return interpolate_module.three_interpolate(points, idx, weight) 95 | 96 | 97 | @tf.RegisterGradient('ThreeInterpolate') 98 | def _three_interpolate_grad(op, grad_out): 99 | points = op.inputs[0] 100 | idx = op.inputs[1] 101 | weight = op.inputs[2] 102 | return [interpolate_module.three_interpolate_grad(points, idx, weight, grad_out), None, None] 103 | -------------------------------------------------------------------------------- /pointconv/layers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, './') 3 | 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | from tensorflow.keras.layers import Layer, BatchNormalization 7 | 8 | from pointconv import utils, cpp_modules 9 | 10 | class PointConvSA(keras.layers.Layer): 11 | 12 | def __init__(self, npoint, radius, sigma, K, mlp, group_all=False, activation=tf.nn.relu, bn=False): 13 | 14 | super(PointConvSA, self).__init__() 15 | 16 | self.npoint = npoint 17 | self.radius = radius 18 | self.sigma = sigma 19 | self.K = K 20 | self.mlp = mlp 21 | self.group_all = group_all 22 | self.activation = activation 23 | self.bn = bn 24 | 25 | self.mlp_list = [] 26 | self.weightnet_hidden = [] 27 | self.nonlinear_transform = [] 28 | 29 | 30 | def build(self, input_shape): 31 | 32 | for i, n_filters in enumerate(self.mlp): 33 | self.mlp_list.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 34 | 35 | for i, n_filters in enumerate([32]): 36 | self.weightnet_hidden.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 37 | 38 | for i, n_filters in enumerate([16, 1]): 39 | self.nonlinear_transform.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 40 | 41 | self.np_conv = utils.Conv2d(self.mlp[-1], strides=[1, self.mlp[-1]], activation=self.activation, bn=self.bn) 42 | 43 | super(PointConvSA, self).build(input_shape) 44 | 45 | 46 | def call(self, xyz, feature, training=True): 47 | 48 | num_points = xyz.get_shape()[1] 49 | 50 | if feature is None: 51 | feature = tf.identity(xyz) 52 | 53 | if num_points == self.npoint: 54 | new_xyz = xyz 55 | else: 56 | new_xyz = utils.sampling(self.npoint, xyz) 57 | 58 | if self.group_all == True: 59 | grouped_xyz, grouped_feature, idx = utils.grouping_all(feature, xyz) 60 | else: 61 | grouped_xyz, grouped_feature, idx = utils.grouping(feature, self.K, xyz, new_xyz) 62 | 63 | density = utils.kernel_density_estimation_ball(xyz, self.radius, self.sigma) 64 | inverse_density = tf.math.divide(1.0, density) 65 | grouped_density = tf.gather_nd(inverse_density, idx) # (batch_size, npoint, nsample, 1) 66 | inverse_max_density = tf.reduce_max(grouped_density, axis = 2, keepdims = True) 67 | density_scale = tf.math.divide(grouped_density, inverse_max_density) 68 | 69 | for i, mlp_layer in enumerate(self.mlp_list): 70 | grouped_feature = mlp_layer(grouped_feature, training=training) 71 | 72 | for i, mlp_layer in enumerate(self.weightnet_hidden): 73 | weight = mlp_layer(grouped_xyz, training=training) 74 | 75 | for i, mlp_layer in enumerate(self.nonlinear_transform): 76 | density_scale = mlp_layer(density_scale, training=training) 77 | 78 | new_points = tf.math.multiply(grouped_feature, density_scale) 79 | new_points = tf.transpose(new_points, [0, 1, 3, 2]) 80 | new_points = tf.linalg.matmul(new_points, weight) 81 | 82 | new_points = self.np_conv(new_points, training=training) 83 | new_points = tf.squeeze(new_points, [2]) 84 | 85 | return new_xyz, new_points 86 | 87 | 88 | class PointConvFP(keras.layers.Layer): 89 | 90 | def __init__(self, radius, sigma, K, mlp, out_ch=512, activation=tf.nn.relu, bn=False): 91 | super(PointConvFP, self).__init__() 92 | 93 | self.radius = radius 94 | self.sigma = sigma 95 | self.K = K 96 | self.mlp = mlp 97 | self.activation = activation 98 | self.bn = bn 99 | self.out_ch = out_ch 100 | 101 | self.mlp_list = [] 102 | self.weightnet_hidden = [] 103 | self.nonlinear_transform = [] 104 | 105 | def build(self, input_shape): 106 | 107 | for i, n_filters in enumerate(self.mlp): 108 | self.mlp_list.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 109 | 110 | for i, n_filters in enumerate([32]): 111 | self.weightnet_hidden.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 112 | 113 | for i, n_filters in enumerate([16, 1]): 114 | self.nonlinear_transform.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 115 | 116 | self.np_conv = utils.Conv2d(self.mlp[0], strides=[1, self.out_ch+3], activation=self.activation, bn=self.bn) 117 | 118 | super(PointConvFP, self).build(input_shape) 119 | 120 | def call(self, xyz1, xyz2, points1, points2, training=True): 121 | 122 | dist, idx = cpp_modules.three_nn(xyz1, xyz2) 123 | dist = tf.math.maximum(dist, 1e10) 124 | norm = tf.math.reduce_sum((1.0/dist), axis=2, keepdims=True) 125 | norm = tf.tile(norm, [1, 1, 3]) 126 | weight = (1.0/dist) / norm 127 | interpolated_points = cpp_modules.three_interpolate(points2, idx, weight) 128 | 129 | grouped_xyz, grouped_feature, idx = utils.grouping(interpolated_points, self.K, xyz1, xyz1) 130 | 131 | density = utils.kernel_density_estimation_ball(xyz1, self.radius, self.sigma) 132 | inverse_density = tf.math.divide(1.0, density) 133 | grouped_density = tf.gather_nd(inverse_density, idx) 134 | inverse_max_density = tf.reduce_max(grouped_density, axis=2, keepdims=True) 135 | density_scale = tf.math.divide(grouped_density, inverse_max_density) 136 | 137 | for i, mlp_layer in enumerate(self.weightnet_hidden): 138 | weight = mlp_layer(grouped_xyz, training=training) 139 | 140 | for i, mlp_layer in enumerate(self.nonlinear_transform): 141 | density_scale = mlp_layer(density_scale, training=training) 142 | 143 | new_points = tf.math.multiply(grouped_feature, density_scale) 144 | new_points = tf.transpose(new_points, [0, 1, 3, 2]) 145 | new_points = tf.linalg.matmul(new_points, weight) 146 | 147 | new_points = self.np_conv(new_points, training=training) 148 | 149 | if points1 is not None: 150 | new_points1 = tf.concat([new_points, tf.expand_dims(points1, axis=2)], -1) 151 | else: 152 | new_points1 = new_points 153 | 154 | for i, mlp_layer in enumerate(self.mlp_list): 155 | if i != 0: 156 | new_points1 = mlp_layer(new_points1, training=training) 157 | 158 | new_points1 = tf.squeeze(new_points1, [2]) 159 | 160 | return new_points1 161 | 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /pointconv/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper Function for PointConv 3 | Author: Wenxuan Wu 4 | Date: July 2018 5 | """ 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow import keras 10 | from sklearn.neighbors import KDTree 11 | 12 | from .cpp_modules import ( 13 | farthest_point_sample, 14 | gather_point, 15 | query_ball_point, 16 | group_point, 17 | three_nn 18 | ) 19 | 20 | def knn_kdtree(nsample, xyz, new_xyz): 21 | batch_size = xyz.shape[0] 22 | n_points = new_xyz.shape[1] 23 | 24 | indices = np.zeros((batch_size, n_points, nsample), dtype=np.int32) 25 | for batch_idx in range(batch_size): 26 | X = xyz.numpy()[batch_idx, ...] 27 | q_X = new_xyz[batch_idx, ...] 28 | kdt = KDTree(X, leaf_size=30) 29 | _, indices[batch_idx] = kdt.query(q_X, k=nsample) 30 | 31 | return indices 32 | 33 | 34 | def kernel_density_estimation_ball(pts, radius, sigma, N_points=128, is_norm=False): 35 | 36 | idx, pts_cnt = query_ball_point(radius, N_points, pts, pts) 37 | g_pts = group_point(pts, idx) 38 | g_pts -= tf.tile(tf.expand_dims(pts, 2), [1, 1, N_points, 1]) 39 | 40 | R = tf.sqrt(sigma) 41 | xRinv = tf.math.divide(g_pts, R) 42 | quadform = tf.reduce_sum(tf.square(xRinv), axis=-1) 43 | logsqrtdetSigma = tf.math.log(R) * 3 44 | mvnpdf = tf.exp(-0.5 * quadform - logsqrtdetSigma - 45 | 3 * tf.math.log(2 * 3.1415926) / 2) 46 | 47 | first_val, _ = tf.split(mvnpdf, [1, N_points - 1], axis=2) 48 | 49 | mvnpdf = tf.reduce_sum(mvnpdf, axis=2, keepdims=True) 50 | 51 | num_val_to_sub = tf.expand_dims( 52 | tf.cast(tf.subtract(N_points, pts_cnt), dtype=tf.float32), axis=-1) 53 | 54 | val_to_sub = tf.multiply(first_val, num_val_to_sub) 55 | 56 | mvnpdf = tf.subtract(mvnpdf, val_to_sub) 57 | 58 | scale = tf.math.divide(1.0, tf.expand_dims( 59 | tf.cast(pts_cnt, dtype=tf.float32), axis=-1)) 60 | density = tf.multiply(mvnpdf, scale) 61 | 62 | if is_norm: 63 | density_max = tf.reduce_max(density, axis=1, keepdims=True) 64 | density = tf.math.divide(density, density_max) 65 | 66 | return density 67 | 68 | 69 | def kernel_density_estimation(pts, sigma, kpoint=32, is_norm=False): 70 | with tf.variable_scope("ComputeDensity") as sc: 71 | batch_size = pts.get_shape()[0] 72 | num_points = pts.get_shape()[1] 73 | if num_points < kpoint: 74 | kpoint = num_points.value - 1 75 | with tf.device('/cpu:0'): 76 | point_indices = tf.py_function( 77 | knn_kdtree, [kpoint, pts, pts], tf.int32) 78 | batch_indices = tf.tile(tf.reshape( 79 | tf.range(batch_size), (-1, 1, 1, 1)), (1, num_points, kpoint, 1)) 80 | idx = tf.concat( 81 | [batch_indices, tf.expand_dims(point_indices, axis=3)], axis=3) 82 | idx.set_shape([batch_size, num_points, kpoint, 2]) 83 | 84 | grouped_pts = tf.gather_nd(pts, idx) 85 | # translation normalization 86 | grouped_pts -= tf.tile(tf.expand_dims(pts, 2), [1, 1, kpoint, 1]) 87 | 88 | R = tf.sqrt(sigma) 89 | xRinv = tf.div(grouped_pts, R) 90 | quadform = tf.reduce_sum(tf.square(xRinv), axis=-1) 91 | logsqrtdetSigma = tf.log(R) * 3 92 | mvnpdf = tf.exp(-0.5 * quadform - logsqrtdetSigma - 93 | 3 * tf.log(2 * 3.1415926) / 2) 94 | mvnpdf = tf.reduce_sum(mvnpdf, axis=2, keepdims=True) 95 | 96 | scale = 1.0 / kpoint 97 | density = tf.multiply(mvnpdf, scale) 98 | 99 | if is_norm: 100 | density_max = tf.reduce_max(density, axis=1, keepdims=True) 101 | density = tf.div(density, density_max) 102 | 103 | return density 104 | 105 | 106 | def sampling(npoint, pts): 107 | ''' 108 | inputs: 109 | npoint: scalar, number of points to sample 110 | pointcloud: B * N * 3, input point cloud 111 | output: 112 | sub_pts: B * npoint * 3, sub-sampled point cloud 113 | ''' 114 | 115 | sub_pts = gather_point( 116 | pts, farthest_point_sample(npoint, pts)) 117 | return sub_pts 118 | 119 | 120 | def grouping(feature, K, src_xyz, q_xyz, use_xyz=True): 121 | ''' 122 | K: neighbor size 123 | src_xyz: original point xyz (batch_size, ndataset, 3) 124 | q_xyz: query point xyz (batch_size, npoint, 3) 125 | ''' 126 | 127 | batch_size = src_xyz.get_shape()[0] 128 | npoint = q_xyz.get_shape()[1] 129 | 130 | point_indices = tf.py_function(knn_kdtree, [K, src_xyz, q_xyz], tf.int32) 131 | batch_indices = tf.tile(tf.reshape( 132 | tf.range(batch_size), (-1, 1, 1, 1)), (1, npoint, K, 1)) 133 | idx = tf.concat( 134 | [batch_indices, tf.expand_dims(point_indices, axis=3)], axis=3) 135 | idx.set_shape([batch_size, npoint, K, 2]) 136 | 137 | grouped_xyz = tf.gather_nd(src_xyz, idx) 138 | # translation normalization 139 | grouped_xyz -= tf.tile(tf.expand_dims(q_xyz, 2), [1, 1, K, 1]) 140 | 141 | grouped_feature = tf.gather_nd(feature, idx) 142 | if use_xyz: 143 | new_points = tf.concat([grouped_xyz, grouped_feature], axis=-1) 144 | else: 145 | new_points = grouped_feature 146 | 147 | return grouped_xyz, new_points, idx 148 | 149 | 150 | def grouping_all(feature, src_xyz, use_xyz=True): 151 | 152 | batch_size = src_xyz.get_shape()[0] 153 | npoint = src_xyz.get_shape()[1] 154 | 155 | new_xyz = tf.reduce_mean(src_xyz, axis=1, keepdims=True) 156 | new_xyz = tf.reshape(src_xyz, (batch_size, 1, src_xyz.shape[1], 3)) - tf.reshape(new_xyz, (batch_size, 1, 1, 3)) 157 | 158 | idx = tf.constant(np.tile(np.array(range(npoint)).reshape((1, 1, npoint, 1)), (batch_size, 1, 1, 1)), tf.int32) 159 | idx = tf.concat([tf.tile(tf.reshape(tf.range(batch_size), (-1, 1, 1, 1)), [1, 1, new_xyz.shape[2], 1]), idx], -1) 160 | 161 | grouped_xyz = tf.reshape(src_xyz, (batch_size, 1, npoint, 3)) 162 | 163 | if feature is not None: 164 | if use_xyz: 165 | new_points = tf.concat([src_xyz, feature], axis=2) 166 | else: 167 | new_points = feature 168 | new_points = tf.expand_dims(new_points, 1) 169 | else: 170 | new_points = grouped_xyz 171 | 172 | return grouped_xyz, new_points, idx 173 | 174 | 175 | class Conv2d(keras.layers.Layer): 176 | 177 | def __init__(self, filters, strides=[1, 1], activation=tf.nn.relu, padding='VALID', initializer='glorot_normal', bn=False): 178 | super(Conv2d, self).__init__() 179 | 180 | self.filters = filters 181 | self.strides = strides 182 | self.activation = activation 183 | self.padding = padding 184 | self.initializer = initializer 185 | self.bn = bn 186 | 187 | def build(self, input_shape): 188 | 189 | self.w = self.add_weight( 190 | shape=(1, 1, input_shape[-1], self.filters), 191 | initializer=self.initializer, 192 | trainable=True, 193 | name='pnet_conv' 194 | ) 195 | 196 | if self.bn: 197 | self.bn_layer = keras.layers.BatchNormalization() 198 | 199 | super(Conv2d, self).build(input_shape) 200 | 201 | def call(self, inputs, training=True): 202 | 203 | points = tf.nn.conv2d(inputs, filters=self.w, strides=self.strides, padding=self.padding) 204 | if self.bn:points = self.bn_layer(points, training=training) 205 | if self.activation:points = self.activation(points) 206 | 207 | return points 208 | -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // Find three nearest neigbors with square distance 19 | // input: xyz1 (b,n,3), xyz2(b,m,3) 20 | // output: dist (b,n,3), idx (b,n,3) 21 | void threenn_cpu(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx) { 22 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | using namespace tensorflow; 11 | 12 | REGISTER_OP("ThreeNN") 13 | .Input("xyz1: float32") 14 | .Input("xyz2: float32") 15 | .Output("dist: float32") 16 | .Output("idx: int32") 17 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 18 | c->set_output(0, c->input(0)); 19 | c->set_output(1, c->input(0)); 20 | return Status::OK(); 21 | }); 22 | REGISTER_OP("ThreeInterpolate") 23 | .Input("points: float32") 24 | .Input("idx: int32") 25 | .Input("weight: float32") 26 | .Output("out: float32") 27 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 28 | ::tensorflow::shape_inference::ShapeHandle dims1; // (b,m,c) 29 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 30 | ::tensorflow::shape_inference::ShapeHandle dims2; // (b,n,3) 31 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &dims2)); 32 | // (b,n,c) 33 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), c->Dim(dims2, 1), c->Dim(dims1, 2)}); 34 | c->set_output(0, output); 35 | return Status::OK(); 36 | }); 37 | REGISTER_OP("ThreeInterpolateGrad") 38 | .Input("points: float32") 39 | .Input("idx: int32") 40 | .Input("weight: float32") 41 | .Input("grad_out: float32") 42 | .Output("grad_points: float32") 43 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 44 | c->set_output(0, c->input(0)); 45 | return Status::OK(); 46 | }); 47 | 48 | float randomf(){ 49 | return (rand()+0.5)/(RAND_MAX+1.0); 50 | } 51 | static double get_time(){ 52 | timespec tp; 53 | clock_gettime(CLOCK_MONOTONIC,&tp); 54 | return tp.tv_sec+tp.tv_nsec*1e-9; 55 | } 56 | 57 | // Find three nearest neigbors with square distance 58 | // input: xyz1 (b,n,3), xyz2(b,m,3) 59 | // output: dist (b,n,3), idx (b,n,3) 60 | void threenn_cpu(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx) { 61 | for (int i=0;iinput(0); 163 | OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeNN expects (b,n,3) xyz1 shape.")); 164 | int b = xyz1_tensor.shape().dim_size(0); 165 | int n = xyz1_tensor.shape().dim_size(1); 166 | 167 | const Tensor& xyz2_tensor = context->input(1); 168 | OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeNN expects (b,m,3) xyz2 shape.")); 169 | int m = xyz2_tensor.shape().dim_size(1); 170 | 171 | Tensor *dist_tensor = nullptr; 172 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,n,3}, &dist_tensor)); 173 | Tensor *idx_tensor = nullptr; 174 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,n,3}, &idx_tensor)); 175 | 176 | auto xyz1_flat = xyz1_tensor.flat(); 177 | const float *xyz1 = &(xyz1_flat(0)); 178 | auto xyz2_flat = xyz2_tensor.flat(); 179 | const float *xyz2 = &(xyz2_flat(0)); 180 | auto dist_flat = dist_tensor->flat(); 181 | float *dist = &(dist_flat(0)); 182 | auto idx_flat = idx_tensor->flat(); 183 | int *idx = &(idx_flat(0)); 184 | threenn_cpu(b,n,m,xyz1,xyz2,dist,idx); 185 | } 186 | }; 187 | REGISTER_KERNEL_BUILDER(Name("ThreeNN").Device(DEVICE_CPU), ThreeNNOp); 188 | 189 | 190 | 191 | class ThreeInterpolateOp: public OpKernel{ 192 | public: 193 | explicit ThreeInterpolateOp(OpKernelConstruction * context):OpKernel(context){} 194 | 195 | void Compute(OpKernelContext * context) override { 196 | const Tensor& points_tensor=context->input(0); 197 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("ThreeInterpolate expects (b,m,c) points shape")); 198 | int b = points_tensor.shape().dim_size(0); 199 | int m = points_tensor.shape().dim_size(1); 200 | int c = points_tensor.shape().dim_size(2); 201 | 202 | const Tensor& idx_tensor=context->input(1); 203 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b && idx_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolate expects (b,n,3) idx shape")); 204 | int n = idx_tensor.shape().dim_size(1); 205 | const Tensor& weight_tensor=context->input(2); 206 | OP_REQUIRES(context,weight_tensor.dims()==3 && weight_tensor.shape().dim_size(0)==b && weight_tensor.shape().dim_size(1)==n && weight_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolate expects (b,n,3) weight shape")); 207 | 208 | Tensor * out_tensor = nullptr; 209 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,n,c}, &out_tensor)); 210 | 211 | auto points_flat = points_tensor.flat(); 212 | const float *points = &(points_flat(0)); 213 | auto idx_flat = idx_tensor.flat(); 214 | const int *idx = &(idx_flat(0)); 215 | auto weight_flat = weight_tensor.flat(); 216 | const float *weight = &(weight_flat(0)); 217 | auto out_flat = out_tensor->flat(); 218 | float *out = &(out_flat(0)); 219 | threeinterpolate_cpu(b,m,c,n,points,idx,weight,out); 220 | } 221 | }; 222 | REGISTER_KERNEL_BUILDER(Name("ThreeInterpolate").Device(DEVICE_CPU),ThreeInterpolateOp); 223 | 224 | 225 | class ThreeInterpolateGradOp: public OpKernel{ 226 | public: 227 | explicit ThreeInterpolateGradOp(OpKernelConstruction * context):OpKernel(context){} 228 | 229 | void Compute(OpKernelContext * context) override { 230 | const Tensor& points_tensor=context->input(0); 231 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("ThreeInterpolateGrad expects (b,m,c) points shape")); 232 | int b = points_tensor.shape().dim_size(0); 233 | int m = points_tensor.shape().dim_size(1); 234 | int c = points_tensor.shape().dim_size(2); 235 | 236 | const Tensor& idx_tensor=context->input(1); 237 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,3) idx shape")); 238 | int n = idx_tensor.shape().dim_size(1); 239 | const Tensor& weight_tensor=context->input(2); 240 | OP_REQUIRES(context,weight_tensor.dims()==3 && weight_tensor.shape().dim_size(0)==b && weight_tensor.shape().dim_size(1)==n && weight_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,3) weight shape")); 241 | 242 | const Tensor& grad_out_tensor=context->input(3); 243 | OP_REQUIRES(context,grad_out_tensor.dims()==3 && grad_out_tensor.shape().dim_size(0)==b && grad_out_tensor.shape().dim_size(1)==n && grad_out_tensor.shape().dim_size(2)==c, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,c) grad_out shape")); 244 | 245 | Tensor * grad_points_tensor = nullptr; 246 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,m,c}, &grad_points_tensor)); 247 | 248 | auto points_flat = points_tensor.flat(); 249 | const float *points = &(points_flat(0)); 250 | auto idx_flat = idx_tensor.flat(); 251 | const int *idx = &(idx_flat(0)); 252 | auto weight_flat = weight_tensor.flat(); 253 | const float *weight = &(weight_flat(0)); 254 | auto grad_out_flat = grad_out_tensor.flat(); 255 | const float *grad_out = &(grad_out_flat(0)); 256 | auto grad_points_flat = grad_points_tensor->flat(); 257 | float *grad_points = &(grad_points_flat(0)); 258 | memset(grad_points, 0, sizeof(float)*b*m*c); 259 | threeinterpolate_grad_cpu(b,n,c,m,grad_out,idx,weight,grad_points); 260 | } 261 | }; 262 | REGISTER_KERNEL_BUILDER(Name("ThreeInterpolateGrad").Device(DEVICE_CPU),ThreeInterpolateGradOp); 263 | -------------------------------------------------------------------------------- /tf_ops/compile_ops.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | TF_CFLAGS=$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') 4 | TF_LFLAGS=$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') 5 | CUDA_ROOT=/usr/local/cuda-10.1 6 | 7 | cd tf_ops 8 | 9 | echo "Compiling GPU Ops..." 10 | g++ -std=c++11 -shared ./3d_interpolation/tf_interpolate.cpp -o ./3d_interpolation/tf_interpolate_so.so -I $CUDA_ROOT/include -lcudart -L $CUDA_ROOT/lib64/ -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2 11 | echo "Interpolate op compiled." 12 | 13 | $CUDA_ROOT/bin/nvcc ./grouping/tf_grouping_g.cu -o ./grouping/tf_grouping_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 14 | g++ -std=c++11 -shared ./grouping/tf_grouping.cpp ./grouping/tf_grouping_g.cu.o -o ./grouping/tf_grouping_so.so -I $CUDA_ROOT/include -L $CUDA_ROOT/lib64/ -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2 15 | echo "Grouping op compiled." 16 | 17 | $CUDA_ROOT/bin/nvcc ./sampling/tf_sampling_g.cu -o ./sampling/tf_sampling_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 18 | g++ -std=c++11 -shared ./sampling/tf_sampling.cpp ./sampling/tf_sampling_g.cu.o -o ./sampling/tf_sampling_so.so -I $CUDA_ROOT/include -L $CUDA_ROOT/lib64/ -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2 19 | echo "Sampling op compiled." 20 | 21 | echo "All ops compiled successfully." 22 | cd ../ 23 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | #include 11 | using namespace tensorflow; 12 | 13 | REGISTER_OP("QueryBallPoint") 14 | .Attr("radius: float") 15 | .Attr("nsample: int") 16 | .Input("xyz1: float32") 17 | .Input("xyz2: float32") 18 | .Output("idx: int32") 19 | .Output("pts_cnt: int32") 20 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 21 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoint * 3 22 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &dims2)); 23 | int nsample; 24 | TF_RETURN_IF_ERROR(c->GetAttr("nsample", &nsample)); 25 | ::tensorflow::shape_inference::ShapeHandle output1 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), nsample}); 26 | c->set_output(0, output1); 27 | ::tensorflow::shape_inference::ShapeHandle output2 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)}); 28 | c->set_output(1, output2); 29 | return Status::OK(); 30 | }); 31 | REGISTER_OP("SelectionSort") 32 | .Attr("k: int") 33 | .Input("dist: float32") 34 | .Output("outi: int32") 35 | .Output("out: float32") 36 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 37 | c->set_output(0, c->input(0)); 38 | c->set_output(1, c->input(0)); 39 | return Status::OK(); 40 | }); 41 | REGISTER_OP("GroupPoint") 42 | .Input("points: float32") 43 | .Input("idx: int32") 44 | .Output("out: float32") 45 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 46 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ndataset * channels 47 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 48 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints * nsample 49 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &dims2)); 50 | // batch_size * npoints * nsample * channels 51 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), c->Dim(dims2, 2), c->Dim(dims1, 2)}); 52 | c->set_output(0, output); 53 | return Status::OK(); 54 | }); 55 | REGISTER_OP("GroupPointGrad") 56 | .Input("points: float32") 57 | .Input("idx: int32") 58 | .Input("grad_out: float32") 59 | .Output("grad_points: float32") 60 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 61 | c->set_output(0, c->input(0)); 62 | return Status::OK(); 63 | }); 64 | 65 | 66 | void queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt); 67 | class QueryBallPointGpuOp : public OpKernel { 68 | public: 69 | explicit QueryBallPointGpuOp(OpKernelConstruction* context) : OpKernel(context) { 70 | OP_REQUIRES_OK(context, context->GetAttr("radius", &radius_)); 71 | OP_REQUIRES(context, radius_ > 0, errors::InvalidArgument("QueryBallPoint expects positive radius")); 72 | 73 | OP_REQUIRES_OK(context, context->GetAttr("nsample", &nsample_)); 74 | OP_REQUIRES(context, nsample_ > 0, errors::InvalidArgument("QueryBallPoint expects positive nsample")); 75 | } 76 | 77 | void Compute(OpKernelContext* context) override { 78 | const Tensor& xyz1_tensor = context->input(0); 79 | OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("QueryBallPoint expects (batch_size, ndataset, 3) xyz1 shape.")); 80 | int b = xyz1_tensor.shape().dim_size(0); 81 | int n = xyz1_tensor.shape().dim_size(1); 82 | 83 | const Tensor& xyz2_tensor = context->input(1); 84 | OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("QueryBallPoint expects (batch_size, npoint, 3) xyz2 shape.")); 85 | int m = xyz2_tensor.shape().dim_size(1); 86 | 87 | Tensor *idx_tensor = nullptr; 88 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,nsample_}, &idx_tensor)); 89 | Tensor *pts_cnt_tensor = nullptr; 90 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m}, &pts_cnt_tensor)); 91 | 92 | auto xyz1_flat = xyz1_tensor.flat(); 93 | const float *xyz1 = &(xyz1_flat(0)); 94 | auto xyz2_flat = xyz2_tensor.flat(); 95 | const float *xyz2 = &(xyz2_flat(0)); 96 | auto idx_flat = idx_tensor->flat(); 97 | int *idx = &(idx_flat(0)); 98 | auto pts_cnt_flat = pts_cnt_tensor->flat(); 99 | int *pts_cnt = &(pts_cnt_flat(0)); 100 | queryBallPointLauncher(b,n,m,radius_,nsample_,xyz1,xyz2,idx,pts_cnt); 101 | } 102 | private: 103 | float radius_; 104 | int nsample_; 105 | }; 106 | REGISTER_KERNEL_BUILDER(Name("QueryBallPoint").Device(DEVICE_GPU), QueryBallPointGpuOp); 107 | 108 | void selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out); 109 | class SelectionSortGpuOp : public OpKernel { 110 | public: 111 | explicit SelectionSortGpuOp(OpKernelConstruction* context) : OpKernel(context) { 112 | OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); 113 | OP_REQUIRES(context, k_ > 0, errors::InvalidArgument("SelectionSort expects positive k")); 114 | } 115 | 116 | void Compute(OpKernelContext* context) override { 117 | const Tensor& dist_tensor = context->input(0); 118 | OP_REQUIRES(context, dist_tensor.dims()==3, errors::InvalidArgument("SelectionSort expects (b,m,n) dist shape.")); 119 | int b = dist_tensor.shape().dim_size(0); 120 | int m = dist_tensor.shape().dim_size(1); 121 | int n = dist_tensor.shape().dim_size(2); 122 | 123 | Tensor *outi_tensor = nullptr; 124 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,n}, &outi_tensor)); 125 | Tensor *out_tensor = nullptr; 126 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m,n}, &out_tensor)); 127 | 128 | auto dist_flat = dist_tensor.flat(); 129 | const float *dist = &(dist_flat(0)); 130 | auto outi_flat = outi_tensor->flat(); 131 | int *outi = &(outi_flat(0)); 132 | auto out_flat = out_tensor->flat(); 133 | float *out = &(out_flat(0)); 134 | selectionSortLauncher(b,n,m,k_,dist,outi,out); 135 | } 136 | private: 137 | int k_; 138 | }; 139 | REGISTER_KERNEL_BUILDER(Name("SelectionSort").Device(DEVICE_GPU), SelectionSortGpuOp); 140 | 141 | 142 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out); 143 | class GroupPointGpuOp: public OpKernel{ 144 | public: 145 | explicit GroupPointGpuOp(OpKernelConstruction * context):OpKernel(context){} 146 | 147 | void Compute(OpKernelContext * context) override { 148 | const Tensor& points_tensor=context->input(0); 149 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("GroupPoint expects (batch_size, num_points, channel) points shape")); 150 | int b = points_tensor.shape().dim_size(0); 151 | int n = points_tensor.shape().dim_size(1); 152 | int c = points_tensor.shape().dim_size(2); 153 | 154 | const Tensor& idx_tensor=context->input(1); 155 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("GroupPoint expects (batch_size, npoints, nsample) idx shape")); 156 | int m = idx_tensor.shape().dim_size(1); 157 | int nsample = idx_tensor.shape().dim_size(2); 158 | 159 | Tensor * out_tensor = nullptr; 160 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,m,nsample,c}, &out_tensor)); 161 | 162 | auto points_flat = points_tensor.flat(); 163 | const float *points = &(points_flat(0)); 164 | auto idx_flat = idx_tensor.flat(); 165 | const int *idx = &(idx_flat(0)); 166 | auto out_flat = out_tensor->flat(); 167 | float *out = &(out_flat(0)); 168 | groupPointLauncher(b,n,c,m,nsample,points,idx,out); 169 | } 170 | }; 171 | REGISTER_KERNEL_BUILDER(Name("GroupPoint").Device(DEVICE_GPU),GroupPointGpuOp); 172 | 173 | void groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points); 174 | class GroupPointGradGpuOp: public OpKernel{ 175 | public: 176 | explicit GroupPointGradGpuOp(OpKernelConstruction * context):OpKernel(context){} 177 | 178 | void Compute(OpKernelContext * context) override { 179 | const Tensor& points_tensor=context->input(0); 180 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("GroupPointGrad expects (batch_size, num_points, channel) points shape")); 181 | int b = points_tensor.shape().dim_size(0); 182 | int n = points_tensor.shape().dim_size(1); 183 | int c = points_tensor.shape().dim_size(2); 184 | 185 | const Tensor& idx_tensor=context->input(1); 186 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("GroupPointGrad expects (batch_size, npoints, nsample) idx shape")); 187 | int m = idx_tensor.shape().dim_size(1); 188 | int nsample = idx_tensor.shape().dim_size(2); 189 | 190 | const Tensor& grad_out_tensor=context->input(2); 191 | OP_REQUIRES(context,grad_out_tensor.dims()==4 && grad_out_tensor.shape().dim_size(0)==b && grad_out_tensor.shape().dim_size(1)==m && grad_out_tensor.shape().dim_size(2)==nsample && grad_out_tensor.shape().dim_size(3)==c, errors::InvalidArgument("GroupPointGrad expects (batch_size, npoints, nsample, channel) grad_out shape")); 192 | 193 | Tensor * grad_points_tensor = nullptr; 194 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,n,c}, &grad_points_tensor)); 195 | 196 | auto points_flat = points_tensor.flat(); 197 | const float *points = &(points_flat(0)); 198 | auto idx_flat = idx_tensor.flat(); 199 | const int *idx = &(idx_flat(0)); 200 | auto grad_out_flat = grad_out_tensor.flat(); 201 | const float *grad_out = &(grad_out_flat(0)); 202 | auto grad_points_flat = grad_points_tensor->flat(); 203 | float *grad_points = &(grad_points_flat(0)); 204 | cudaMemset(grad_points, 0, sizeof(float)*b*n*c); 205 | groupPointGradLauncher(b,n,c,m,nsample,grad_out,idx,grad_points); 206 | } 207 | }; 208 | REGISTER_KERNEL_BUILDER(Name("GroupPointGrad").Device(DEVICE_GPU),GroupPointGradGpuOp); 209 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping_g.cu: -------------------------------------------------------------------------------- 1 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 2 | // output: idx (b,m,nsample), pts_cnt (b,m) 3 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt) { 4 | int batch_index = blockIdx.x; 5 | xyz1 += n*3*batch_index; 6 | xyz2 += m*3*batch_index; 7 | idx += m*nsample*batch_index; 8 | pts_cnt += m*batch_index; // counting how many unique points selected in local region 9 | 10 | int index = threadIdx.x; 11 | int stride = blockDim.x; 12 | 13 | for (int j=index;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx,pts_cnt); 127 | //cudaDeviceSynchronize(); 128 | } 129 | void selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out) { 130 | selection_sort_gpu<<>>(b,n,m,k,dist,outi,out); 131 | //cudaDeviceSynchronize(); 132 | } 133 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out){ 134 | group_point_gpu<<>>(b,n,c,m,nsample,points,idx,out); 135 | //cudaDeviceSynchronize(); 136 | } 137 | void groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points){ 138 | group_point_grad_gpu<<>>(b,n,c,m,nsample,grad_out,idx,grad_points); 139 | //group_point_grad_gpu<<<1,1>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 140 | //cudaDeviceSynchronize(); 141 | } 142 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping_g.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgriffiths3/pointconv-tensorflow2/8125dd2a9c8e021ee1585662d077fb489866458c/tf_ops/grouping/tf_grouping_g.cu.o -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling.cpp: -------------------------------------------------------------------------------- 1 | /* Furthest point sampling 2 | * Original author: Haoqiang Fan 3 | * Modified by Charles R. Qi 4 | * All Rights Reserved. 2017. 5 | */ 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | #include 11 | 12 | using namespace tensorflow; 13 | 14 | REGISTER_OP("ProbSample") 15 | .Input("inp: float32") 16 | .Input("inpr: float32") 17 | .Output("out: int32") 18 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 19 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ncategory 20 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &dims1)); 21 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints 22 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &dims2)); 23 | // batch_size * npoints 24 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)}); 25 | c->set_output(0, output); 26 | return Status::OK(); 27 | }); 28 | REGISTER_OP("FarthestPointSample") 29 | .Attr("npoint: int") 30 | .Input("inp: float32") 31 | .Output("out: int32") 32 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 33 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * npoint * 3 34 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 35 | int npoint; 36 | TF_RETURN_IF_ERROR(c->GetAttr("npoint", &npoint)); 37 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), npoint}); 38 | c->set_output(0, output); 39 | return Status::OK(); 40 | }); 41 | REGISTER_OP("GatherPoint") 42 | .Input("inp: float32") 43 | .Input("idx: int32") 44 | .Output("out: float32") 45 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 46 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ndataset * 3 47 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 48 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints 49 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &dims2)); 50 | // batch_size * npoints * 3 51 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), c->Dim(dims2, 1), c->Dim(dims1, 2)}); 52 | c->set_output(0, output); 53 | return Status::OK(); 54 | }); 55 | REGISTER_OP("GatherPointGrad") 56 | .Input("inp: float32") 57 | .Input("idx: int32") 58 | .Input("out_g: float32") 59 | .Output("inp_g: float32") 60 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 61 | c->set_output(0, c->input(0)); 62 | return Status::OK(); 63 | }); 64 | 65 | void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out); 66 | class ProbSampleGpuOp: public OpKernel{ 67 | public: 68 | explicit ProbSampleGpuOp(OpKernelConstruction* context):OpKernel(context){} 69 | void Compute(OpKernelContext * context)override{ 70 | const Tensor& inp_tensor=context->input(0); 71 | const Tensor& inpr_tensor=context->input(1); 72 | auto inp_flat=inp_tensor.flat(); 73 | auto inpr_flat=inpr_tensor.flat(); 74 | const float * inp=&(inp_flat(0)); 75 | const float * inpr=&(inpr_flat(0)); 76 | OP_REQUIRES(context,inp_tensor.dims()==2,errors::InvalidArgument("ProbSample expects (batch_size,num_choices) inp shape")); 77 | int b=inp_tensor.shape().dim_size(0); 78 | int n=inp_tensor.shape().dim_size(1); 79 | OP_REQUIRES(context,inpr_tensor.dims()==2 && inpr_tensor.shape().dim_size(0)==b,errors::InvalidArgument("ProbSample expects (batch_size,num_points) inpr shape")); 80 | int m=inpr_tensor.shape().dim_size(1); 81 | Tensor * out_tensor=NULL; 82 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m},&out_tensor)); 83 | auto out_flat=out_tensor->flat(); 84 | int * out=&(out_flat(0)); 85 | Tensor temp_tensor; 86 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{b,n},&temp_tensor)); 87 | auto temp_flat=temp_tensor.flat(); 88 | float * temp=&(temp_flat(0)); 89 | probsampleLauncher(b,n,m,inp,inpr,temp,out); 90 | } 91 | }; 92 | REGISTER_KERNEL_BUILDER(Name("ProbSample").Device(DEVICE_GPU), ProbSampleGpuOp); 93 | 94 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out); 95 | class FarthestPointSampleGpuOp: public OpKernel{ 96 | public: 97 | explicit FarthestPointSampleGpuOp(OpKernelConstruction* context):OpKernel(context) { 98 | OP_REQUIRES_OK(context, context->GetAttr("npoint", &npoint_)); 99 | OP_REQUIRES(context, npoint_ > 0, errors::InvalidArgument("FarthestPointSample expects positive npoint")); 100 | } 101 | void Compute(OpKernelContext * context)override{ 102 | int m = npoint_; 103 | 104 | const Tensor& inp_tensor=context->input(0); 105 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("FarthestPointSample expects (batch_size,num_points,3) inp shape")); 106 | int b=inp_tensor.shape().dim_size(0); 107 | int n=inp_tensor.shape().dim_size(1); 108 | auto inp_flat=inp_tensor.flat(); 109 | const float * inp=&(inp_flat(0)); 110 | Tensor * out_tensor; 111 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m},&out_tensor)); 112 | auto out_flat=out_tensor->flat(); 113 | int * out=&(out_flat(0)); 114 | Tensor temp_tensor; 115 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{32,n},&temp_tensor)); 116 | auto temp_flat=temp_tensor.flat(); 117 | float * temp=&(temp_flat(0)); 118 | farthestpointsamplingLauncher(b,n,m,inp,temp,out); 119 | } 120 | private: 121 | int npoint_; 122 | }; 123 | REGISTER_KERNEL_BUILDER(Name("FarthestPointSample").Device(DEVICE_GPU),FarthestPointSampleGpuOp); 124 | 125 | void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out); 126 | class GatherPointGpuOp: public OpKernel{ 127 | public: 128 | explicit GatherPointGpuOp(OpKernelConstruction * context):OpKernel(context){} 129 | void Compute(OpKernelContext * context)override{ 130 | const Tensor& inp_tensor=context->input(0); 131 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPoint expects (batch_size,num_points,3) inp shape")); 132 | int b=inp_tensor.shape().dim_size(0); 133 | int n=inp_tensor.shape().dim_size(1); 134 | const Tensor& idx_tensor=context->input(1); 135 | OP_REQUIRES(context,idx_tensor.dims()==2 && idx_tensor.shape().dim_size(0)==b,errors::InvalidArgument("GatherPoint expects (batch_size,num_result) idx shape")); 136 | int m=idx_tensor.shape().dim_size(1); 137 | auto inp_flat=inp_tensor.flat(); 138 | const float * inp=&(inp_flat(0)); 139 | auto idx_flat=idx_tensor.flat(); 140 | const int * idx=&(idx_flat(0)); 141 | Tensor * out_tensor=NULL; 142 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m,3},&out_tensor)); 143 | auto out_flat=out_tensor->flat(); 144 | float * out=&(out_flat(0)); 145 | gatherpointLauncher(b,n,m,inp,idx,out); 146 | } 147 | }; 148 | REGISTER_KERNEL_BUILDER(Name("GatherPoint").Device(DEVICE_GPU),GatherPointGpuOp); 149 | 150 | void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g); 151 | class GatherPointGradGpuOp: public OpKernel{ 152 | public: 153 | explicit GatherPointGradGpuOp(OpKernelConstruction * context):OpKernel(context){} 154 | void Compute(OpKernelContext * context)override{ 155 | const Tensor& inp_tensor=context->input(0); 156 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_points,3) inp")); 157 | int b=inp_tensor.shape().dim_size(0); 158 | int n=inp_tensor.shape().dim_size(1); 159 | const Tensor& idx_tensor=context->input(1); 160 | OP_REQUIRES(context,idx_tensor.dims()==2 && idx_tensor.shape().dim_size(0)==b,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_result) idx shape")); 161 | int m=idx_tensor.shape().dim_size(1); 162 | auto inp_flat=inp_tensor.flat(); 163 | const float * inp=&(inp_flat(0)); 164 | auto idx_flat=idx_tensor.flat(); 165 | const int * idx=&(idx_flat(0)); 166 | const Tensor& out_g_tensor=context->input(2); 167 | OP_REQUIRES(context,out_g_tensor.dims()==3 && out_g_tensor.shape().dim_size(0)==b && out_g_tensor.shape().dim_size(1)==m && out_g_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_result,3) out_g shape")); 168 | auto out_g_flat=out_g_tensor.flat(); 169 | const float * out_g=&(out_g_flat(0)); 170 | Tensor * inp_g_tensor=NULL; 171 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&inp_g_tensor)); 172 | auto inp_g_flat=inp_g_tensor->flat(); 173 | float * inp_g=&(inp_g_flat(0)); 174 | cudaMemset(inp_g,0,b*n*3*4); 175 | scatteraddpointLauncher(b,n,m,out_g,idx,inp_g); 176 | } 177 | }; 178 | REGISTER_KERNEL_BUILDER(Name("GatherPointGrad").Device(DEVICE_GPU),GatherPointGradGpuOp); 179 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling_g.cu: -------------------------------------------------------------------------------- 1 | /* Furthest point sampling GPU implementation 2 | * Original author: Haoqiang Fan 3 | * Modified by Charles R. Qi 4 | * All Rights Reserved. 2017. 5 | */ 6 | 7 | __global__ void cumsumKernel(int b,int n,const float * __restrict__ inp,float * __restrict__ out){ 8 | const int BlockSize=2048; 9 | const int paddingLevel=5; 10 | __shared__ float buffer4[BlockSize*4]; 11 | __shared__ float buffer[BlockSize+(BlockSize>>paddingLevel)]; 12 | for (int i=blockIdx.x;i>2; 18 | for (int k=threadIdx.x*4;k>2)+(k>>(2+paddingLevel))]=v4; 33 | }else{ 34 | float v=0; 35 | for (int k2=k;k2>2)+(k>>(2+paddingLevel))]=v; 43 | } 44 | } 45 | int u=0; 46 | for (;(2<>(u+1));k+=blockDim.x){ 49 | int i1=(((k<<1)+2)<>paddingLevel; 52 | i2+=i2>>paddingLevel; 53 | buffer[i1]+=buffer[i2]; 54 | } 55 | } 56 | u--; 57 | for (;u>=0;u--){ 58 | __syncthreads(); 59 | for (int k=threadIdx.x;k>(u+1));k+=blockDim.x){ 60 | int i1=(((k<<1)+3)<>paddingLevel; 63 | i2+=i2>>paddingLevel; 64 | buffer[i1]+=buffer[i2]; 65 | } 66 | } 67 | __syncthreads(); 68 | for (int k=threadIdx.x*4;k>2)-1)+(((k>>2)-1)>>paddingLevel); 71 | buffer4[k]+=buffer[k2]; 72 | buffer4[k+1]+=buffer[k2]; 73 | buffer4[k+2]+=buffer[k2]; 74 | buffer4[k+3]+=buffer[k2]; 75 | } 76 | } 77 | __syncthreads(); 78 | for (int k=threadIdx.x;k>paddingLevel)]+runningsum2; 82 | float r2=runningsum+t; 83 | runningsum2=t-(r2-runningsum); 84 | runningsum=r2; 85 | __syncthreads(); 86 | } 87 | } 88 | } 89 | 90 | __global__ void binarysearchKernel(int b,int n,int m,const float * __restrict__ dataset,const float * __restrict__ query, int * __restrict__ result){ 91 | int base=1; 92 | while (base=1;k>>=1) 99 | if (r>=k && dataset[i*n+r-k]>=q) 100 | r-=k; 101 | result[i*m+j]=r; 102 | } 103 | } 104 | } 105 | __global__ void farthestpointsamplingKernel(int b,int n,int m,const float * __restrict__ dataset,float * __restrict__ temp,int * __restrict__ idxs){ 106 | if (m<=0) 107 | return; 108 | const int BlockSize=512; 109 | __shared__ float dists[BlockSize]; 110 | __shared__ int dists_i[BlockSize]; 111 | const int BufferSize=3072; 112 | __shared__ float buf[BufferSize*3]; 113 | for (int i=blockIdx.x;ibest){ 147 | best=d2; 148 | besti=k; 149 | } 150 | } 151 | dists[threadIdx.x]=best; 152 | dists_i[threadIdx.x]=besti; 153 | for (int u=0;(1<>(u+1))){ 156 | int i1=(threadIdx.x*2)<>>(b,n,inp,out); 196 | } 197 | //require b*n working space 198 | void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out){ 199 | cumsumKernel<<<32,512>>>(b,n,inp_p,temp); 200 | binarysearchKernel<<>>(b,n,m,temp,inp_r,out); 201 | } 202 | //require 32*n working space 203 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out){ 204 | farthestpointsamplingKernel<<<32,512>>>(b,n,m,inp,temp,out); 205 | } 206 | void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out){ 207 | gatherpointKernel<<>>(b,n,m,inp,idx,out); 208 | } 209 | void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g){ 210 | scatteraddpointKernel<<>>(b,n,m,out_g,idx,inp_g); 211 | } 212 | 213 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling_g.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgriffiths3/pointconv-tensorflow2/8125dd2a9c8e021ee1585662d077fb489866458c/tf_ops/sampling/tf_sampling_g.cu.o -------------------------------------------------------------------------------- /train_modelnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, './') 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | 10 | from model_modelnet import PointConvModel 11 | 12 | tf.random.set_seed(1234) 13 | 14 | 15 | def load_dataset(in_file, batch_size): 16 | 17 | assert os.path.isfile(in_file), '[error] dataset path not found' 18 | 19 | n_points = 8192 20 | shuffle_buffer = 1000 21 | 22 | def _extract_fn(data_record): 23 | 24 | in_features = { 25 | 'points': tf.io.FixedLenFeature([n_points * 3], tf.float32), 26 | 'label': tf.io.FixedLenFeature([1], tf.int64) 27 | } 28 | 29 | return tf.io.parse_single_example(data_record, in_features) 30 | 31 | def _preprocess_fn(sample): 32 | 33 | points = sample['points'] 34 | label = sample['label'] 35 | 36 | points = tf.reshape(points, (n_points, 3)) 37 | points = tf.random.shuffle(points) 38 | 39 | return points, label 40 | 41 | dataset = tf.data.TFRecordDataset(in_file) 42 | dataset = dataset.shuffle(shuffle_buffer) 43 | dataset = dataset.map(_extract_fn) 44 | dataset = dataset.map(_preprocess_fn) 45 | dataset = dataset.batch(batch_size, drop_remainder=True) 46 | 47 | return dataset 48 | 49 | 50 | def train(): 51 | 52 | model = PointConvModel(config['batch_size'], config['bn']) 53 | 54 | train_ds = load_dataset(config['train_ds'], config['batch_size']) 55 | val_ds = load_dataset(config['val_ds'], config['batch_size']) 56 | 57 | callbacks = [ 58 | keras.callbacks.EarlyStopping( 59 | 'val_sparse_categorical_accuracy', min_delta=0.1, patience=3), 60 | keras.callbacks.TensorBoard( 61 | './logs/{}'.format(config['log_dir']), update_freq=50), 62 | keras.callbacks.ModelCheckpoint( 63 | './logs/{}/model/weights'.format(config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) 64 | ] 65 | 66 | model.build((config['batch_size'], 8192, 3)) 67 | print(model.summary()) 68 | 69 | model.compile( 70 | optimizer=keras.optimizers.Adam(config['lr']), 71 | loss=keras.losses.SparseCategoricalCrossentropy(), 72 | metrics=[keras.metrics.SparseCategoricalAccuracy()] 73 | ) 74 | 75 | model.fit( 76 | train_ds, 77 | validation_data = val_ds, 78 | validation_steps = 10, 79 | validation_freq = 1, 80 | callbacks = callbacks, 81 | epochs = 100, 82 | verbose = 1 83 | ) 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | config = { 89 | 'train_ds' : './data/modelnet_train.tfrecord', 90 | 'val_ds' : './data/modelnet_val.tfrecord', 91 | 'batch_size' : 8, 92 | 'lr' : 1e-3, 93 | 'bn' : False, 94 | 'log_dir' : 'modelnet_1' 95 | } 96 | 97 | train() 98 | -------------------------------------------------------------------------------- /train_scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, './') 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | 7 | from model_scannet import PointConvModel 8 | from tensorflow import keras 9 | import tensorflow as tf 10 | 11 | tf.random.set_seed(1234) 12 | 13 | 14 | def load_dataset(in_file, batch_size): 15 | 16 | assert os.path.isfile(in_file), '[error] dataset path not found' 17 | 18 | n_points = 8192 19 | shuffle_buffer = 1000 20 | 21 | def _extract_fn(data_record): 22 | 23 | in_features = { 24 | 'points': tf.io.FixedLenFeature([n_points * 3], tf.float32), 25 | 'labels': tf.io.FixedLenFeature([n_points], tf.int64) 26 | } 27 | 28 | return tf.io.parse_single_example(data_record, in_features) 29 | 30 | def _preprocess_fn(sample): 31 | 32 | points = sample['points'] 33 | labels = sample['labels'] 34 | 35 | points = tf.reshape(points, (n_points, 3)) 36 | labels = tf.reshape(labels, (n_points, 1)) 37 | 38 | shuffle_idx = tf.range(points.shape[0]) 39 | points = tf.gather(points, shuffle_idx) 40 | labels = tf.gather(labels, shuffle_idx) 41 | 42 | return points, labels 43 | 44 | dataset = tf.data.TFRecordDataset(in_file) 45 | dataset = dataset.shuffle(shuffle_buffer) 46 | dataset = dataset.map(_extract_fn) 47 | dataset = dataset.map(_preprocess_fn) 48 | dataset = dataset.batch(batch_size, drop_remainder=True) 49 | 50 | return dataset 51 | 52 | 53 | def train(): 54 | 55 | model = PointConvModel(config['batch_size'], config['bn']) 56 | 57 | train_ds = load_dataset(config['train_ds'], config['batch_size']) 58 | val_ds = load_dataset(config['val_ds'], config['batch_size']) 59 | 60 | callbacks = [ 61 | keras.callbacks.EarlyStopping( 62 | 'val_sparse_categorical_accuracy', min_delta=0.1, patience=3), 63 | keras.callbacks.TensorBoard( 64 | './logs/{}'.format(config['log_dir']), update_freq=50), 65 | keras.callbacks.ModelCheckpoint( 66 | './logs/{}/model/weights'.format(config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) 67 | ] 68 | 69 | model.build((config['batch_size'], 8192, 3)) 70 | print(model.summary()) 71 | 72 | model.compile( 73 | optimizer=keras.optimizers.Adam(config['lr']), 74 | loss=keras.losses.SparseCategoricalCrossentropy(), 75 | metrics=[keras.metrics.SparseCategoricalAccuracy()] 76 | ) 77 | 78 | model.fit( 79 | train_ds, 80 | validation_data=val_ds, 81 | validation_steps=10, 82 | validation_freq=1, 83 | callbacks=callbacks, 84 | epochs=100, 85 | verbose=1 86 | ) 87 | 88 | 89 | if __name__ == '__main__': 90 | 91 | config = { 92 | 'train_ds': './data/scannet_train.tfrecord', 93 | 'val_ds': './data/scannet_val.tfrecord', 94 | 'batch_size': 4, 95 | 'lr': 1e-3, 96 | 'bn': False, 97 | 'log_dir': 'scannet_2' 98 | } 99 | 100 | train() 101 | --------------------------------------------------------------------------------