├── .gitignore ├── README.md ├── models ├── cls_msg_model.py ├── cls_ssg_model.py └── sem_seg_model.py ├── pnet2_layers ├── cpp_modules.py ├── layers.py └── utils.py ├── tf_ops ├── 3d_interpolation │ ├── interpolate.cpp │ └── tf_interpolate.cpp ├── compile_ops.sh ├── grouping │ ├── tf_grouping.cpp │ └── tf_grouping_g.cu └── sampling │ ├── tf_sampling.cpp │ └── tf_sampling_g.cu ├── train_modelnet.py └── train_scannet.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.cu.o 3 | *.pyc 4 | *tfrecord 5 | logs/ 6 | .DS_Store* 7 | __pycache__/ 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pointnet++ tensorflow 2.0 layers 2 | 3 | > Note: For the newer PointConv layers in tensorflow 2.x visit the repostiory [here](https://github.com/dgriffiths3/pointconv-tensorflow2). 4 | 5 | The repository contains implementations of the pointnet++ set abstraction and feature propagation layers as `tf.keras.layers` classes. The intention is not to be a full pointnet++ tensorflow 2.0 implementation, but provide an easy way to build a pointnet++ style network architecture using the tensorflow 2.0 keras api. For reference here is the original [paper](https://arxiv.org/pdf/1706.02413.pdf) and [code](https://github.com/charlesq34/pointnet2). Where possible I have tried to directly copy and paste original code to avoid discrepancies. 6 | 7 | ## Setup 8 | 9 | Requirements: 10 | 11 | * python >= 3.0+ 12 | * tensorflow-gpu >= 2.2+ 13 | * cuda == 10.1 14 | > Note: This repository uses the `train_step` model override which is new for `tensorflow 2.2.0`, as such if you wish to use the provided training scripts it is important your tensorflow is not an older version. The layers will work for tensorflow 2.0+. 15 | 16 | To compile the tensorflow Ops first ensure the `CUDA_ROOT` path in `tf_ops/compile_ops.sh` points correctly to you cuda folder then compile the ops with: 17 | 18 | ``` 19 | chmod u+x tf_ops/compile_ops.sh 20 | tf_ops/compile_ops.sh 21 | ``` 22 | 23 | ## Usage 24 | 25 | The layers should work as direct replacements for standard `tf.keras.layers` layers, most similarly `Conv2D`. To import just run `from pnet2_layers.layers import `. To use in your own project just copy the pnet2_layers folder into your project structure and locate with either relative or absolute imports. 26 | 27 | For example, to mimic the `pointnet2_cls_ssg` model in the original repository as a custom model, it would look like: 28 | 29 | ``` 30 | import tensorflow 31 | from pnet2_layers.layers import Pointnet_SA 32 | 33 | class CLS_SSG_Model(tf.keras.Model) 34 | 35 | def __init__(self, batch_size, activation=tf.nn.relu): 36 | super(Pointnet2Encoder, self).__init__() 37 | 38 | self.batch_size = batch_size 39 | 40 | self.layer1 = Pointnet_SA(npoint=512, radius=0.2, nsample=32, mlp=[64, 64, 128], group_all=False, activation=self.activation) 41 | self.layer1 = Pointnet_SA(npoint=128, radius=0.4, nsample=32, mlp=[128, 128, 256], group_all=False, activation=self.activation) 42 | self.layer1 = Pointnet_SA(npoint=None, radius=None, nsample=None, mlp=[256, 512, 1024], group_all=False, activation=self.activation) 43 | 44 | # The rest of the model can be implemented using standard tf.keras.layers (Dense and dropout). 45 | 46 | def call(): 47 | 48 | xyz, points = self.layer1(input, None) 49 | xyz, points = self.layer2(xyz, points) 50 | xyz, points = self.layer3(xyz, points) 51 | 52 | points = tf.reshape(points, (self.batch_size, -1)) 53 | 54 | # run points through dense / dropout layers. 55 | 56 | return points 57 | ``` 58 | 59 | Examples of a few of the models from the original repository can be found in the `models` folder. 60 | 61 | To run the ModelNet or ScanNet example first download the `tfrecords` containing the training data from [here](https://drive.google.com/open?id=1v5B68RHgDI95KM4EhDrRJxLacJAHcoxz) and place in a folder called `data`. To start the training script run either: 62 | 63 | ``` 64 | python train_modelnet.py 65 | ``` 66 | or: 67 | ``` 68 | python train_scannet.py 69 | ``` 70 | 71 | You can view training logs with: 72 | 73 | ``` 74 | tensorboard --logdir=logs --port=6006 75 | ``` 76 | 77 | and navigate to `localhost:6006` in a web browser. 78 | 79 | By default this runs the multi-scale grouping modules. To run the standard set abstraction layer without multi-scale grouping, set `msg` to `False` in the `params` dictionary. 80 | 81 | ## Notes 82 | 83 | I have implemented [batch normalization](https://towardsdatascience.com/batch-normalization-in-neural-networks-1ac91516821c) for all pointnet++ layers in the model files. I personally find I get better results on real world datasets when this option is set to `False`. In the original implementation this batch normalization is set to `True`. 84 | 85 | If you use this repository and find any bugs please submit an issue so I can fix them for anyone else using the code. -------------------------------------------------------------------------------- /models/cls_msg_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, './') 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras import Model 8 | from tensorflow.keras.layers import Dense, Dropout, BatchNormalization 9 | 10 | from pnet2_layers.layers import Pointnet_SA, Pointnet_SA_MSG 11 | 12 | 13 | class CLS_MSG_Model(Model): 14 | 15 | def __init__(self, batch_size, num_classes, bn=False, activation=tf.nn.relu): 16 | super(CLS_MSG_Model, self).__init__() 17 | 18 | self.activation = activation 19 | self.batch_size = batch_size 20 | self.num_classes = num_classes 21 | self.bn = bn 22 | self.keep_prob = 0.4 23 | 24 | self.kernel_initializer = 'glorot_normal' 25 | self.kernel_regularizer = None 26 | 27 | self.init_network() 28 | 29 | 30 | def init_network(self): 31 | 32 | self.layer1 = Pointnet_SA_MSG( 33 | npoint=1024, 34 | radius_list=[0.1,0.2,0.4], 35 | nsample_list=[16,32,128], 36 | mlp=[[32,32,64], [64,64,128], [64,96,128]], 37 | activation=self.activation, 38 | bn = self.bn 39 | ) 40 | 41 | self.layer2 = Pointnet_SA_MSG( 42 | npoint=512, 43 | radius_list=[0.2,0.4,0.8], 44 | nsample_list=[32,64,128], 45 | mlp=[[64,64,128], [128,128,256], [128,128,256]], 46 | activation=self.activation, 47 | bn = self.bn 48 | ) 49 | 50 | self.layer3 = Pointnet_SA( 51 | npoint=None, 52 | radius=None, 53 | nsample=None, 54 | mlp=[256, 512, 1024], 55 | group_all=True, 56 | activation=self.activation, 57 | bn = self.bn 58 | ) 59 | 60 | self.dense1 = Dense(512, activation=self.activation) 61 | self.dropout1 = Dropout(self.keep_prob) 62 | 63 | self.dense2 = Dense(128, activation=self.activation) 64 | self.dropout2 = Dropout(self.keep_prob) 65 | 66 | self.dense3 = Dense(self.num_classes, activation=tf.nn.softmax) 67 | 68 | 69 | def forward_pass(self, input, training): 70 | 71 | xyz, points = self.layer1(input, None, training=training) 72 | xyz, points = self.layer2(xyz, points, training=training) 73 | xyz, points = self.layer3(xyz, points, training=training) 74 | 75 | net = tf.reshape(points, (self.batch_size, -1)) 76 | 77 | net = self.dense1(net) 78 | net = self.dropout1(net) 79 | 80 | net = self.dense2(net) 81 | net = self.dropout2(net) 82 | 83 | pred = self.dense3(net) 84 | 85 | return pred 86 | 87 | 88 | def train_step(self, input): 89 | 90 | with tf.GradientTape() as tape: 91 | 92 | pred = self.forward_pass(input[0], True) 93 | loss = self.compiled_loss(input[1], pred) 94 | 95 | gradients = tape.gradient(loss, self.trainable_variables) 96 | self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) 97 | 98 | self.compiled_metrics.update_state(input[1], pred) 99 | 100 | return {m.name: m.result() for m in self.metrics} 101 | 102 | 103 | def test_step(self, input): 104 | 105 | pred = self.forward_pass(input[0], False) 106 | loss = self.compiled_loss(input[1], pred) 107 | 108 | self.compiled_metrics.update_state(input[1], pred) 109 | 110 | return {m.name: m.result() for m in self.metrics} 111 | 112 | 113 | def call(self, input, training=False): 114 | 115 | return self.forward_pass(input, training) 116 | -------------------------------------------------------------------------------- /models/cls_ssg_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, './') 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras import Model 8 | from tensorflow.keras.layers import Dense, Dropout, BatchNormalization 9 | 10 | from pnet2_layers.layers import Pointnet_SA, Pointnet_SA_MSG 11 | 12 | 13 | class CLS_SSG_Model(Model): 14 | 15 | def __init__(self, batch_size, num_classes, bn=False, activation=tf.nn.relu): 16 | super(CLS_SSG_Model, self).__init__() 17 | 18 | self.activation = activation 19 | self.batch_size = batch_size 20 | self.num_classes = num_classes 21 | self.bn = bn 22 | self.keep_prob = 0.5 23 | 24 | self.kernel_initializer = 'glorot_normal' 25 | self.kernel_regularizer = None 26 | 27 | self.init_network() 28 | 29 | 30 | def init_network(self): 31 | 32 | self.layer1 = Pointnet_SA( 33 | npoint=512, radius=0.2, 34 | nsample=32, 35 | mlp=[64, 64, 128], 36 | group_all=False, 37 | activation=self.activation, 38 | bn = self.bn 39 | ) 40 | 41 | self.layer2 = Pointnet_SA( 42 | npoint=128, 43 | radius=0.4, 44 | nsample=64, 45 | mlp=[128, 128, 256], 46 | group_all=False, 47 | activation=self.activation, 48 | bn = self.bn 49 | ) 50 | 51 | self.layer3 = Pointnet_SA( 52 | npoint=None, 53 | radius=None, 54 | nsample=None, 55 | mlp=[256, 512, 1024], 56 | group_all=True, 57 | activation=self.activation, 58 | bn = self.bn 59 | ) 60 | 61 | self.dense1 = Dense(512, activation=self.activation) 62 | self.dropout1 = Dropout(self.keep_prob) 63 | 64 | self.dense2 = Dense(128, activation=self.activation) 65 | self.dropout2 = Dropout(self.keep_prob) 66 | 67 | self.dense3 = Dense(self.num_classes, activation=tf.nn.softmax) 68 | 69 | 70 | def forward_pass(self, input, training): 71 | 72 | xyz, points = self.layer1(input, None, training=training) 73 | xyz, points = self.layer2(xyz, points, training=training) 74 | xyz, points = self.layer3(xyz, points, training=training) 75 | 76 | net = tf.reshape(points, (self.batch_size, -1)) 77 | 78 | net = self.dense1(net) 79 | net = self.dropout1(net) 80 | 81 | net = self.dense2(net) 82 | net = self.dropout2(net) 83 | 84 | pred = self.dense3(net) 85 | 86 | return pred 87 | 88 | 89 | def train_step(self, input): 90 | 91 | with tf.GradientTape() as tape: 92 | 93 | pred = self.forward_pass(input[0], True) 94 | loss = self.compiled_loss(input[1], pred) 95 | 96 | gradients = tape.gradient(loss, self.trainable_variables) 97 | self.optimizer.apply_gradients( 98 | zip(gradients, self.trainable_variables)) 99 | 100 | self.compiled_metrics.update_state(input[1], pred) 101 | 102 | return {m.name: m.result() for m in self.metrics} 103 | 104 | 105 | def test_step(self, input): 106 | 107 | pred = self.forward_pass(input[0], False) 108 | loss = self.compiled_loss(input[1], pred) 109 | 110 | self.compiled_metrics.update_state(input[1], pred) 111 | 112 | return {m.name: m.result() for m in self.metrics} 113 | 114 | 115 | def call(self, input, training=False): 116 | 117 | return self.forward_pass(input, training) 118 | -------------------------------------------------------------------------------- /models/sem_seg_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, './') 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras import Model 8 | from tensorflow.keras.layers import Dense, Dropout 9 | 10 | from pnet2_layers.layers import Pointnet_SA, Pointnet_FP 11 | 12 | 13 | class SEM_SEG_Model(Model): 14 | 15 | def __init__(self, batch_size, num_classes, bn=False, activation=tf.nn.relu): 16 | super(SEM_SEG_Model, self).__init__() 17 | 18 | self.activation = activation 19 | self.batch_size = batch_size 20 | self.keep_prob = 0.5 21 | self.num_classes = num_classes 22 | self.bn = bn 23 | 24 | self.kernel_initializer = 'glorot_normal' 25 | self.kernel_regularizer = None 26 | 27 | self.init_network() 28 | 29 | 30 | def init_network(self): 31 | 32 | self.sa_1 = Pointnet_SA( 33 | npoint=1024, 34 | radius=0.1, 35 | nsample=32, 36 | mlp=[32, 32, 64], 37 | group_all=False, 38 | activation=self.activation, 39 | bn = self.bn 40 | ) 41 | 42 | self.sa_2 = Pointnet_SA( 43 | npoint=256, 44 | radius=0.2, 45 | nsample=32, 46 | mlp=[64, 64, 128], 47 | group_all=False, 48 | activation=self.activation, 49 | bn = self.bn 50 | ) 51 | 52 | self.sa_3 = Pointnet_SA( 53 | npoint=64, 54 | radius=0.4, 55 | nsample=32, 56 | mlp=[128, 128, 256], 57 | group_all=False, 58 | activation=self.activation, 59 | bn = self.bn 60 | ) 61 | 62 | self.sa_4 = Pointnet_SA( 63 | npoint=16, 64 | radius=0.8, 65 | nsample=32, 66 | mlp=[256, 256, 512], 67 | group_all=False, 68 | activation=self.activation, 69 | bn = self.bn 70 | ) 71 | 72 | self.fp_1 = Pointnet_FP( 73 | mlp = [256, 256], 74 | activation = self.activation, 75 | bn = self.bn 76 | ) 77 | 78 | self.fp_2 = Pointnet_FP( 79 | mlp = [256, 256], 80 | activation = self.activation, 81 | bn = self.bn 82 | ) 83 | 84 | self.fp_3 = Pointnet_FP( 85 | mlp = [256, 128], 86 | activation = self.activation, 87 | bn = self.bn 88 | ) 89 | 90 | self.fp_4 = Pointnet_FP( 91 | mlp = [128, 128, 128], 92 | activation = self.activation, 93 | bn = self.bn 94 | ) 95 | 96 | 97 | self.dense1 = Dense(128, activation=self.activation) 98 | 99 | self.dropout1 = Dropout(self.keep_prob) 100 | 101 | self.dense2 = Dense(self.num_classes, activation=tf.nn.softmax) 102 | 103 | 104 | def forward_pass(self, input, training): 105 | 106 | l0_xyz = input 107 | l0_points = None 108 | 109 | l1_xyz, l1_points = self.sa_1(l0_xyz, l0_points, training=training) 110 | l2_xyz, l2_points = self.sa_2(l1_xyz, l1_points, training=training) 111 | l3_xyz, l3_points = self.sa_3(l2_xyz, l2_points, training=training) 112 | l4_xyz, l4_points = self.sa_4(l3_xyz, l3_points, training=training) 113 | 114 | l3_points = self.fp_1(l3_xyz, l4_xyz, l3_points, l4_points, training=training) 115 | l2_points = self.fp_2(l2_xyz, l3_xyz, l2_points, l3_points, training=training) 116 | l1_points = self.fp_3(l1_xyz, l2_xyz, l1_points, l2_points, training=training) 117 | l0_points = self.fp_4(l0_xyz, l1_xyz, l0_points, l1_points, training=training) 118 | 119 | net = self.dense1(l0_points) 120 | net = self.dropout1(net) 121 | pred = self.dense2(net) 122 | 123 | return pred 124 | 125 | 126 | def train_step(self, input): 127 | 128 | with tf.GradientTape() as tape: 129 | 130 | pred = self.forward_pass(input[0], True) 131 | loss = self.compiled_loss(input[1], pred) 132 | 133 | gradients = tape.gradient(loss, self.trainable_variables) 134 | self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) 135 | 136 | self.compiled_metrics.update_state(input[1], pred) 137 | 138 | return {m.name: m.result() for m in self.metrics} 139 | 140 | 141 | def test_step(self, input): 142 | 143 | pred = self.forward_pass(input[0], False) 144 | loss = self.compiled_loss(input[1], pred) 145 | 146 | self.compiled_metrics.update_state(input[1], pred) 147 | 148 | return {m.name: m.result() for m in self.metrics} 149 | 150 | 151 | def call(self, input, training=False): 152 | 153 | return self.forward_pass(input, training) 154 | 155 | -------------------------------------------------------------------------------- /pnet2_layers/cpp_modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.python.framework import ops 6 | 7 | from tensorflow.keras.layers import MaxPool1D, Layer 8 | 9 | sampling_module=tf.load_op_library('./tf_ops/sampling/tf_sampling_so.so') 10 | grouping_module=tf.load_op_library('./tf_ops/grouping/tf_grouping_so.so') 11 | interpolate_module=tf.load_op_library('./tf_ops/3d_interpolation/tf_interpolate_so.so') 12 | 13 | def prob_sample(inp,inpr): 14 | return sampling_module.prob_sample(inp,inpr) 15 | 16 | ops.NoGradient('ProbSample') 17 | 18 | 19 | def gather_point(inp,idx): 20 | return sampling_module.gather_point(inp,idx) 21 | 22 | 23 | @tf.RegisterGradient('GatherPoint') 24 | def _gather_point_grad(op,out_g): 25 | inp=op.inputs[0] 26 | idx=op.inputs[1] 27 | return [sampling_module.gather_point_grad(inp,idx,out_g),None] 28 | 29 | 30 | def farthest_point_sample(npoint,inp): 31 | return sampling_module.farthest_point_sample(inp, npoint) 32 | 33 | ops.NoGradient('FarthestPointSample') 34 | 35 | 36 | def query_ball_point(radius, nsample, xyz1, xyz2): 37 | return grouping_module.query_ball_point(xyz1, xyz2, radius, nsample) 38 | 39 | ops.NoGradient('QueryBallPoint') 40 | 41 | 42 | def select_top_k(k, dist): 43 | 44 | return grouping_module.selection_sort(dist, k) 45 | 46 | ops.NoGradient('SelectionSort') 47 | 48 | 49 | def group_point(points, idx): 50 | 51 | return grouping_module.group_point(points, idx) 52 | 53 | 54 | @tf.RegisterGradient('GroupPoint') 55 | def _group_point_grad(op, grad_out): 56 | points = op.inputs[0] 57 | idx = op.inputs[1] 58 | return [grouping_module.group_point_grad(points, idx, grad_out), None] 59 | 60 | 61 | def knn_point(k, xyz1, xyz2): 62 | 63 | b = xyz1.get_shape()[0].value 64 | n = xyz1.get_shape()[1].value 65 | c = xyz1.get_shape()[2].value 66 | m = xyz2.get_shape()[1].value 67 | print (b, n, c, m) 68 | print (xyz1, (b,1,n,c)) 69 | xyz1 = tf.tile(tf.reshape(xyz1, (b,1,n,c)), [1,m,1,1]) 70 | xyz2 = tf.tile(tf.reshape(xyz2, (b,m,1,c)), [1,1,n,1]) 71 | dist = tf.reduce_sum((xyz1-xyz2)**2, -1) 72 | print (dist, k) 73 | outi, out = select_top_k(k, dist) 74 | idx = tf.slice(outi, [0,0,0], [-1,-1,k]) 75 | val = tf.slice(out, [0,0,0], [-1,-1,k]) 76 | print (idx, val) 77 | #val, idx = tf.nn.top_k(-dist, k=k) # ONLY SUPPORT CPU 78 | return val, idx 79 | 80 | 81 | def three_nn(xyz1, xyz2): 82 | return interpolate_module.three_nn(xyz1, xyz2) 83 | 84 | ops.NoGradient('ThreeNN') 85 | 86 | 87 | def three_interpolate(points, idx, weight): 88 | return interpolate_module.three_interpolate(points, idx, weight) 89 | 90 | 91 | @tf.RegisterGradient('ThreeInterpolate') 92 | def _three_interpolate_grad(op, grad_out): 93 | points = op.inputs[0] 94 | idx = op.inputs[1] 95 | weight = op.inputs[2] 96 | return [interpolate_module.three_interpolate_grad(points, idx, weight, grad_out), None, None] 97 | -------------------------------------------------------------------------------- /pnet2_layers/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Layer, BatchNormalization 3 | 4 | from . import utils 5 | 6 | 7 | class Pointnet_SA(Layer): 8 | 9 | def __init__( 10 | self, npoint, radius, nsample, mlp, group_all=False, knn=False, use_xyz=True, activation=tf.nn.relu, bn=False 11 | ): 12 | 13 | super(Pointnet_SA, self).__init__() 14 | 15 | self.npoint = npoint 16 | self.radius = radius 17 | self.nsample = nsample 18 | self.mlp = mlp 19 | self.group_all = group_all 20 | self.knn = False 21 | self.use_xyz = use_xyz 22 | self.activation = activation 23 | self.bn = bn 24 | 25 | self.mlp_list = [] 26 | 27 | def build(self, input_shape): 28 | 29 | for i, n_filters in enumerate(self.mlp): 30 | self.mlp_list.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 31 | 32 | super(Pointnet_SA, self).build(input_shape) 33 | 34 | def call(self, xyz, points, training=True): 35 | 36 | if points is not None: 37 | if len(points.shape) < 3: 38 | points = tf.expand_dims(points, axis=0) 39 | 40 | if self.group_all: 41 | nsample = xyz.get_shape()[1] 42 | new_xyz, new_points, idx, grouped_xyz = utils.sample_and_group_all(xyz, points, self.use_xyz) 43 | else: 44 | new_xyz, new_points, idx, grouped_xyz = utils.sample_and_group( 45 | self.npoint, 46 | self.radius, 47 | self.nsample, 48 | xyz, 49 | points, 50 | self.knn, 51 | use_xyz=self.use_xyz 52 | ) 53 | 54 | for i, mlp_layer in enumerate(self.mlp_list): 55 | new_points = mlp_layer(new_points, training=training) 56 | 57 | new_points = tf.math.reduce_max(new_points, axis=2, keepdims=True) 58 | 59 | return new_xyz, tf.squeeze(new_points) 60 | 61 | 62 | class Pointnet_SA_MSG(Layer): 63 | 64 | def __init__( 65 | self, npoint, radius_list, nsample_list, mlp, use_xyz=True, activation=tf.nn.relu, bn = False 66 | ): 67 | 68 | super(Pointnet_SA_MSG, self).__init__() 69 | 70 | self.npoint = npoint 71 | self.radius_list = radius_list 72 | self.nsample_list = nsample_list 73 | self.mlp = mlp 74 | self.use_xyz = use_xyz 75 | self.activation = activation 76 | self.bn = bn 77 | 78 | self.mlp_list = [] 79 | 80 | def build(self, input_shape): 81 | 82 | for i in range(len(self.radius_list)): 83 | tmp_list = [] 84 | for i, n_filters in enumerate(self.mlp[i]): 85 | tmp_list.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 86 | self.mlp_list.append(tmp_list) 87 | 88 | super(Pointnet_SA_MSG, self).build(input_shape) 89 | 90 | def call(self, xyz, points, training=True): 91 | 92 | if points is not None: 93 | if len(points.shape) < 3: 94 | points = tf.expand_dims(points, axis=0) 95 | 96 | new_xyz = utils.gather_point(xyz, utils.farthest_point_sample(self.npoint, xyz)) 97 | 98 | new_points_list = [] 99 | 100 | for i in range(len(self.radius_list)): 101 | radius = self.radius_list[i] 102 | nsample = self.nsample_list[i] 103 | idx, pts_cnt = utils.query_ball_point(radius, nsample, xyz, new_xyz) 104 | grouped_xyz = utils.group_point(xyz, idx) 105 | grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) 106 | 107 | if points is not None: 108 | grouped_points = utils.group_point(points, idx) 109 | if self.use_xyz: 110 | grouped_points = tf.concat([grouped_points, grouped_xyz], axis=-1) 111 | else: 112 | grouped_points = grouped_xyz 113 | 114 | for i, mlp_layer in enumerate(self.mlp_list[i]): 115 | grouped_points = mlp_layer(grouped_points, training=training) 116 | 117 | new_points = tf.math.reduce_max(grouped_points, axis=2) 118 | new_points_list.append(new_points) 119 | 120 | new_points_concat = tf.concat(new_points_list, axis=-1) 121 | 122 | return new_xyz, new_points_concat 123 | 124 | 125 | class Pointnet_FP(Layer): 126 | 127 | def __init__( 128 | self, mlp, activation=tf.nn.relu, bn=False 129 | ): 130 | 131 | super(Pointnet_FP, self).__init__() 132 | 133 | self.mlp = mlp 134 | self.activation = activation 135 | self.bn = bn 136 | 137 | self.mlp_list = [] 138 | 139 | 140 | def build(self, input_shape): 141 | 142 | for i, n_filters in enumerate(self.mlp): 143 | self.mlp_list.append(utils.Conv2d(n_filters, activation=self.activation, bn=self.bn)) 144 | super(Pointnet_FP, self).build(input_shape) 145 | 146 | def call(self, xyz1, xyz2, points1, points2, training=True): 147 | 148 | if points1 is not None: 149 | if len(points1.shape) < 3: 150 | points1 = tf.expand_dims(points1, axis=0) 151 | if points2 is not None: 152 | if len(points2.shape) < 3: 153 | points2 = tf.expand_dims(points2, axis=0) 154 | 155 | dist, idx = utils.three_nn(xyz1, xyz2) 156 | dist = tf.maximum(dist, 1e-10) 157 | norm = tf.reduce_sum((1.0/dist),axis=2, keepdims=True) 158 | norm = tf.tile(norm,[1,1,3]) 159 | weight = (1.0/dist) / norm 160 | interpolated_points = utils.three_interpolate(points2, idx, weight) 161 | 162 | if points1 is not None: 163 | new_points1 = tf.concat(axis=2, values=[interpolated_points, points1]) # B,ndataset1,nchannel1+nchannel2 164 | else: 165 | new_points1 = interpolated_points 166 | new_points1 = tf.expand_dims(new_points1, 2) 167 | 168 | for i, mlp_layer in enumerate(self.mlp_list): 169 | new_points1 = mlp_layer(new_points1, training=training) 170 | 171 | new_points1 = tf.squeeze(new_points1) 172 | if len(new_points1.shape) < 3: 173 | new_points1 = tf.expand_dims(new_points1, axis=0) 174 | 175 | return new_points1 176 | -------------------------------------------------------------------------------- /pnet2_layers/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tensorflow.keras.layers import MaxPool1D, Layer, BatchNormalization 7 | 8 | from .cpp_modules import ( 9 | farthest_point_sample, 10 | gather_point, 11 | query_ball_point, 12 | group_point, 13 | knn_point, 14 | three_nn, 15 | three_interpolate 16 | ) 17 | 18 | 19 | def sample_and_group(npoint, radius, nsample, xyz, points, knn=False, use_xyz=True): 20 | 21 | new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz)) # (batch_size, npoint, 3) 22 | if knn: 23 | _,idx = knn_point(nsample, xyz, new_xyz) 24 | else: 25 | idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz) 26 | grouped_xyz = group_point(xyz, idx) # (batch_size, npoint, nsample, 3) 27 | grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) # translation normalization 28 | if points is not None: 29 | grouped_points = group_point(points, idx) # (batch_size, npoint, nsample, channel) 30 | if use_xyz: 31 | new_points = tf.concat([grouped_xyz, grouped_points], axis=-1) # (batch_size, npoint, nample, 3+channel) 32 | else: 33 | new_points = grouped_points 34 | else: 35 | new_points = grouped_xyz 36 | 37 | return new_xyz, new_points, idx, grouped_xyz 38 | 39 | 40 | def sample_and_group_all(xyz, points, use_xyz=True): 41 | 42 | batch_size = xyz.get_shape()[0] 43 | nsample = xyz.get_shape()[1] 44 | 45 | new_xyz = tf.constant(np.tile(np.array([0,0,0]).reshape((1,1,3)), (batch_size,1,1)),dtype=tf.float32) # (batch_size, 1, 3) 46 | 47 | idx = tf.constant(np.tile(np.array(range(nsample)).reshape((1,1,nsample)), (batch_size,1,1))) 48 | grouped_xyz = tf.reshape(xyz, (batch_size, 1, nsample, 3)) # (batch_size, npoint=1, nsample, 3) 49 | if points is not None: 50 | if use_xyz: 51 | new_points = tf.concat([xyz, points], axis=2) # (batch_size, 16, 259) 52 | else: 53 | new_points = points 54 | new_points = tf.expand_dims(new_points, 1) # (batch_size, 1, 16, 259) 55 | else: 56 | new_points = grouped_xyz 57 | return new_xyz, new_points, idx, grouped_xyz 58 | 59 | 60 | class Conv2d(Layer): 61 | 62 | def __init__(self, filters, strides=[1, 1], activation=tf.nn.relu, padding='VALID', initializer='glorot_normal', bn=False): 63 | super(Conv2d, self).__init__() 64 | 65 | self.filters = filters 66 | self.strides = strides 67 | self.activation = activation 68 | self.padding = padding 69 | self.initializer = initializer 70 | self.bn = bn 71 | 72 | def build(self, input_shape): 73 | 74 | self.w = self.add_weight( 75 | shape=(1, 1, input_shape[-1], self.filters), 76 | initializer=self.initializer, 77 | trainable=True, 78 | name='pnet_conv' 79 | ) 80 | 81 | if self.bn: self.bn_layer = BatchNormalization() 82 | 83 | super(Conv2d, self).build(input_shape) 84 | 85 | def call(self, inputs, training=True): 86 | 87 | points = tf.nn.conv2d(inputs, filters=self.w, strides=self.strides, padding=self.padding) 88 | 89 | if self.bn: points = self.bn_layer(points, training=training) 90 | 91 | if self.activation: points = self.activation(points) 92 | 93 | return points 94 | -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // Find three nearest neigbors with square distance 19 | // input: xyz1 (b,n,3), xyz2(b,m,3) 20 | // output: dist (b,n,3), idx (b,n,3) 21 | void threenn_cpu(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx) { 22 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | using namespace tensorflow; 11 | 12 | REGISTER_OP("ThreeNN") 13 | .Input("xyz1: float32") 14 | .Input("xyz2: float32") 15 | .Output("dist: float32") 16 | .Output("idx: int32") 17 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 18 | c->set_output(0, c->input(0)); 19 | c->set_output(1, c->input(0)); 20 | return Status::OK(); 21 | }); 22 | REGISTER_OP("ThreeInterpolate") 23 | .Input("points: float32") 24 | .Input("idx: int32") 25 | .Input("weight: float32") 26 | .Output("out: float32") 27 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 28 | ::tensorflow::shape_inference::ShapeHandle dims1; // (b,m,c) 29 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 30 | ::tensorflow::shape_inference::ShapeHandle dims2; // (b,n,3) 31 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &dims2)); 32 | // (b,n,c) 33 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), c->Dim(dims2, 1), c->Dim(dims1, 2)}); 34 | c->set_output(0, output); 35 | return Status::OK(); 36 | }); 37 | REGISTER_OP("ThreeInterpolateGrad") 38 | .Input("points: float32") 39 | .Input("idx: int32") 40 | .Input("weight: float32") 41 | .Input("grad_out: float32") 42 | .Output("grad_points: float32") 43 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 44 | c->set_output(0, c->input(0)); 45 | return Status::OK(); 46 | }); 47 | 48 | float randomf(){ 49 | return (rand()+0.5)/(RAND_MAX+1.0); 50 | } 51 | static double get_time(){ 52 | timespec tp; 53 | clock_gettime(CLOCK_MONOTONIC,&tp); 54 | return tp.tv_sec+tp.tv_nsec*1e-9; 55 | } 56 | 57 | // Find three nearest neigbors with square distance 58 | // input: xyz1 (b,n,3), xyz2(b,m,3) 59 | // output: dist (b,n,3), idx (b,n,3) 60 | void threenn_cpu(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx) { 61 | for (int i=0;iinput(0); 163 | OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeNN expects (b,n,3) xyz1 shape.")); 164 | int b = xyz1_tensor.shape().dim_size(0); 165 | int n = xyz1_tensor.shape().dim_size(1); 166 | 167 | const Tensor& xyz2_tensor = context->input(1); 168 | OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeNN expects (b,m,3) xyz2 shape.")); 169 | int m = xyz2_tensor.shape().dim_size(1); 170 | 171 | Tensor *dist_tensor = nullptr; 172 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,n,3}, &dist_tensor)); 173 | Tensor *idx_tensor = nullptr; 174 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,n,3}, &idx_tensor)); 175 | 176 | auto xyz1_flat = xyz1_tensor.flat(); 177 | const float *xyz1 = &(xyz1_flat(0)); 178 | auto xyz2_flat = xyz2_tensor.flat(); 179 | const float *xyz2 = &(xyz2_flat(0)); 180 | auto dist_flat = dist_tensor->flat(); 181 | float *dist = &(dist_flat(0)); 182 | auto idx_flat = idx_tensor->flat(); 183 | int *idx = &(idx_flat(0)); 184 | threenn_cpu(b,n,m,xyz1,xyz2,dist,idx); 185 | } 186 | }; 187 | REGISTER_KERNEL_BUILDER(Name("ThreeNN").Device(DEVICE_CPU), ThreeNNOp); 188 | 189 | 190 | 191 | class ThreeInterpolateOp: public OpKernel{ 192 | public: 193 | explicit ThreeInterpolateOp(OpKernelConstruction * context):OpKernel(context){} 194 | 195 | void Compute(OpKernelContext * context) override { 196 | const Tensor& points_tensor=context->input(0); 197 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("ThreeInterpolate expects (b,m,c) points shape")); 198 | int b = points_tensor.shape().dim_size(0); 199 | int m = points_tensor.shape().dim_size(1); 200 | int c = points_tensor.shape().dim_size(2); 201 | 202 | const Tensor& idx_tensor=context->input(1); 203 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b && idx_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolate expects (b,n,3) idx shape")); 204 | int n = idx_tensor.shape().dim_size(1); 205 | const Tensor& weight_tensor=context->input(2); 206 | OP_REQUIRES(context,weight_tensor.dims()==3 && weight_tensor.shape().dim_size(0)==b && weight_tensor.shape().dim_size(1)==n && weight_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolate expects (b,n,3) weight shape")); 207 | 208 | Tensor * out_tensor = nullptr; 209 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,n,c}, &out_tensor)); 210 | 211 | auto points_flat = points_tensor.flat(); 212 | const float *points = &(points_flat(0)); 213 | auto idx_flat = idx_tensor.flat(); 214 | const int *idx = &(idx_flat(0)); 215 | auto weight_flat = weight_tensor.flat(); 216 | const float *weight = &(weight_flat(0)); 217 | auto out_flat = out_tensor->flat(); 218 | float *out = &(out_flat(0)); 219 | threeinterpolate_cpu(b,m,c,n,points,idx,weight,out); 220 | } 221 | }; 222 | REGISTER_KERNEL_BUILDER(Name("ThreeInterpolate").Device(DEVICE_CPU),ThreeInterpolateOp); 223 | 224 | 225 | class ThreeInterpolateGradOp: public OpKernel{ 226 | public: 227 | explicit ThreeInterpolateGradOp(OpKernelConstruction * context):OpKernel(context){} 228 | 229 | void Compute(OpKernelContext * context) override { 230 | const Tensor& points_tensor=context->input(0); 231 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("ThreeInterpolateGrad expects (b,m,c) points shape")); 232 | int b = points_tensor.shape().dim_size(0); 233 | int m = points_tensor.shape().dim_size(1); 234 | int c = points_tensor.shape().dim_size(2); 235 | 236 | const Tensor& idx_tensor=context->input(1); 237 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,3) idx shape")); 238 | int n = idx_tensor.shape().dim_size(1); 239 | const Tensor& weight_tensor=context->input(2); 240 | OP_REQUIRES(context,weight_tensor.dims()==3 && weight_tensor.shape().dim_size(0)==b && weight_tensor.shape().dim_size(1)==n && weight_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,3) weight shape")); 241 | 242 | const Tensor& grad_out_tensor=context->input(3); 243 | OP_REQUIRES(context,grad_out_tensor.dims()==3 && grad_out_tensor.shape().dim_size(0)==b && grad_out_tensor.shape().dim_size(1)==n && grad_out_tensor.shape().dim_size(2)==c, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,c) grad_out shape")); 244 | 245 | Tensor * grad_points_tensor = nullptr; 246 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,m,c}, &grad_points_tensor)); 247 | 248 | auto points_flat = points_tensor.flat(); 249 | const float *points = &(points_flat(0)); 250 | auto idx_flat = idx_tensor.flat(); 251 | const int *idx = &(idx_flat(0)); 252 | auto weight_flat = weight_tensor.flat(); 253 | const float *weight = &(weight_flat(0)); 254 | auto grad_out_flat = grad_out_tensor.flat(); 255 | const float *grad_out = &(grad_out_flat(0)); 256 | auto grad_points_flat = grad_points_tensor->flat(); 257 | float *grad_points = &(grad_points_flat(0)); 258 | memset(grad_points, 0, sizeof(float)*b*m*c); 259 | threeinterpolate_grad_cpu(b,n,c,m,grad_out,idx,weight,grad_points); 260 | } 261 | }; 262 | REGISTER_KERNEL_BUILDER(Name("ThreeInterpolateGrad").Device(DEVICE_CPU),ThreeInterpolateGradOp); 263 | -------------------------------------------------------------------------------- /tf_ops/compile_ops.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | TF_CFLAGS=$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') 4 | TF_LFLAGS=$(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') 5 | CUDA_ROOT=/usr/local/cuda-10.1 6 | 7 | cd tf_ops 8 | 9 | g++ -std=c++11 -shared ./3d_interpolation/tf_interpolate.cpp -o ./3d_interpolation/tf_interpolate_so.so -I $CUDA_ROOT/include -lcudart -L $CUDA_ROOT/lib64/ -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2 10 | 11 | $CUDA_ROOT/bin/nvcc ./grouping/tf_grouping_g.cu -o ./grouping/tf_grouping_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 12 | g++ -std=c++11 -shared ./grouping/tf_grouping.cpp ./grouping/tf_grouping_g.cu.o -o ./grouping/tf_grouping_so.so -I $CUDA_ROOT/include -L $CUDA_ROOT/lib64/ -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2 13 | 14 | $CUDA_ROOT/bin/nvcc ./sampling/tf_sampling_g.cu -o ./sampling/tf_sampling_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 15 | g++ -std=c++11 -shared ./sampling/tf_sampling.cpp ./sampling/tf_sampling_g.cu.o -o ./sampling/tf_sampling_so.so -I $CUDA_ROOT/include -L $CUDA_ROOT/lib64/ -fPIC ${TF_CFLAGS} ${TF_LFLAGS} -O2 16 | 17 | cd ../ 18 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | #include 11 | using namespace tensorflow; 12 | 13 | REGISTER_OP("QueryBallPoint") 14 | .Attr("radius: float") 15 | .Attr("nsample: int") 16 | .Input("xyz1: float32") 17 | .Input("xyz2: float32") 18 | .Output("idx: int32") 19 | .Output("pts_cnt: int32") 20 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 21 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoint * 3 22 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &dims2)); 23 | int nsample; 24 | TF_RETURN_IF_ERROR(c->GetAttr("nsample", &nsample)); 25 | ::tensorflow::shape_inference::ShapeHandle output1 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), nsample}); 26 | c->set_output(0, output1); 27 | ::tensorflow::shape_inference::ShapeHandle output2 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)}); 28 | c->set_output(1, output2); 29 | return Status::OK(); 30 | }); 31 | REGISTER_OP("SelectionSort") 32 | .Attr("k: int") 33 | .Input("dist: float32") 34 | .Output("outi: int32") 35 | .Output("out: float32") 36 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 37 | c->set_output(0, c->input(0)); 38 | c->set_output(1, c->input(0)); 39 | return Status::OK(); 40 | }); 41 | REGISTER_OP("GroupPoint") 42 | .Input("points: float32") 43 | .Input("idx: int32") 44 | .Output("out: float32") 45 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 46 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ndataset * channels 47 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 48 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints * nsample 49 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &dims2)); 50 | // batch_size * npoints * nsample * channels 51 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), c->Dim(dims2, 2), c->Dim(dims1, 2)}); 52 | c->set_output(0, output); 53 | return Status::OK(); 54 | }); 55 | REGISTER_OP("GroupPointGrad") 56 | .Input("points: float32") 57 | .Input("idx: int32") 58 | .Input("grad_out: float32") 59 | .Output("grad_points: float32") 60 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 61 | c->set_output(0, c->input(0)); 62 | return Status::OK(); 63 | }); 64 | 65 | 66 | void queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt); 67 | class QueryBallPointGpuOp : public OpKernel { 68 | public: 69 | explicit QueryBallPointGpuOp(OpKernelConstruction* context) : OpKernel(context) { 70 | OP_REQUIRES_OK(context, context->GetAttr("radius", &radius_)); 71 | OP_REQUIRES(context, radius_ > 0, errors::InvalidArgument("QueryBallPoint expects positive radius")); 72 | 73 | OP_REQUIRES_OK(context, context->GetAttr("nsample", &nsample_)); 74 | OP_REQUIRES(context, nsample_ > 0, errors::InvalidArgument("QueryBallPoint expects positive nsample")); 75 | } 76 | 77 | void Compute(OpKernelContext* context) override { 78 | const Tensor& xyz1_tensor = context->input(0); 79 | OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("QueryBallPoint expects (batch_size, ndataset, 3) xyz1 shape.")); 80 | int b = xyz1_tensor.shape().dim_size(0); 81 | int n = xyz1_tensor.shape().dim_size(1); 82 | 83 | const Tensor& xyz2_tensor = context->input(1); 84 | OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("QueryBallPoint expects (batch_size, npoint, 3) xyz2 shape.")); 85 | int m = xyz2_tensor.shape().dim_size(1); 86 | 87 | Tensor *idx_tensor = nullptr; 88 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,nsample_}, &idx_tensor)); 89 | Tensor *pts_cnt_tensor = nullptr; 90 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m}, &pts_cnt_tensor)); 91 | 92 | auto xyz1_flat = xyz1_tensor.flat(); 93 | const float *xyz1 = &(xyz1_flat(0)); 94 | auto xyz2_flat = xyz2_tensor.flat(); 95 | const float *xyz2 = &(xyz2_flat(0)); 96 | auto idx_flat = idx_tensor->flat(); 97 | int *idx = &(idx_flat(0)); 98 | auto pts_cnt_flat = pts_cnt_tensor->flat(); 99 | int *pts_cnt = &(pts_cnt_flat(0)); 100 | queryBallPointLauncher(b,n,m,radius_,nsample_,xyz1,xyz2,idx,pts_cnt); 101 | } 102 | private: 103 | float radius_; 104 | int nsample_; 105 | }; 106 | REGISTER_KERNEL_BUILDER(Name("QueryBallPoint").Device(DEVICE_GPU), QueryBallPointGpuOp); 107 | 108 | void selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out); 109 | class SelectionSortGpuOp : public OpKernel { 110 | public: 111 | explicit SelectionSortGpuOp(OpKernelConstruction* context) : OpKernel(context) { 112 | OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); 113 | OP_REQUIRES(context, k_ > 0, errors::InvalidArgument("SelectionSort expects positive k")); 114 | } 115 | 116 | void Compute(OpKernelContext* context) override { 117 | const Tensor& dist_tensor = context->input(0); 118 | OP_REQUIRES(context, dist_tensor.dims()==3, errors::InvalidArgument("SelectionSort expects (b,m,n) dist shape.")); 119 | int b = dist_tensor.shape().dim_size(0); 120 | int m = dist_tensor.shape().dim_size(1); 121 | int n = dist_tensor.shape().dim_size(2); 122 | 123 | Tensor *outi_tensor = nullptr; 124 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,n}, &outi_tensor)); 125 | Tensor *out_tensor = nullptr; 126 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m,n}, &out_tensor)); 127 | 128 | auto dist_flat = dist_tensor.flat(); 129 | const float *dist = &(dist_flat(0)); 130 | auto outi_flat = outi_tensor->flat(); 131 | int *outi = &(outi_flat(0)); 132 | auto out_flat = out_tensor->flat(); 133 | float *out = &(out_flat(0)); 134 | selectionSortLauncher(b,n,m,k_,dist,outi,out); 135 | } 136 | private: 137 | int k_; 138 | }; 139 | REGISTER_KERNEL_BUILDER(Name("SelectionSort").Device(DEVICE_GPU), SelectionSortGpuOp); 140 | 141 | 142 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out); 143 | class GroupPointGpuOp: public OpKernel{ 144 | public: 145 | explicit GroupPointGpuOp(OpKernelConstruction * context):OpKernel(context){} 146 | 147 | void Compute(OpKernelContext * context) override { 148 | const Tensor& points_tensor=context->input(0); 149 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("GroupPoint expects (batch_size, num_points, channel) points shape")); 150 | int b = points_tensor.shape().dim_size(0); 151 | int n = points_tensor.shape().dim_size(1); 152 | int c = points_tensor.shape().dim_size(2); 153 | 154 | const Tensor& idx_tensor=context->input(1); 155 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("GroupPoint expects (batch_size, npoints, nsample) idx shape")); 156 | int m = idx_tensor.shape().dim_size(1); 157 | int nsample = idx_tensor.shape().dim_size(2); 158 | 159 | Tensor * out_tensor = nullptr; 160 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,m,nsample,c}, &out_tensor)); 161 | 162 | auto points_flat = points_tensor.flat(); 163 | const float *points = &(points_flat(0)); 164 | auto idx_flat = idx_tensor.flat(); 165 | const int *idx = &(idx_flat(0)); 166 | auto out_flat = out_tensor->flat(); 167 | float *out = &(out_flat(0)); 168 | groupPointLauncher(b,n,c,m,nsample,points,idx,out); 169 | } 170 | }; 171 | REGISTER_KERNEL_BUILDER(Name("GroupPoint").Device(DEVICE_GPU),GroupPointGpuOp); 172 | 173 | void groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points); 174 | class GroupPointGradGpuOp: public OpKernel{ 175 | public: 176 | explicit GroupPointGradGpuOp(OpKernelConstruction * context):OpKernel(context){} 177 | 178 | void Compute(OpKernelContext * context) override { 179 | const Tensor& points_tensor=context->input(0); 180 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("GroupPointGrad expects (batch_size, num_points, channel) points shape")); 181 | int b = points_tensor.shape().dim_size(0); 182 | int n = points_tensor.shape().dim_size(1); 183 | int c = points_tensor.shape().dim_size(2); 184 | 185 | const Tensor& idx_tensor=context->input(1); 186 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("GroupPointGrad expects (batch_size, npoints, nsample) idx shape")); 187 | int m = idx_tensor.shape().dim_size(1); 188 | int nsample = idx_tensor.shape().dim_size(2); 189 | 190 | const Tensor& grad_out_tensor=context->input(2); 191 | OP_REQUIRES(context,grad_out_tensor.dims()==4 && grad_out_tensor.shape().dim_size(0)==b && grad_out_tensor.shape().dim_size(1)==m && grad_out_tensor.shape().dim_size(2)==nsample && grad_out_tensor.shape().dim_size(3)==c, errors::InvalidArgument("GroupPointGrad expects (batch_size, npoints, nsample, channel) grad_out shape")); 192 | 193 | Tensor * grad_points_tensor = nullptr; 194 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,n,c}, &grad_points_tensor)); 195 | 196 | auto points_flat = points_tensor.flat(); 197 | const float *points = &(points_flat(0)); 198 | auto idx_flat = idx_tensor.flat(); 199 | const int *idx = &(idx_flat(0)); 200 | auto grad_out_flat = grad_out_tensor.flat(); 201 | const float *grad_out = &(grad_out_flat(0)); 202 | auto grad_points_flat = grad_points_tensor->flat(); 203 | float *grad_points = &(grad_points_flat(0)); 204 | cudaMemset(grad_points, 0, sizeof(float)*b*n*c); 205 | groupPointGradLauncher(b,n,c,m,nsample,grad_out,idx,grad_points); 206 | } 207 | }; 208 | REGISTER_KERNEL_BUILDER(Name("GroupPointGrad").Device(DEVICE_GPU),GroupPointGradGpuOp); 209 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping_g.cu: -------------------------------------------------------------------------------- 1 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 2 | // output: idx (b,m,nsample), pts_cnt (b,m) 3 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt) { 4 | int batch_index = blockIdx.x; 5 | xyz1 += n*3*batch_index; 6 | xyz2 += m*3*batch_index; 7 | idx += m*nsample*batch_index; 8 | pts_cnt += m*batch_index; // counting how many unique points selected in local region 9 | 10 | int index = threadIdx.x; 11 | int stride = blockDim.x; 12 | 13 | for (int j=index;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx,pts_cnt); 127 | //cudaDeviceSynchronize(); 128 | } 129 | void selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out) { 130 | selection_sort_gpu<<>>(b,n,m,k,dist,outi,out); 131 | //cudaDeviceSynchronize(); 132 | } 133 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out){ 134 | group_point_gpu<<>>(b,n,c,m,nsample,points,idx,out); 135 | //cudaDeviceSynchronize(); 136 | } 137 | void groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points){ 138 | group_point_grad_gpu<<>>(b,n,c,m,nsample,grad_out,idx,grad_points); 139 | //group_point_grad_gpu<<<1,1>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 140 | //cudaDeviceSynchronize(); 141 | } 142 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling.cpp: -------------------------------------------------------------------------------- 1 | /* Furthest point sampling 2 | * Original author: Haoqiang Fan 3 | * Modified by Charles R. Qi 4 | * All Rights Reserved. 2017. 5 | */ 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | #include 11 | 12 | using namespace tensorflow; 13 | 14 | REGISTER_OP("ProbSample") 15 | .Input("inp: float32") 16 | .Input("inpr: float32") 17 | .Output("out: int32") 18 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 19 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ncategory 20 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &dims1)); 21 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints 22 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &dims2)); 23 | // batch_size * npoints 24 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)}); 25 | c->set_output(0, output); 26 | return Status::OK(); 27 | }); 28 | REGISTER_OP("FarthestPointSample") 29 | .Attr("npoint: int") 30 | .Input("inp: float32") 31 | .Output("out: int32") 32 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 33 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * npoint * 3 34 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 35 | int npoint; 36 | TF_RETURN_IF_ERROR(c->GetAttr("npoint", &npoint)); 37 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), npoint}); 38 | c->set_output(0, output); 39 | return Status::OK(); 40 | }); 41 | REGISTER_OP("GatherPoint") 42 | .Input("inp: float32") 43 | .Input("idx: int32") 44 | .Output("out: float32") 45 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 46 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ndataset * 3 47 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &dims1)); 48 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints 49 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &dims2)); 50 | // batch_size * npoints * 3 51 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), c->Dim(dims2, 1), c->Dim(dims1, 2)}); 52 | c->set_output(0, output); 53 | return Status::OK(); 54 | }); 55 | REGISTER_OP("GatherPointGrad") 56 | .Input("inp: float32") 57 | .Input("idx: int32") 58 | .Input("out_g: float32") 59 | .Output("inp_g: float32") 60 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 61 | c->set_output(0, c->input(0)); 62 | return Status::OK(); 63 | }); 64 | 65 | void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out); 66 | class ProbSampleGpuOp: public OpKernel{ 67 | public: 68 | explicit ProbSampleGpuOp(OpKernelConstruction* context):OpKernel(context){} 69 | void Compute(OpKernelContext * context)override{ 70 | const Tensor& inp_tensor=context->input(0); 71 | const Tensor& inpr_tensor=context->input(1); 72 | auto inp_flat=inp_tensor.flat(); 73 | auto inpr_flat=inpr_tensor.flat(); 74 | const float * inp=&(inp_flat(0)); 75 | const float * inpr=&(inpr_flat(0)); 76 | OP_REQUIRES(context,inp_tensor.dims()==2,errors::InvalidArgument("ProbSample expects (batch_size,num_choices) inp shape")); 77 | int b=inp_tensor.shape().dim_size(0); 78 | int n=inp_tensor.shape().dim_size(1); 79 | OP_REQUIRES(context,inpr_tensor.dims()==2 && inpr_tensor.shape().dim_size(0)==b,errors::InvalidArgument("ProbSample expects (batch_size,num_points) inpr shape")); 80 | int m=inpr_tensor.shape().dim_size(1); 81 | Tensor * out_tensor=NULL; 82 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m},&out_tensor)); 83 | auto out_flat=out_tensor->flat(); 84 | int * out=&(out_flat(0)); 85 | Tensor temp_tensor; 86 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{b,n},&temp_tensor)); 87 | auto temp_flat=temp_tensor.flat(); 88 | float * temp=&(temp_flat(0)); 89 | probsampleLauncher(b,n,m,inp,inpr,temp,out); 90 | } 91 | }; 92 | REGISTER_KERNEL_BUILDER(Name("ProbSample").Device(DEVICE_GPU), ProbSampleGpuOp); 93 | 94 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out); 95 | class FarthestPointSampleGpuOp: public OpKernel{ 96 | public: 97 | explicit FarthestPointSampleGpuOp(OpKernelConstruction* context):OpKernel(context) { 98 | OP_REQUIRES_OK(context, context->GetAttr("npoint", &npoint_)); 99 | OP_REQUIRES(context, npoint_ > 0, errors::InvalidArgument("FarthestPointSample expects positive npoint")); 100 | } 101 | void Compute(OpKernelContext * context)override{ 102 | int m = npoint_; 103 | 104 | const Tensor& inp_tensor=context->input(0); 105 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("FarthestPointSample expects (batch_size,num_points,3) inp shape")); 106 | int b=inp_tensor.shape().dim_size(0); 107 | int n=inp_tensor.shape().dim_size(1); 108 | auto inp_flat=inp_tensor.flat(); 109 | const float * inp=&(inp_flat(0)); 110 | Tensor * out_tensor; 111 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m},&out_tensor)); 112 | auto out_flat=out_tensor->flat(); 113 | int * out=&(out_flat(0)); 114 | Tensor temp_tensor; 115 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{32,n},&temp_tensor)); 116 | auto temp_flat=temp_tensor.flat(); 117 | float * temp=&(temp_flat(0)); 118 | farthestpointsamplingLauncher(b,n,m,inp,temp,out); 119 | } 120 | private: 121 | int npoint_; 122 | }; 123 | REGISTER_KERNEL_BUILDER(Name("FarthestPointSample").Device(DEVICE_GPU),FarthestPointSampleGpuOp); 124 | 125 | void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out); 126 | class GatherPointGpuOp: public OpKernel{ 127 | public: 128 | explicit GatherPointGpuOp(OpKernelConstruction * context):OpKernel(context){} 129 | void Compute(OpKernelContext * context)override{ 130 | const Tensor& inp_tensor=context->input(0); 131 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPoint expects (batch_size,num_points,3) inp shape")); 132 | int b=inp_tensor.shape().dim_size(0); 133 | int n=inp_tensor.shape().dim_size(1); 134 | const Tensor& idx_tensor=context->input(1); 135 | OP_REQUIRES(context,idx_tensor.dims()==2 && idx_tensor.shape().dim_size(0)==b,errors::InvalidArgument("GatherPoint expects (batch_size,num_result) idx shape")); 136 | int m=idx_tensor.shape().dim_size(1); 137 | auto inp_flat=inp_tensor.flat(); 138 | const float * inp=&(inp_flat(0)); 139 | auto idx_flat=idx_tensor.flat(); 140 | const int * idx=&(idx_flat(0)); 141 | Tensor * out_tensor=NULL; 142 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m,3},&out_tensor)); 143 | auto out_flat=out_tensor->flat(); 144 | float * out=&(out_flat(0)); 145 | gatherpointLauncher(b,n,m,inp,idx,out); 146 | } 147 | }; 148 | REGISTER_KERNEL_BUILDER(Name("GatherPoint").Device(DEVICE_GPU),GatherPointGpuOp); 149 | 150 | void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g); 151 | class GatherPointGradGpuOp: public OpKernel{ 152 | public: 153 | explicit GatherPointGradGpuOp(OpKernelConstruction * context):OpKernel(context){} 154 | void Compute(OpKernelContext * context)override{ 155 | const Tensor& inp_tensor=context->input(0); 156 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_points,3) inp")); 157 | int b=inp_tensor.shape().dim_size(0); 158 | int n=inp_tensor.shape().dim_size(1); 159 | const Tensor& idx_tensor=context->input(1); 160 | OP_REQUIRES(context,idx_tensor.dims()==2 && idx_tensor.shape().dim_size(0)==b,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_result) idx shape")); 161 | int m=idx_tensor.shape().dim_size(1); 162 | auto inp_flat=inp_tensor.flat(); 163 | const float * inp=&(inp_flat(0)); 164 | auto idx_flat=idx_tensor.flat(); 165 | const int * idx=&(idx_flat(0)); 166 | const Tensor& out_g_tensor=context->input(2); 167 | OP_REQUIRES(context,out_g_tensor.dims()==3 && out_g_tensor.shape().dim_size(0)==b && out_g_tensor.shape().dim_size(1)==m && out_g_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_result,3) out_g shape")); 168 | auto out_g_flat=out_g_tensor.flat(); 169 | const float * out_g=&(out_g_flat(0)); 170 | Tensor * inp_g_tensor=NULL; 171 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&inp_g_tensor)); 172 | auto inp_g_flat=inp_g_tensor->flat(); 173 | float * inp_g=&(inp_g_flat(0)); 174 | cudaMemset(inp_g,0,b*n*3*4); 175 | scatteraddpointLauncher(b,n,m,out_g,idx,inp_g); 176 | } 177 | }; 178 | REGISTER_KERNEL_BUILDER(Name("GatherPointGrad").Device(DEVICE_GPU),GatherPointGradGpuOp); 179 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling_g.cu: -------------------------------------------------------------------------------- 1 | /* Furthest point sampling GPU implementation 2 | * Original author: Haoqiang Fan 3 | * Modified by Charles R. Qi 4 | * All Rights Reserved. 2017. 5 | */ 6 | 7 | __global__ void cumsumKernel(int b,int n,const float * __restrict__ inp,float * __restrict__ out){ 8 | const int BlockSize=2048; 9 | const int paddingLevel=5; 10 | __shared__ float buffer4[BlockSize*4]; 11 | __shared__ float buffer[BlockSize+(BlockSize>>paddingLevel)]; 12 | for (int i=blockIdx.x;i>2; 18 | for (int k=threadIdx.x*4;k>2)+(k>>(2+paddingLevel))]=v4; 33 | }else{ 34 | float v=0; 35 | for (int k2=k;k2>2)+(k>>(2+paddingLevel))]=v; 43 | } 44 | } 45 | int u=0; 46 | for (;(2<>(u+1));k+=blockDim.x){ 49 | int i1=(((k<<1)+2)<>paddingLevel; 52 | i2+=i2>>paddingLevel; 53 | buffer[i1]+=buffer[i2]; 54 | } 55 | } 56 | u--; 57 | for (;u>=0;u--){ 58 | __syncthreads(); 59 | for (int k=threadIdx.x;k>(u+1));k+=blockDim.x){ 60 | int i1=(((k<<1)+3)<>paddingLevel; 63 | i2+=i2>>paddingLevel; 64 | buffer[i1]+=buffer[i2]; 65 | } 66 | } 67 | __syncthreads(); 68 | for (int k=threadIdx.x*4;k>2)-1)+(((k>>2)-1)>>paddingLevel); 71 | buffer4[k]+=buffer[k2]; 72 | buffer4[k+1]+=buffer[k2]; 73 | buffer4[k+2]+=buffer[k2]; 74 | buffer4[k+3]+=buffer[k2]; 75 | } 76 | } 77 | __syncthreads(); 78 | for (int k=threadIdx.x;k>paddingLevel)]+runningsum2; 82 | float r2=runningsum+t; 83 | runningsum2=t-(r2-runningsum); 84 | runningsum=r2; 85 | __syncthreads(); 86 | } 87 | } 88 | } 89 | 90 | __global__ void binarysearchKernel(int b,int n,int m,const float * __restrict__ dataset,const float * __restrict__ query, int * __restrict__ result){ 91 | int base=1; 92 | while (base=1;k>>=1) 99 | if (r>=k && dataset[i*n+r-k]>=q) 100 | r-=k; 101 | result[i*m+j]=r; 102 | } 103 | } 104 | } 105 | __global__ void farthestpointsamplingKernel(int b,int n,int m,const float * __restrict__ dataset,float * __restrict__ temp,int * __restrict__ idxs){ 106 | if (m<=0) 107 | return; 108 | const int BlockSize=512; 109 | __shared__ float dists[BlockSize]; 110 | __shared__ int dists_i[BlockSize]; 111 | const int BufferSize=3072; 112 | __shared__ float buf[BufferSize*3]; 113 | for (int i=blockIdx.x;ibest){ 147 | best=d2; 148 | besti=k; 149 | } 150 | } 151 | dists[threadIdx.x]=best; 152 | dists_i[threadIdx.x]=besti; 153 | for (int u=0;(1<>(u+1))){ 156 | int i1=(threadIdx.x*2)<>>(b,n,inp,out); 196 | } 197 | //require b*n working space 198 | void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out){ 199 | cumsumKernel<<<32,512>>>(b,n,inp_p,temp); 200 | binarysearchKernel<<>>(b,n,m,temp,inp_r,out); 201 | } 202 | //require 32*n working space 203 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out){ 204 | farthestpointsamplingKernel<<<32,512>>>(b,n,m,inp,temp,out); 205 | } 206 | void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out){ 207 | gatherpointKernel<<>>(b,n,m,inp,idx,out); 208 | } 209 | void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g){ 210 | scatteraddpointKernel<<>>(b,n,m,out_g,idx,inp_g); 211 | } 212 | 213 | -------------------------------------------------------------------------------- /train_modelnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import datetime 4 | 5 | sys.path.insert(0, './') 6 | 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 8 | 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | 12 | from models.cls_msg_model import CLS_MSG_Model 13 | from models.cls_ssg_model import CLS_SSG_Model 14 | 15 | tf.random.set_seed(1234) 16 | 17 | 18 | def load_dataset(in_file, batch_size): 19 | 20 | assert os.path.isfile(in_file), '[error] dataset path not found' 21 | 22 | n_points = 8192 23 | shuffle_buffer = 1000 24 | 25 | def _extract_fn(data_record): 26 | 27 | in_features = { 28 | 'points': tf.io.FixedLenFeature([n_points * 3], tf.float32), 29 | 'label': tf.io.FixedLenFeature([1], tf.int64) 30 | } 31 | 32 | return tf.io.parse_single_example(data_record, in_features) 33 | 34 | def _preprocess_fn(sample): 35 | 36 | points = sample['points'] 37 | label = sample['label'] 38 | 39 | points = tf.reshape(points, (n_points, 3)) 40 | points = tf.random.shuffle(points) 41 | 42 | return points, label 43 | 44 | dataset = tf.data.TFRecordDataset(in_file) 45 | dataset = dataset.shuffle(shuffle_buffer) 46 | dataset = dataset.map(_extract_fn) 47 | dataset = dataset.map(_preprocess_fn) 48 | dataset = dataset.batch(batch_size, drop_remainder=True) 49 | 50 | return dataset 51 | 52 | 53 | def train(): 54 | 55 | if config['msg'] == True: 56 | model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['bn']) 57 | else: 58 | model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['bn']) 59 | 60 | train_ds = load_dataset(config['train_ds'], config['batch_size']) 61 | val_ds = load_dataset(config['val_ds'], config['batch_size']) 62 | 63 | callbacks = [ 64 | keras.callbacks.EarlyStopping( 65 | 'val_sparse_categorical_accuracy', min_delta=0.01, patience=10), 66 | keras.callbacks.TensorBoard( 67 | './logs/{}'.format(config['log_dir']), update_freq=50), 68 | keras.callbacks.ModelCheckpoint( 69 | './logs/{}/model/weights.ckpt'.format(config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) 70 | ] 71 | 72 | model.build(input_shape=(config['batch_size'], 8192, 3)) 73 | print(model.summary()) 74 | 75 | model.compile( 76 | optimizer=keras.optimizers.Adam(config['lr']), 77 | loss=keras.losses.SparseCategoricalCrossentropy(), 78 | metrics=[keras.metrics.SparseCategoricalAccuracy()] 79 | ) 80 | 81 | model.fit( 82 | train_ds, 83 | validation_data = val_ds, 84 | validation_steps = 20, 85 | validation_freq = 1, 86 | callbacks = callbacks, 87 | epochs = 100, 88 | verbose = 1 89 | ) 90 | 91 | 92 | if __name__ == '__main__': 93 | 94 | config = { 95 | 'train_ds' : 'data/modelnet_train.tfrecord', 96 | 'val_ds' : 'data/modelnet_val.tfrecord', 97 | 'log_dir' : 'msg_1', 98 | 'batch_size' : 4, 99 | 'lr' : 0.001, 100 | 'num_classes' : 40, 101 | 'msg' : True, 102 | 'bn' : False 103 | } 104 | 105 | train() 106 | -------------------------------------------------------------------------------- /train_scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import datetime 4 | 5 | sys.path.insert(0, './') 6 | 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 8 | 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | 12 | from models.sem_seg_model import SEM_SEG_Model 13 | 14 | tf.random.set_seed(42) 15 | 16 | 17 | def load_dataset(in_file, batch_size): 18 | 19 | assert os.path.isfile(in_file), '[error] dataset path not found' 20 | 21 | n_points = 8192 22 | shuffle_buffer = 1000 23 | 24 | def _extract_fn(data_record): 25 | 26 | in_features = { 27 | 'points': tf.io.FixedLenFeature([n_points * 3], tf.float32), 28 | 'labels': tf.io.FixedLenFeature([n_points], tf.int64) 29 | } 30 | 31 | return tf.io.parse_single_example(data_record, in_features) 32 | 33 | def _preprocess_fn(sample): 34 | 35 | points = sample['points'] 36 | labels = sample['labels'] 37 | 38 | points = tf.reshape(points, (n_points, 3)) 39 | labels = tf.reshape(labels, (n_points, 1)) 40 | 41 | shuffle_idx = tf.range(points.shape[0]) 42 | shuffle_idx = tf.random.shuffle(shuffle_idx) 43 | points = tf.gather(points, shuffle_idx) 44 | labels = tf.gather(labels, shuffle_idx) 45 | 46 | return points, labels 47 | 48 | dataset = tf.data.TFRecordDataset(in_file) 49 | dataset = dataset.shuffle(shuffle_buffer) 50 | dataset = dataset.map(_extract_fn) 51 | dataset = dataset.map(_preprocess_fn) 52 | dataset = dataset.batch(batch_size, drop_remainder=True) 53 | 54 | return dataset 55 | 56 | 57 | def train(): 58 | 59 | model = SEM_SEG_Model(config['batch_size'], config['num_classes'], config['bn']) 60 | 61 | train_ds = load_dataset(config['train_ds'], config['batch_size']) 62 | val_ds = load_dataset(config['val_ds'], config['batch_size']) 63 | 64 | callbacks = [ 65 | keras.callbacks.TensorBoard( 66 | './logs/{}'.format(config['log_dir']), update_freq=50), 67 | keras.callbacks.ModelCheckpoint( 68 | './logs/{}/model/weights'.format(config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) 69 | ] 70 | 71 | model.build((config['batch_size'], 8192, 3)) 72 | print(model.summary()) 73 | 74 | model.compile( 75 | optimizer=keras.optimizers.Adam(config['lr']), 76 | loss=keras.losses.SparseCategoricalCrossentropy(), 77 | metrics=[keras.metrics.SparseCategoricalAccuracy()] 78 | ) 79 | 80 | model.fit( 81 | train_ds, 82 | validation_data=val_ds, 83 | validation_steps=10, 84 | validation_freq=1, 85 | callbacks=callbacks, 86 | epochs=100, 87 | verbose=1 88 | ) 89 | 90 | 91 | if __name__ == '__main__': 92 | 93 | config = { 94 | 'train_ds' : 'data/scannet_train.tfrecord', 95 | 'val_ds' : 'data/scannet_val.tfrecord', 96 | 'log_dir' : 'scannet_1', 97 | 'log_freq' : 10, 98 | 'test_freq' : 100, 99 | 'batch_size' : 4, 100 | 'num_classes' : 21, 101 | 'lr' : 0.001, 102 | 'bn' : False, 103 | } 104 | 105 | train() 106 | --------------------------------------------------------------------------------