├── .gitignore ├── LICENSE ├── PointConv.py ├── README.md ├── evaluate_scannet.py ├── imgs └── example.png ├── models ├── pointconv_weight_density_n16.py └── pointconv_weight_density_n16_dp.py ├── scannet ├── README.md ├── eulerangles.py ├── pc_util.py ├── scannet_dataset_rgb.py ├── scannet_dataset_sw_rgb.py ├── scannetv2_seg_dataset_rgb21c_pointid.py ├── scannetv2_test.txt ├── scannetv2_train.txt ├── scannetv2_val.txt ├── util.py └── visualize │ ├── util.py │ ├── util_3d.py │ └── visualize_labels_on_mesh.py ├── tf_ops ├── 3d_interpolation │ ├── interpolate.cpp │ ├── tf_interpolate.cpp │ ├── tf_interpolate.py │ ├── tf_interpolate_compile.sh │ ├── tf_interpolate_op_test.py │ ├── tf_interpolate_so.so │ └── visu_interpolation.py ├── grouping │ ├── .gitignore │ ├── test │ │ ├── compile.sh │ │ ├── query_ball_point.cpp │ │ ├── query_ball_point.cu │ │ ├── query_ball_point_block.cu │ │ ├── query_ball_point_grid.cu │ │ ├── selection_sort.cpp │ │ ├── selection_sort.cu │ │ └── selection_sort_const.cu │ ├── tf_grouping.cpp │ ├── tf_grouping.py │ ├── tf_grouping.pyc │ ├── tf_grouping_compile.sh │ ├── tf_grouping_g.cu │ └── tf_grouping_op_test.py └── sampling │ ├── .gitignore │ ├── tf_sampling.cpp │ ├── tf_sampling.py │ ├── tf_sampling.pyc │ ├── tf_sampling_compile.sh │ └── tf_sampling_g.cu ├── train_scannet_IoU.py └── utils ├── pointconv_util.py ├── provider.py └── tf_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pickle 3 | *.so -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | PointConv: Deep Convolutional Networks on 3D Point Clouds. 2 | 3 | Copyright (c) 2019, Deep Machine Vision group, Oregon State University 4 | 5 | The MIT License (MIT) 6 | 7 | Copyright (c) 2019 Wenxuan Wu 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /PointConv.py: -------------------------------------------------------------------------------- 1 | """ 2 | PointConv operation 3 | Author: Wenxuan Wu 4 | Date: July 2018 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import math 12 | import numpy as np 13 | import tensorflow as tf 14 | import os 15 | import sys 16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 18 | sys.path.append(os.path.join(BASE_DIR, 'tf_ops/3d_interpolation')) 19 | sys.path.append(os.path.join(BASE_DIR, 'tf_ops/grouping')) 20 | from tf_interpolate import three_nn, three_interpolate 21 | import tf_grouping 22 | import pointconv_util 23 | import tf_util 24 | 25 | def weight_net_hidden(xyz, hidden_units, scope, is_training, bn_decay=None, weight_decay = None, activation_fn=tf.nn.relu): 26 | 27 | with tf.variable_scope(scope) as sc: 28 | net = xyz 29 | for i, num_hidden_units in enumerate(hidden_units): 30 | net = tf_util.conv2d(net, num_hidden_units, [1, 1], 31 | padding = 'VALID', stride=[1, 1], 32 | bn = True, is_training = is_training, activation_fn=activation_fn, 33 | scope = 'wconv%d'%(i), bn_decay=bn_decay, weight_decay = weight_decay) 34 | 35 | #net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='wconv_dp%d'%(i)) 36 | return net 37 | 38 | def weight_net(xyz, hidden_units, scope, is_training, bn_decay=None, weight_decay = None, activation_fn=tf.nn.relu): 39 | 40 | with tf.variable_scope(scope) as sc: 41 | net = xyz 42 | for i, num_hidden_units in enumerate(hidden_units): 43 | if i != len(hidden_units) -1: 44 | net = tf_util.conv2d(net, num_hidden_units, [1, 1], 45 | padding = 'VALID', stride=[1, 1], 46 | bn = True, is_training = is_training, activation_fn=activation_fn, 47 | scope = 'wconv%d'%(i), bn_decay=bn_decay, weight_decay = weight_decay) 48 | else: 49 | net = tf_util.conv2d(net, num_hidden_units, [1, 1], 50 | padding = 'VALID', stride=[1, 1], 51 | bn = False, is_training = is_training, activation_fn=None, 52 | scope = 'wconv%d'%(i), bn_decay=bn_decay, weight_decay = weight_decay) 53 | #net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='wconv_dp%d'%(i)) 54 | return net 55 | 56 | def nonlinear_transform(data_in, mlp, scope, is_training, bn_decay=None, weight_decay = None, activation_fn = tf.nn.relu): 57 | 58 | with tf.variable_scope(scope) as sc: 59 | 60 | net = data_in 61 | l = len(mlp) 62 | if l > 1: 63 | for i, out_ch in enumerate(mlp[0:(l-1)]): 64 | net = tf_util.conv2d(net, out_ch, [1, 1], 65 | padding = 'VALID', stride=[1, 1], 66 | bn = True, is_training = is_training, activation_fn=tf.nn.relu, 67 | scope = 'nonlinear%d'%(i), bn_decay=bn_decay, weight_decay = weight_decay) 68 | 69 | #net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp_nonlinear%d'%(i)) 70 | net = tf_util.conv2d(net, mlp[-1], [1, 1], 71 | padding = 'VALID', stride=[1, 1], 72 | bn = False, is_training = is_training, 73 | scope = 'nonlinear%d'%(l-1), bn_decay=bn_decay, 74 | activation_fn=tf.nn.sigmoid, weight_decay = weight_decay) 75 | 76 | return net 77 | 78 | def feature_encoding_layer(xyz, feature, npoint, radius, sigma, K, mlp, is_training, bn_decay, weight_decay, scope, bn=True, use_xyz=True): 79 | ''' Input: 80 | xyz: (batch_size, ndataset, 3) TF tensor 81 | feature: (batch_size, ndataset, channel) TF tensor 82 | npoint: int32 -- #points sampled in farthest point sampling 83 | sigma: float32 -- KDE bandwidth 84 | K: int32 -- how many points in each local region 85 | mlp: list of int32 -- output size for MLP on each point 86 | use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features 87 | Return: 88 | new_xyz: (batch_size, npoint, 3) TF tensor 89 | new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor 90 | ''' 91 | with tf.variable_scope(scope) as sc: 92 | num_points = xyz.get_shape()[1] 93 | if num_points == npoint: 94 | new_xyz = xyz 95 | else: 96 | new_xyz = pointconv_util.sampling(npoint, xyz) 97 | 98 | grouped_xyz, grouped_feature, idx = pointconv_util.grouping(feature, K, xyz, new_xyz) 99 | 100 | density = pointconv_util.kernel_density_estimation_ball(xyz, radius, sigma) 101 | inverse_density = tf.div(1.0, density) 102 | grouped_density = tf.gather_nd(inverse_density, idx) # (batch_size, npoint, nsample, 1) 103 | #grouped_density = tf_grouping.group_point(inverse_density, idx) 104 | inverse_max_density = tf.reduce_max(grouped_density, axis = 2, keepdims = True) 105 | density_scale = tf.div(grouped_density, inverse_max_density) 106 | 107 | #density_scale = tf_grouping.group_point(density, idx) 108 | 109 | for i, num_out_channel in enumerate(mlp): 110 | if i != len(mlp) - 1: 111 | grouped_feature = tf_util.conv2d(grouped_feature, num_out_channel, [1,1], 112 | padding='VALID', stride=[1,1], 113 | bn=bn, is_training=is_training, 114 | scope='conv%d'%(i), bn_decay=bn_decay, weight_decay = weight_decay) 115 | 116 | weight = weight_net_hidden(grouped_xyz, [32], scope = 'weight_net', is_training=is_training, bn_decay = bn_decay, weight_decay = weight_decay) 117 | 118 | density_scale = nonlinear_transform(density_scale, [16, 1], scope = 'density_net', is_training=is_training, bn_decay = bn_decay, weight_decay = weight_decay) 119 | 120 | new_points = tf.multiply(grouped_feature, density_scale) 121 | 122 | new_points = tf.transpose(new_points, [0, 1, 3, 2]) 123 | 124 | new_points = tf.matmul(new_points, weight) 125 | 126 | new_points = tf_util.conv2d(new_points, mlp[-1], [1,new_points.get_shape()[2].value], 127 | padding='VALID', stride=[1,1], 128 | bn=bn, is_training=is_training, 129 | scope='after_conv', bn_decay=bn_decay, weight_decay = weight_decay) 130 | 131 | new_points = tf.squeeze(new_points, [2]) # (batch_size, npoints, mlp2[-1]) 132 | 133 | return new_xyz, new_points 134 | 135 | def feature_decoding_layer(xyz1, xyz2, points1, points2, radius, sigma, K, mlp, is_training, bn_decay, weight_decay, scope, bn=True, use_xyz = True): 136 | ''' Input: 137 | xyz1: (batch_size, ndataset1, 3) TF tensor 138 | xyz2: (batch_size, ndataset2, 3) TF tensor, sparser than xyz1 139 | points1: (batch_size, ndataset1, nchannel1) TF tensor 140 | points2: (batch_size, ndataset2, nchannel2) TF tensor 141 | sigma: float32 -- KDE bandwidth 142 | K: int32 -- how many points in each local region 143 | mlp: list of int32 -- output size for MLP on each point 144 | Return: 145 | new_points: (batch_size, ndataset1, mlp[-1]) TF tensor 146 | ''' 147 | with tf.variable_scope(scope) as sc: 148 | dist, idx = three_nn(xyz1, xyz2) 149 | dist = tf.maximum(dist, 1e-10) 150 | norm = tf.reduce_sum((1.0/dist),axis=2,keepdims=True) 151 | norm = tf.tile(norm,[1,1,3]) 152 | weight = (1.0/dist) / norm 153 | interpolated_points = three_interpolate(points2, idx, weight) 154 | 155 | #setup for deConv 156 | grouped_xyz, grouped_feature, idx = pointconv_util.grouping(interpolated_points, K, xyz1, xyz1, use_xyz=use_xyz) 157 | 158 | density = pointconv_util.kernel_density_estimation_ball(xyz1, radius, sigma) 159 | inverse_density = tf.div(1.0, density) 160 | grouped_density = tf.gather_nd(inverse_density, idx) # (batch_size, npoint, nsample, 1) 161 | #grouped_density = tf_grouping.group_point(inverse_density, idx) 162 | inverse_max_density = tf.reduce_max(grouped_density, axis = 2, keepdims = True) 163 | density_scale = tf.div(grouped_density, inverse_max_density) 164 | 165 | #density_scale = tf_grouping.group_point(density, idx) 166 | 167 | weight = weight_net_hidden(grouped_xyz, [32], scope = 'decode_weight_net', is_training=is_training, bn_decay = bn_decay, weight_decay = weight_decay) 168 | 169 | density_scale = nonlinear_transform(density_scale, [16, 1], scope = 'decode_density_net', is_training=is_training, bn_decay = bn_decay, weight_decay = weight_decay) 170 | 171 | new_points = tf.multiply(grouped_feature, density_scale) 172 | 173 | new_points = tf.transpose(new_points, [0, 1, 3, 2]) 174 | 175 | new_points = tf.matmul(new_points, weight) 176 | 177 | new_points = tf_util.conv2d(new_points, mlp[0], [1,new_points.get_shape()[2].value], 178 | padding='VALID', stride=[1,1], 179 | bn=bn, is_training=is_training, 180 | scope='decode_after_conv', bn_decay=bn_decay, weight_decay = weight_decay) 181 | 182 | if points1 is not None: 183 | new_points1 = tf.concat(axis=-1, values=[new_points, tf.expand_dims(points1, axis = 2)]) # B,ndataset1,nchannel1+nchannel2 184 | else: 185 | new_points1 = new_points 186 | 187 | for i, num_out_channel in enumerate(mlp): 188 | if i != 0: 189 | new_points1 = tf_util.conv2d(new_points1, num_out_channel, [1,1], 190 | padding='VALID', stride=[1,1], 191 | bn=bn, is_training=is_training, 192 | scope='conv_%d'%(i), bn_decay=bn_decay, weight_decay = weight_decay) 193 | new_points1 = tf.squeeze(new_points1, [2]) # B,ndataset1,mlp[-1] 194 | return new_points1 195 | 196 | def feature_decoding_layer_depthwise(xyz1, xyz2, points1, points2, radius, sigma, K, mlp, is_training, bn_decay, weight_decay, scope, bn=True, use_xyz = True): 197 | ''' Input: 198 | depthwise version of pointconv 199 | xyz1: (batch_size, ndataset1, 3) TF tensor 200 | xyz2: (batch_size, ndataset2, 3) TF tensor, sparser than xyz1 201 | points1: (batch_size, ndataset1, nchannel1) TF tensor 202 | points2: (batch_size, ndataset2, nchannel2) TF tensor 203 | sigma: float32 -- KDE bandwidth 204 | K: int32 -- how many points in each local region 205 | mlp: list of int32 -- output size for MLP on each point 206 | Return: 207 | new_points: (batch_size, ndataset1, mlp[-1]) TF tensor 208 | ''' 209 | with tf.variable_scope(scope) as sc: 210 | dist, idx = three_nn(xyz1, xyz2) 211 | dist = tf.maximum(dist, 1e-10) 212 | norm = tf.reduce_sum((1.0/dist),axis=2,keepdims=True) 213 | norm = tf.tile(norm,[1,1,3]) 214 | weight = (1.0/dist) / norm 215 | interpolated_points = three_interpolate(points2, idx, weight) 216 | 217 | #setup for deConv 218 | grouped_xyz, grouped_feature, idx = pointconv_util.grouping(interpolated_points, K, xyz1, xyz1, use_xyz=use_xyz) 219 | 220 | density = pointconv_util.kernel_density_estimation_ball(xyz1, radius, sigma) 221 | inverse_density = tf.div(1.0, density) 222 | grouped_density = tf.gather_nd(inverse_density, idx) # (batch_size, npoint, nsample, 1) 223 | #grouped_density = tf_grouping.group_point(inverse_density, idx) 224 | inverse_max_density = tf.reduce_max(grouped_density, axis = 2, keepdims = True) 225 | density_scale = tf.div(grouped_density, inverse_max_density) 226 | 227 | #density_scale = tf_grouping.group_point(density, idx) 228 | 229 | weight = weight_net(grouped_xyz, [32, grouped_feature.get_shape()[3].value], scope = 'decode_weight_net', is_training=is_training, bn_decay = bn_decay, weight_decay = weight_decay) 230 | 231 | density_scale = nonlinear_transform(density_scale, [16, 1], scope = 'decode_density_net', is_training=is_training, bn_decay = bn_decay, weight_decay = weight_decay) 232 | 233 | new_points = tf.multiply(grouped_feature, density_scale) 234 | 235 | new_points = tf.multiply(grouped_feature, weight) 236 | 237 | new_points = tf_util.reduce_sum2d_conv(new_points, axis = 2, scope = 'fp_sumpool', bn=True, 238 | bn_decay = bn_decay, is_training = is_training, keepdims = False) 239 | 240 | if points1 is not None: 241 | new_points1 = tf.concat(axis=-1, values=[new_points, points1]) # B,ndataset1,nchannel1+nchannel2 242 | else: 243 | new_points1 = new_points 244 | new_points1 = tf.expand_dims(new_points1, 2) 245 | for i, num_out_channel in enumerate(mlp): 246 | new_points1 = tf_util.conv2d(new_points1, num_out_channel, [1,1], 247 | padding='VALID', stride=[1,1], 248 | bn=bn, is_training=is_training, 249 | scope='conv_%d'%(i), bn_decay=bn_decay, weight_decay = weight_decay) 250 | new_points1 = tf.squeeze(new_points1, [2]) # B,ndataset1,mlp[-1] 251 | return new_points1 252 | 253 | def placeholder_inputs(batch_size, num_point, channel): 254 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 255 | feature_pts_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, channel)) 256 | labels_pl = tf.placeholder(tf.int32, shape=(batch_size, num_point)) 257 | return pointclouds_pl, feature_pts_pl, labels_pl 258 | 259 | if __name__=='__main__': 260 | import numpy as np 261 | pts = np.random.random((32, 2048, 3)).astype('float32') 262 | fpts = pts 263 | sigma = 0.1 264 | N = 512 265 | K = 64 266 | D = 1 267 | C_list = [64, 128] 268 | mlp_w = [64] 269 | mlp_d = [64] 270 | is_training = tf.placeholder(tf.bool, shape=()) 271 | 272 | import pdb 273 | pdb.set_trace() 274 | 275 | with tf.device('/gpu:1'): 276 | #points = tf.constant(pts) 277 | #features = tf.constant(fpts) 278 | points_pl, features_pl, labels_pl = placeholder_inputs(32, 2048, 3) 279 | sub_pts, features = feature_encoding_layer(points_pl, features_pl, N, sigma, K, [10, 20], is_training, bn_decay = 0.1, weight_decay = 0.1, scope = "FE") 280 | feature_decode = feature_decoding_layer(points_pl, sub_pts, features_pl, features, sigma, K, [10, 23], is_training, bn_decay=0.1, weight_decay = 0.1, scope= "FD") 281 | 282 | 283 | 284 | 285 | 286 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointConv 2 | **PointConv: Deep Convolutional Networks on 3D Point Clouds.** CVPR 2019 3 | Wenxuan Wu, Zhongang Qi, Li Fuxin. 4 | 5 |

6 | 7 | ## Introduction 8 | This project is based on our CVPR2019 paper. You can find the [arXiv](https://arxiv.org/abs/1811.07246) version here. 9 | ``` 10 | @article{wu2018pointconv, 11 | title={PointConv: Deep Convolutional Networks on 3D Point Clouds}, 12 | author={Wu, Wenxuan and Qi, Zhongang and Fuxin, Li}, 13 | journal={arXiv preprint arXiv:1811.07246}, 14 | year={2018} 15 | } 16 | ``` 17 | Unlike images which are represented in regular dense grids, 3D point clouds are irregular and unordered, hence applying convolution on them can be difficult. In this paper, we extend the dynamic filter to a new convolution operation, named PointConv. PointConv can be applied on point clouds to build deep convolutional networks. We treat convolution kernels as nonlinear functions of the local coordinates of 3D points comprised of weight and density functions. With respect to a given point, the weight functions are learned with multi-layer perceptron networks and the density functions through kernel density estimation. A novel reformulation is proposed for efficiently computing the weight functions, which allowed us to dramatically scale up the network and significantly improve its performance. The learned convolution kernel can be used to compute translation-invariant and permutation-invariant convolution on any point set in the 3D space. Besides, PointConv can also be used as deconvolution operators to propagate features from a subsampled point cloud back to its original resolution. Experiments on ModelNet40, ShapeNet, and ScanNet show that deep convolutional neural networks built on PointConv are able to achieve state-of-the-art on challenging semantic segmentation benchmarks on 3D point clouds. Besides, our experiments converting CIFAR-10 into a point cloud showed that networks built on PointConv can match the performance of convolutional networks in 2D images of a similar structure. 18 | 19 | ## Installation 20 | The code is based on [PointNet](https://github.com/charlesq34/pointnet), and [PointNet++](https://github.com/charlesq34/pointnet2). Please install [TensorFlow](https://www.tensorflow.org/install/), and follow the instruction in [PointNet++](https://github.com/charlesq34/pointnet2) to compile the customized TF operators. 21 | The code has been tested with Python 2.7, TensorFlow 1.11.0, CUDA 9.0 and cuDNN 7.3 on Ubuntu 16.04. 22 | 23 | ## Usage 24 | ### ModelNet40 Classification 25 | Please check [pointconv_pytorch](https://github.com/DylanWusee/pointconv_pytorch) for details on Classification task on ModelNet40 using pytorch. 26 | 27 | ### ScanetNet DataSet Segmentation 28 | 29 | Download the ScanNetv2 dataset from [here](http://www.scan-net.org/), and see `scannet/README` for details of preprocessing. 30 | 31 | To train a model to segment Scannet Scenes: 32 | 33 | ``` 34 | CUDA_VISIBLE_DEVICES=0 python train_scannet_IoU.py --model pointconv_weight_density_n16 --log_dir pointconv_scannet_ --batch_size 8 35 | ``` 36 | 37 | After training, to evaluate the segmentation IoU accuracies: 38 | 39 | ``` 40 | CUDA_VISIBLE_DEVICES=0 python evaluate_scannet.py --model pointconv_weight_density_n16 --batch_size 8 --model_path pointconv_scannet_%s --ply_path DataSet/ScanNetv2/scans 41 | ``` 42 | 43 | Modify the model_path to your .ckpt file path and the ply_path to the ScanNetv2 ply file. 44 | 45 | ## License 46 | This repository is released under MIT License (see LICENSE file for details). 47 | -------------------------------------------------------------------------------- /evaluate_scannet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation on ScanNet: Generalize neccenary .ply and .txt file 3 | Author: Wenxuan Wu 4 | Date: July 2018 5 | """ 6 | 7 | import argparse 8 | import math 9 | from datetime import datetime 10 | import h5py 11 | from plyfile import PlyData, PlyElement 12 | import numpy as np 13 | import tensorflow as tf 14 | import socket 15 | import importlib 16 | import os 17 | import sys 18 | from datetime import datetime 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | sys.path.append(BASE_DIR) 21 | sys.path.append(os.path.join(BASE_DIR, 'models')) 22 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 23 | sys.path.append(os.path.join(BASE_DIR, 'scannet')) 24 | #sys.path.append(os.path.join(BASE_DIR, 'scannet/preprocessing')) 25 | sys.path.append(os.path.join(BASE_DIR, 'scannet/visualize')) 26 | import provider 27 | import tf_util 28 | import scannet_dataset_sw_rgb 29 | import pc_util 30 | from visualize_labels_on_mesh import visualize 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 34 | parser.add_argument('--model', default='model', help='Model name [default: model]') 35 | parser.add_argument('--batch_size', type=int, default=8, help='Batch Size during training [default: 8]') 36 | parser.add_argument('--num_point', type=int, default=8192, help='Point Number [256/512/1024/2048] [default: 8192]') 37 | parser.add_argument('--model_path', default='log/model.ckpt', help='model checkpoint file path [default: log/model.ckpt]') 38 | parser.add_argument('--ply_path', default='scannet', help='ply path from original Scannet') 39 | parser.add_argument('--dump_dir', default='dump', help='dump folder path [dump]') 40 | parser.add_argument('--num_votes', type=int, default=5, help='Aggregate classification scores from multiple rotations [default: 5]') 41 | parser.add_argument('--with_rgb',help='With rgb or not', action='store_true') 42 | FLAGS = parser.parse_args() 43 | 44 | BATCH_SIZE = FLAGS.batch_size 45 | NUM_POINT = FLAGS.num_point 46 | MODEL_PATH = FLAGS.model_path 47 | GPU_INDEX = FLAGS.gpu 48 | WITH_RGB = FLAGS.with_rgb 49 | PLY_PATH = FLAGS.ply_path 50 | MODEL = importlib.import_module(FLAGS.model) # import network module 51 | DUMP_DIR = FLAGS.dump_dir + datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 52 | if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR) 53 | LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w') 54 | LOG_FOUT.write(str(FLAGS)+'\n') 55 | 56 | BANDWIDTH = 0.05 57 | 58 | NUM_CLASSES = 21 59 | HOSTNAME = socket.gethostname() 60 | 61 | DATA_PATH = os.path.join(BASE_DIR, 'scannet') 62 | print("start loading whole scene data ...") 63 | TEST_DATASET_WHOLE_SCENE = scannet_dataset_sw_rgb.ScannetDatasetWholeScene_evaluation(root=DATA_PATH, split='val', with_rgb = WITH_RGB) 64 | 65 | def log_string(out_str): 66 | LOG_FOUT.write(out_str+'\n') 67 | LOG_FOUT.flush() 68 | print(out_str) 69 | 70 | def evaluate(num_votes): 71 | with tf.device('/gpu:'+str(GPU_INDEX)): 72 | if WITH_RGB: 73 | pointclouds_pl = tf.placeholder(tf.float32, shape=(BATCH_SIZE, NUM_POINT, 6)) 74 | else: 75 | pointclouds_pl = tf.placeholder(tf.float32, shape=(BATCH_SIZE, NUM_POINT, 3)) 76 | labels_pl = tf.placeholder(tf.int32, shape=(BATCH_SIZE, NUM_POINT)) 77 | smpws_pl = tf.placeholder(tf.float32, shape=(BATCH_SIZE, NUM_POINT)) 78 | is_training_pl = tf.placeholder(tf.bool, shape=()) 79 | 80 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, NUM_CLASSES, BANDWIDTH) 81 | MODEL.get_loss(pred, labels_pl, smpws_pl) 82 | losses = tf.get_collection('losses') 83 | total_loss = tf.add_n(losses, name='total_loss') 84 | saver = tf.train.Saver() 85 | 86 | # Create a session 87 | config = tf.ConfigProto() 88 | config.gpu_options.allow_growth = True 89 | config.allow_soft_placement = True 90 | config.log_device_placement = False 91 | sess = tf.Session(config=config) 92 | 93 | # Restore variables from disk. 94 | saver.restore(sess, MODEL_PATH) 95 | log_string("Model restored.") 96 | 97 | ops = {'pointclouds_pl': pointclouds_pl, 98 | 'labels_pl': labels_pl, 99 | 'is_training_pl': is_training_pl, 100 | 'pred': pred} 101 | 102 | eval_one_epoch(sess, ops, num_votes) 103 | 104 | def add_vote(vote_label_pool, point_idx, pred_label): 105 | B = pred_label.shape[0] 106 | N = pred_label.shape[1] 107 | for b in range(B): 108 | for n in range(N): 109 | vote_label_pool[int(point_idx[b, n]), int(pred_label[b, n])] += 1 110 | return vote_label_pool 111 | 112 | test_class = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]) 113 | 114 | def eval_one_epoch(sess, ops, num_votes=1, topk=1): 115 | is_training = False 116 | file_list = "./scannet/scannetv2_val.txt" 117 | with open(file_list) as fl: 118 | scene_id = fl.read().splitlines() 119 | 120 | num_batches = len(TEST_DATASET_WHOLE_SCENE) 121 | 122 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 123 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 124 | total_iou_deno_class = [0 for _ in range(NUM_CLASSES)] 125 | 126 | log_string(str(datetime.now())) 127 | log_string('---- EVALUATION WHOLE SCENE----') 128 | 129 | for batch_idx in range(num_batches): 130 | print("visualize %d %s ..."%(batch_idx, scene_id[batch_idx])) 131 | whole_scene_points_index = TEST_DATASET_WHOLE_SCENE.scene_points_id[batch_idx] 132 | whole_scene_points_num = TEST_DATASET_WHOLE_SCENE.scene_points_num[batch_idx] 133 | whole_scene_label = TEST_DATASET_WHOLE_SCENE.semantic_labels_list[batch_idx] 134 | vote_label_pool = np.zeros((whole_scene_label.shape[0], NUM_CLASSES)) 135 | for vote_idx in range(num_votes): 136 | scene_data, scene_label, scene_smpw, scene_point_index = TEST_DATASET_WHOLE_SCENE[batch_idx] 137 | num_blocks = scene_data.shape[0] 138 | s_batch_num = (num_blocks + BATCH_SIZE - 1) // BATCH_SIZE 139 | if WITH_RGB: 140 | batch_data = np.zeros((BATCH_SIZE, NUM_POINT, 6)) 141 | else: 142 | batch_data = np.zeros((BATCH_SIZE, NUM_POINT, 3)) 143 | batch_label = np.zeros((BATCH_SIZE, NUM_POINT)) 144 | batch_point_index = np.zeros((BATCH_SIZE, NUM_POINT)) 145 | for sbatch in range(s_batch_num): 146 | start_idx = sbatch * BATCH_SIZE 147 | end_idx = min((sbatch + 1)*BATCH_SIZE, num_blocks) 148 | real_batch_size = end_idx - start_idx 149 | batch_data[0:real_batch_size,...] = scene_data[start_idx:end_idx, ...] 150 | batch_label[0:real_batch_size,...] = scene_label[start_idx:end_idx, ...] 151 | batch_point_index[0:real_batch_size,...] = scene_point_index[start_idx:end_idx, ...] 152 | 153 | if WITH_RGB: 154 | batch_data[:, :, 3:6] /= 1.0 #255.0 155 | 156 | feed_dict = {ops['pointclouds_pl']: batch_data, 157 | ops['labels_pl']: batch_label, 158 | ops['is_training_pl']: is_training} 159 | pred_val = sess.run(ops['pred'], feed_dict=feed_dict)#BxNxNUM_CLASSES 160 | batch_pred_label = np.argmax(pred_val[:, :, 1:], 2) + 1#BxN 161 | vote_label_pool = add_vote(vote_label_pool, batch_point_index[0:real_batch_size,...], batch_pred_label[0:real_batch_size,...]) 162 | 163 | pred_label = np.argmax(vote_label_pool, 1) 164 | for l in range(NUM_CLASSES): 165 | total_seen_class[l] += np.sum((whole_scene_label==l)) 166 | total_correct_class[l] += np.sum((pred_label==l) & (whole_scene_label==l)) 167 | total_iou_deno_class[l] += np.sum(((pred_label==l) | (whole_scene_label==l)) & (whole_scene_label > 0)) 168 | 169 | 170 | print(total_correct_class) 171 | print(total_iou_deno_class) 172 | print(total_seen_class) 173 | whole_scene_data = np.zeros(whole_scene_points_num) 174 | whole_scene_data[whole_scene_points_index] = test_class[pred_label.astype(np.int32)] 175 | 176 | filename = os.path.join(DUMP_DIR, scene_id[batch_idx] + '.txt') 177 | with open(filename, 'w') as pl_save: 178 | for i in whole_scene_data: 179 | pl_save.write(str(int(i))+'\n') 180 | pl_save.close() 181 | 182 | pred_file = filename 183 | mesh_file = os.path.join(PLY_PATH, scene_id[batch_idx], scene_id[batch_idx]+ '_vh_clean_2.ply') 184 | output_file = os.path.join(DUMP_DIR, scene_id[batch_idx] + '.ply') 185 | visualize(pred_file, mesh_file, output_file) 186 | 187 | IoU = np.array(total_correct_class[1:])/(np.array(total_iou_deno_class[1:],dtype=np.float)+1e-6) 188 | log_string('eval point avg class IoU: %f' % (np.mean(IoU))) 189 | IoU_Class = 'Each Class IoU:::\n' 190 | for i in range(IoU.shape[0]): 191 | IoU_Class += 'Class %d : %.4f\n'%(i+1, IoU[i]) 192 | log_string(IoU_Class) 193 | 194 | print("Done!") 195 | 196 | if __name__=='__main__': 197 | with tf.Graph().as_default(): 198 | evaluate(num_votes=FLAGS.num_votes) 199 | LOG_FOUT.close() 200 | -------------------------------------------------------------------------------- /imgs/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DylanWusee/pointconv/f39dc3e101af2f52544181ee20c14f73279b48ae/imgs/example.png -------------------------------------------------------------------------------- /models/pointconv_weight_density_n16.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(__file__) 4 | sys.path.append(BASE_DIR) 5 | sys.path.append(os.path.join(BASE_DIR, '../')) 6 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 7 | import tensorflow as tf 8 | import numpy as np 9 | import tf_util 10 | from PointConv import feature_encoding_layer, feature_decoding_layer 11 | 12 | def placeholder_inputs(batch_size, num_point): 13 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 14 | labels_pl = tf.placeholder(tf.int32, shape=(batch_size, num_point)) 15 | smpws_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point)) 16 | return pointclouds_pl, labels_pl, smpws_pl 17 | 18 | 19 | def get_model(point_cloud, is_training, num_class, sigma, bn_decay=None, weight_decay = None): 20 | """ Semantic segmentation PointNet, input is BxNx3, output Bxnum_class """ 21 | 22 | batch_size = point_cloud.get_shape()[0].value 23 | num_point = point_cloud.get_shape()[1].value 24 | end_points = {} 25 | l0_xyz = point_cloud 26 | l0_points = point_cloud 27 | 28 | # Feature encoding layers 29 | l1_xyz, l1_points = feature_encoding_layer(l0_xyz, l0_points, npoint=1024, radius = 0.1, sigma = sigma, K=32, mlp=[32,32,64], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer1') 30 | l2_xyz, l2_points = feature_encoding_layer(l1_xyz, l1_points, npoint=256, radius = 0.2, sigma = 2 * sigma, K=32, mlp=[64,64,128], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer2') 31 | l3_xyz, l3_points = feature_encoding_layer(l2_xyz, l2_points, npoint=64, radius = 0.4, sigma = 4 * sigma, K=32, mlp=[128,128,256], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer3') 32 | l4_xyz, l4_points = feature_encoding_layer(l3_xyz, l3_points, npoint=36, radius = 0.8, sigma = 8 * sigma, K=32, mlp=[256,256,512], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer4') 33 | 34 | # Feature decoding layers 35 | l3_points = feature_decoding_layer(l3_xyz, l4_xyz, l3_points, l4_points, 0.8, 8 * sigma, 16, [512,512], is_training, bn_decay, weight_decay, scope='fa_layer1') 36 | l2_points = feature_decoding_layer(l2_xyz, l3_xyz, l2_points, l3_points, 0.4, 4 * sigma, 16, [256,256], is_training, bn_decay, weight_decay, scope='fa_layer2') 37 | l1_points = feature_decoding_layer(l1_xyz, l2_xyz, l1_points, l2_points, 0.2, 2 * sigma, 16, [256,128], is_training, bn_decay, weight_decay, scope='fa_layer3') 38 | l0_points = feature_decoding_layer(l0_xyz, l1_xyz, l0_points, l1_points, 0.1, sigma, 16, [128,128,128], is_training, bn_decay, weight_decay, scope='fa_layer4') 39 | 40 | # FC layers 41 | net = tf_util.conv1d(l0_points, 128, 1, padding='VALID', bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay, weight_decay=weight_decay) 42 | end_points['feats'] = net 43 | net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1') 44 | net = tf_util.conv1d(net, num_class, 1, padding='VALID', activation_fn=None, weight_decay=weight_decay, scope='fc2') 45 | 46 | return net, end_points 47 | 48 | 49 | def get_loss(pred, label, smpw): 50 | """ pred: BxNxC, 51 | label: BxN, 52 | smpw: BxN """ 53 | classify_loss = tf.losses.sparse_softmax_cross_entropy(labels=label, logits=pred, weights=smpw) 54 | weight_reg = tf.add_n(tf.get_collection('losses')) 55 | classify_loss_mean = tf.reduce_mean(classify_loss, name='classify_loss_mean') 56 | total_loss = classify_loss_mean + weight_reg 57 | tf.summary.scalar('classify loss', classify_loss) 58 | tf.summary.scalar('total loss', total_loss) 59 | return total_loss 60 | 61 | if __name__=='__main__': 62 | import pdb 63 | pdb.set_trace() 64 | 65 | with tf.Graph().as_default(): 66 | inputs = tf.zeros((32,2048,3)) 67 | net, _ = get_model(inputs, tf.constant(True), 10, 1.0) 68 | print(net) 69 | -------------------------------------------------------------------------------- /models/pointconv_weight_density_n16_dp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(__file__) 4 | sys.path.append(BASE_DIR) 5 | sys.path.append(os.path.join(BASE_DIR, '../')) 6 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 7 | import tensorflow as tf 8 | import numpy as np 9 | import tf_util 10 | from PointConv import feature_encoding_layer, feature_decoding_layer_depthwise 11 | 12 | def placeholder_inputs(batch_size, num_point): 13 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 14 | labels_pl = tf.placeholder(tf.int32, shape=(batch_size, num_point)) 15 | smpws_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point)) 16 | return pointclouds_pl, labels_pl, smpws_pl 17 | 18 | 19 | def get_model(point_cloud, is_training, num_class, sigma, bn_decay=None, weight_decay = None): 20 | """ Semantic segmentation PointNet, input is BxNx3, output Bxnum_class """ 21 | 22 | batch_size = point_cloud.get_shape()[0].value 23 | num_point = point_cloud.get_shape()[1].value 24 | end_points = {} 25 | l0_xyz = point_cloud 26 | l0_points = point_cloud 27 | 28 | # Feature encoding layers 29 | l1_xyz, l1_points = feature_encoding_layer(l0_xyz, l0_points, npoint=1024, radius = 0.1, sigma = sigma, K=32, mlp=[32,32,64], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer1') 30 | l2_xyz, l2_points = feature_encoding_layer(l1_xyz, l1_points, npoint=256, radius = 0.2, sigma = 2 * sigma, K=32, mlp=[64,64,128], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer2') 31 | l3_xyz, l3_points = feature_encoding_layer(l2_xyz, l2_points, npoint=64, radius = 0.4, sigma = 4 * sigma, K=32, mlp=[128,128,256], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer3') 32 | l4_xyz, l4_points = feature_encoding_layer(l3_xyz, l3_points, npoint=36, radius = 0.8, sigma = 8 * sigma, K=32, mlp=[256,256,512], is_training=is_training, bn_decay=bn_decay, weight_decay = weight_decay, scope='layer4') 33 | 34 | # Feature decoding layers 35 | l3_points = feature_decoding_layer_depthwise(l3_xyz, l4_xyz, l3_points, l4_points, 0.8, 8 * sigma, 16, [512,512], is_training, bn_decay, weight_decay, scope='fa_layer1') 36 | l2_points = feature_decoding_layer_depthwise(l2_xyz, l3_xyz, l2_points, l3_points, 0.4, 4 * sigma, 16, [256,256], is_training, bn_decay, weight_decay, scope='fa_layer2') 37 | l1_points = feature_decoding_layer_depthwise(l1_xyz, l2_xyz, l1_points, l2_points, 0.2, 2 * sigma, 16, [256,128], is_training, bn_decay, weight_decay, scope='fa_layer3') 38 | l0_points = feature_decoding_layer_depthwise(l0_xyz, l1_xyz, l0_points, l1_points, 0.1, sigma, 16, [128,128,128], is_training, bn_decay, weight_decay, scope='fa_layer4') 39 | 40 | # FC layers 41 | net = tf_util.conv1d(l0_points, 128, 1, padding='VALID', bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay, weight_decay=weight_decay) 42 | end_points['feats'] = net 43 | net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1') 44 | net = tf_util.conv1d(net, num_class, 1, padding='VALID', is_training=is_training, activation_fn=None, weight_decay=weight_decay, scope='fc2') 45 | 46 | return net, end_points 47 | 48 | 49 | def get_loss(pred, label, smpw): 50 | """ pred: BxNxC, 51 | label: BxN, 52 | smpw: BxN """ 53 | classify_loss = tf.losses.sparse_softmax_cross_entropy(labels=label, logits=pred, weights=smpw) 54 | weight_reg = tf.add_n(tf.get_collection('losses')) 55 | classify_loss_mean = tf.reduce_mean(classify_loss, name='classify_loss_mean') 56 | total_loss = classify_loss_mean + weight_reg 57 | tf.summary.scalar('classify loss', classify_loss) 58 | tf.summary.scalar('total loss', total_loss) 59 | return total_loss 60 | 61 | if __name__=='__main__': 62 | import pdb 63 | pdb.set_trace() 64 | 65 | with tf.Graph().as_default(): 66 | inputs = tf.zeros((32,2048,3)) 67 | net, _ = get_model(inputs, tf.constant(True), 10, 1.0) 68 | print(net) 69 | -------------------------------------------------------------------------------- /scannet/README.md: -------------------------------------------------------------------------------- 1 | ## ScanNet v2 Data 2 | 3 | Please download original dataset from weibsite: http://www.scan-net.org/ 4 | 5 | To prepare the Scannet dataset for training and evaluation, modity [line 82](https://github.com/DylanWusee/pointconv/blob/2a59507ef8798d52225885865ecc4b50face78c9/scannet/scannetv2_seg_dataset_rgb21c_pointid.py#L82) in `scannetv2_seg_dataset_rgb21c_pointid.py` to your ScanNet v2 dataset path. 6 | 7 | Then, 8 | 9 | ``` 10 | python scannetv2_seg_dataset_rgb21c_pointid.py 11 | ``` 12 | 13 | This will generate three pickle files: `scannet_train_rgb21c_pointid.pickle`, `scannet_val_rgb21c_pointid.pickle`, and `scannet_test_rgb21c_pointid.pickle`. The first two are used in training and validation. 14 | -------------------------------------------------------------------------------- /scannet/eulerangles.py: -------------------------------------------------------------------------------- 1 | # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 2 | # vi: set ft=python sts=4 ts=4 sw=4 et: 3 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 4 | # 5 | # See COPYING file distributed along with the NiBabel package for the 6 | # copyright and license terms. 7 | # 8 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 9 | ''' Module implementing Euler angle rotations and their conversions 10 | 11 | See: 12 | 13 | * http://en.wikipedia.org/wiki/Rotation_matrix 14 | * http://en.wikipedia.org/wiki/Euler_angles 15 | * http://mathworld.wolfram.com/EulerAngles.html 16 | 17 | See also: *Representing Attitude with Euler Angles and Quaternions: A 18 | Reference* (2006) by James Diebel. A cached PDF link last found here: 19 | 20 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134 21 | 22 | Euler's rotation theorem tells us that any rotation in 3D can be 23 | described by 3 angles. Let's call the 3 angles the *Euler angle vector* 24 | and call the angles in the vector :math:`alpha`, :math:`beta` and 25 | :math:`gamma`. The vector is [ :math:`alpha`, 26 | :math:`beta`. :math:`gamma` ] and, in this description, the order of the 27 | parameters specifies the order in which the rotations occur (so the 28 | rotation corresponding to :math:`alpha` is applied first). 29 | 30 | In order to specify the meaning of an *Euler angle vector* we need to 31 | specify the axes around which each of the rotations corresponding to 32 | :math:`alpha`, :math:`beta` and :math:`gamma` will occur. 33 | 34 | There are therefore three axes for the rotations :math:`alpha`, 35 | :math:`beta` and :math:`gamma`; let's call them :math:`i` :math:`j`, 36 | :math:`k`. 37 | 38 | Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3 39 | rotation matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3 40 | matrix `B` and :math:`gamma` around `k` becomes matrix `G`. Then the 41 | whole rotation expressed by the Euler angle vector [ :math:`alpha`, 42 | :math:`beta`. :math:`gamma` ], `R` is given by:: 43 | 44 | R = np.dot(G, np.dot(B, A)) 45 | 46 | See http://mathworld.wolfram.com/EulerAngles.html 47 | 48 | The order :math:`G B A` expresses the fact that the rotations are 49 | performed in the order of the vector (:math:`alpha` around axis `i` = 50 | `A` first). 51 | 52 | To convert a given Euler angle vector to a meaningful rotation, and a 53 | rotation matrix, we need to define: 54 | 55 | * the axes `i`, `j`, `k` 56 | * whether a rotation matrix should be applied on the left of a vector to 57 | be transformed (vectors are column vectors) or on the right (vectors 58 | are row vectors). 59 | * whether the rotations move the axes as they are applied (intrinsic 60 | rotations) - compared the situation where the axes stay fixed and the 61 | vectors move within the axis frame (extrinsic) 62 | * the handedness of the coordinate system 63 | 64 | See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities 65 | 66 | We are using the following conventions: 67 | 68 | * axes `i`, `j`, `k` are the `z`, `y`, and `x` axes respectively. Thus 69 | an Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ] 70 | in our convention implies a :math:`alpha` radian rotation around the 71 | `z` axis, followed by a :math:`beta` rotation around the `y` axis, 72 | followed by a :math:`gamma` rotation around the `x` axis. 73 | * the rotation matrix applies on the left, to column vectors on the 74 | right, so if `R` is the rotation matrix, and `v` is a 3 x N matrix 75 | with N column vectors, the transformed vector set `vdash` is given by 76 | ``vdash = np.dot(R, v)``. 77 | * extrinsic rotations - the axes are fixed, and do not move with the 78 | rotations. 79 | * a right-handed coordinate system 80 | 81 | The convention of rotation around ``z``, followed by rotation around 82 | ``y``, followed by rotation around ``x``, is known (confusingly) as 83 | "xyz", pitch-roll-yaw, Cardan angles, or Tait-Bryan angles. 84 | ''' 85 | 86 | import math 87 | 88 | import sys 89 | if sys.version_info >= (3,0): 90 | from functools import reduce 91 | 92 | import numpy as np 93 | 94 | 95 | _FLOAT_EPS_4 = np.finfo(float).eps * 4.0 96 | 97 | 98 | def euler2mat(z=0, y=0, x=0): 99 | ''' Return matrix for rotations around z, y and x axes 100 | 101 | Uses the z, then y, then x convention above 102 | 103 | Parameters 104 | ---------- 105 | z : scalar 106 | Rotation angle in radians around z-axis (performed first) 107 | y : scalar 108 | Rotation angle in radians around y-axis 109 | x : scalar 110 | Rotation angle in radians around x-axis (performed last) 111 | 112 | Returns 113 | ------- 114 | M : array shape (3,3) 115 | Rotation matrix giving same rotation as for given angles 116 | 117 | Examples 118 | -------- 119 | >>> zrot = 1.3 # radians 120 | >>> yrot = -0.1 121 | >>> xrot = 0.2 122 | >>> M = euler2mat(zrot, yrot, xrot) 123 | >>> M.shape == (3, 3) 124 | True 125 | 126 | The output rotation matrix is equal to the composition of the 127 | individual rotations 128 | 129 | >>> M1 = euler2mat(zrot) 130 | >>> M2 = euler2mat(0, yrot) 131 | >>> M3 = euler2mat(0, 0, xrot) 132 | >>> composed_M = np.dot(M3, np.dot(M2, M1)) 133 | >>> np.allclose(M, composed_M) 134 | True 135 | 136 | You can specify rotations by named arguments 137 | 138 | >>> np.all(M3 == euler2mat(x=xrot)) 139 | True 140 | 141 | When applying M to a vector, the vector should column vector to the 142 | right of M. If the right hand side is a 2D array rather than a 143 | vector, then each column of the 2D array represents a vector. 144 | 145 | >>> vec = np.array([1, 0, 0]).reshape((3,1)) 146 | >>> v2 = np.dot(M, vec) 147 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array 148 | >>> vecs2 = np.dot(M, vecs) 149 | 150 | Rotations are counter-clockwise. 151 | 152 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) 153 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) 154 | True 155 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) 156 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) 157 | True 158 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) 159 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) 160 | True 161 | 162 | Notes 163 | ----- 164 | The direction of rotation is given by the right-hand rule (orient 165 | the thumb of the right hand along the axis around which the rotation 166 | occurs, with the end of the thumb at the positive end of the axis; 167 | curl your fingers; the direction your fingers curl is the direction 168 | of rotation). Therefore, the rotations are counterclockwise if 169 | looking along the axis of rotation from positive to negative. 170 | ''' 171 | Ms = [] 172 | if z: 173 | cosz = math.cos(z) 174 | sinz = math.sin(z) 175 | Ms.append(np.array( 176 | [[cosz, -sinz, 0], 177 | [sinz, cosz, 0], 178 | [0, 0, 1]])) 179 | if y: 180 | cosy = math.cos(y) 181 | siny = math.sin(y) 182 | Ms.append(np.array( 183 | [[cosy, 0, siny], 184 | [0, 1, 0], 185 | [-siny, 0, cosy]])) 186 | if x: 187 | cosx = math.cos(x) 188 | sinx = math.sin(x) 189 | Ms.append(np.array( 190 | [[1, 0, 0], 191 | [0, cosx, -sinx], 192 | [0, sinx, cosx]])) 193 | if Ms: 194 | return reduce(np.dot, Ms[::-1]) 195 | return np.eye(3) 196 | 197 | 198 | def mat2euler(M, cy_thresh=None): 199 | ''' Discover Euler angle vector from 3x3 matrix 200 | 201 | Uses the conventions above. 202 | 203 | Parameters 204 | ---------- 205 | M : array-like, shape (3,3) 206 | cy_thresh : None or scalar, optional 207 | threshold below which to give up on straightforward arctan for 208 | estimating x rotation. If None (default), estimate from 209 | precision of input. 210 | 211 | Returns 212 | ------- 213 | z : scalar 214 | y : scalar 215 | x : scalar 216 | Rotations in radians around z, y, x axes, respectively 217 | 218 | Notes 219 | ----- 220 | If there was no numerical error, the routine could be derived using 221 | Sympy expression for z then y then x rotation matrix, which is:: 222 | 223 | [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)], 224 | [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)], 225 | [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)] 226 | 227 | with the obvious derivations for z, y, and x 228 | 229 | z = atan2(-r12, r11) 230 | y = asin(r13) 231 | x = atan2(-r23, r33) 232 | 233 | Problems arise when cos(y) is close to zero, because both of:: 234 | 235 | z = atan2(cos(y)*sin(z), cos(y)*cos(z)) 236 | x = atan2(cos(y)*sin(x), cos(x)*cos(y)) 237 | 238 | will be close to atan2(0, 0), and highly unstable. 239 | 240 | The ``cy`` fix for numerical instability below is from: *Graphics 241 | Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN: 242 | 0123361559. Specifically it comes from EulerAngles.c by Ken 243 | Shoemake, and deals with the case where cos(y) is close to zero: 244 | 245 | See: http://www.graphicsgems.org/ 246 | 247 | The code appears to be licensed (from the website) as "can be used 248 | without restrictions". 249 | ''' 250 | M = np.asarray(M) 251 | if cy_thresh is None: 252 | try: 253 | cy_thresh = np.finfo(M.dtype).eps * 4 254 | except ValueError: 255 | cy_thresh = _FLOAT_EPS_4 256 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat 257 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) 258 | cy = math.sqrt(r33*r33 + r23*r23) 259 | if cy > cy_thresh: # cos(y) not close to zero, standard form 260 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) 261 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 262 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) 263 | else: # cos(y) (close to) zero, so x -> 0.0 (see above) 264 | # so r21 -> sin(z), r22 -> cos(z) and 265 | z = math.atan2(r21, r22) 266 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 267 | x = 0.0 268 | return z, y, x 269 | 270 | 271 | def euler2quat(z=0, y=0, x=0): 272 | ''' Return quaternion corresponding to these Euler angles 273 | 274 | Uses the z, then y, then x convention above 275 | 276 | Parameters 277 | ---------- 278 | z : scalar 279 | Rotation angle in radians around z-axis (performed first) 280 | y : scalar 281 | Rotation angle in radians around y-axis 282 | x : scalar 283 | Rotation angle in radians around x-axis (performed last) 284 | 285 | Returns 286 | ------- 287 | quat : array shape (4,) 288 | Quaternion in w, x, y z (real, then vector) format 289 | 290 | Notes 291 | ----- 292 | We can derive this formula in Sympy using: 293 | 294 | 1. Formula giving quaternion corresponding to rotation of theta radians 295 | about arbitrary axis: 296 | http://mathworld.wolfram.com/EulerParameters.html 297 | 2. Generated formulae from 1.) for quaternions corresponding to 298 | theta radians rotations about ``x, y, z`` axes 299 | 3. Apply quaternion multiplication formula - 300 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to 301 | formulae from 2.) to give formula for combined rotations. 302 | ''' 303 | z = z/2.0 304 | y = y/2.0 305 | x = x/2.0 306 | cz = math.cos(z) 307 | sz = math.sin(z) 308 | cy = math.cos(y) 309 | sy = math.sin(y) 310 | cx = math.cos(x) 311 | sx = math.sin(x) 312 | return np.array([ 313 | cx*cy*cz - sx*sy*sz, 314 | cx*sy*sz + cy*cz*sx, 315 | cx*cz*sy - sx*cy*sz, 316 | cx*cy*sz + sx*cz*sy]) 317 | 318 | 319 | def quat2euler(q): 320 | ''' Return Euler angles corresponding to quaternion `q` 321 | 322 | Parameters 323 | ---------- 324 | q : 4 element sequence 325 | w, x, y, z of quaternion 326 | 327 | Returns 328 | ------- 329 | z : scalar 330 | Rotation angle in radians around z-axis (performed first) 331 | y : scalar 332 | Rotation angle in radians around y-axis 333 | x : scalar 334 | Rotation angle in radians around x-axis (performed last) 335 | 336 | Notes 337 | ----- 338 | It's possible to reduce the amount of calculation a little, by 339 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but 340 | the reduction in computation is small, and the code repetition is 341 | large. 342 | ''' 343 | # delayed import to avoid cyclic dependencies 344 | import nibabel.quaternions as nq 345 | return mat2euler(nq.quat2mat(q)) 346 | 347 | 348 | def euler2angle_axis(z=0, y=0, x=0): 349 | ''' Return angle, axis corresponding to these Euler angles 350 | 351 | Uses the z, then y, then x convention above 352 | 353 | Parameters 354 | ---------- 355 | z : scalar 356 | Rotation angle in radians around z-axis (performed first) 357 | y : scalar 358 | Rotation angle in radians around y-axis 359 | x : scalar 360 | Rotation angle in radians around x-axis (performed last) 361 | 362 | Returns 363 | ------- 364 | theta : scalar 365 | angle of rotation 366 | vector : array shape (3,) 367 | axis around which rotation occurs 368 | 369 | Examples 370 | -------- 371 | >>> theta, vec = euler2angle_axis(0, 1.5, 0) 372 | >>> print(theta) 373 | 1.5 374 | >>> np.allclose(vec, [0, 1, 0]) 375 | True 376 | ''' 377 | # delayed import to avoid cyclic dependencies 378 | import nibabel.quaternions as nq 379 | return nq.quat2angle_axis(euler2quat(z, y, x)) 380 | 381 | 382 | def angle_axis2euler(theta, vector, is_normalized=False): 383 | ''' Convert angle, axis pair to Euler angles 384 | 385 | Parameters 386 | ---------- 387 | theta : scalar 388 | angle of rotation 389 | vector : 3 element sequence 390 | vector specifying axis for rotation. 391 | is_normalized : bool, optional 392 | True if vector is already normalized (has norm of 1). Default 393 | False 394 | 395 | Returns 396 | ------- 397 | z : scalar 398 | y : scalar 399 | x : scalar 400 | Rotations in radians around z, y, x axes, respectively 401 | 402 | Examples 403 | -------- 404 | >>> z, y, x = angle_axis2euler(0, [1, 0, 0]) 405 | >>> np.allclose((z, y, x), 0) 406 | True 407 | 408 | Notes 409 | ----- 410 | It's possible to reduce the amount of calculation a little, by 411 | combining parts of the ``angle_axis2mat`` and ``mat2euler`` 412 | functions, but the reduction in computation is small, and the code 413 | repetition is large. 414 | ''' 415 | # delayed import to avoid cyclic dependencies 416 | import nibabel.quaternions as nq 417 | M = nq.angle_axis2mat(theta, vector, is_normalized) 418 | return mat2euler(M) 419 | -------------------------------------------------------------------------------- /scannet/pc_util.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for processing point clouds. 2 | 3 | Author: Charles R. Qi, Hao Su 4 | Date: November 2016 5 | """ 6 | 7 | import os 8 | import sys 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | 12 | # Draw point cloud 13 | from eulerangles import euler2mat 14 | 15 | # Point cloud IO 16 | import numpy as np 17 | from plyfile import PlyData, PlyElement 18 | 19 | 20 | # ---------------------------------------- 21 | # Point Cloud/Volume Conversions 22 | # ---------------------------------------- 23 | def point_cloud_label_to_surface_voxel_label(point_cloud, label, res=0.0484): 24 | coordmax = np.max(point_cloud,axis=0) 25 | coordmin = np.min(point_cloud,axis=0) 26 | nvox = np.ceil((coordmax-coordmin)/res) 27 | vidx = np.ceil((point_cloud-coordmin)/res) 28 | vidx = vidx[:,0]+vidx[:,1]*nvox[0]+vidx[:,2]*nvox[0]*nvox[1] 29 | uvidx = np.unique(vidx) 30 | if label.ndim==1: 31 | uvlabel = [np.argmax(np.bincount(label[vidx==uv].astype(np.uint32))) for uv in uvidx] 32 | else: 33 | assert(label.ndim==2) 34 | uvlabel = np.zeros(len(uvidx),label.shape[1]) 35 | for i in range(label.shape[1]): 36 | uvlabel[:,i] = np.array([np.argmax(np.bincount(label[vidx==uv,i].astype(np.uint32))) for uv in uvidx]) 37 | return uvidx, uvlabel, nvox 38 | 39 | def point_cloud_label_to_surface_voxel_label_fast(point_cloud, label, res=0.0484): 40 | coordmax = np.max(point_cloud,axis=0) 41 | coordmin = np.min(point_cloud,axis=0) 42 | nvox = np.ceil((coordmax-coordmin)/res) 43 | vidx = np.ceil((point_cloud-coordmin)/res) 44 | vidx = vidx[:,0]+vidx[:,1]*nvox[0]+vidx[:,2]*nvox[0]*nvox[1] 45 | uvidx, vpidx = np.unique(vidx,return_index=True) 46 | if label.ndim==1: 47 | uvlabel = label[vpidx] 48 | else: 49 | assert(label.ndim==2) 50 | uvlabel = label[vpidx,:] 51 | return uvidx, uvlabel, nvox 52 | 53 | def point_cloud_to_volume_batch(point_clouds, vsize=12, radius=1.0, flatten=True): 54 | """ Input is BxNx3 batch of point cloud 55 | Output is Bx(vsize^3) 56 | """ 57 | vol_list = [] 58 | for b in range(point_clouds.shape[0]): 59 | vol = point_cloud_to_volume(np.squeeze(point_clouds[b,:,:]), vsize, radius) 60 | if flatten: 61 | vol_list.append(vol.flatten()) 62 | else: 63 | vol_list.append(np.expand_dims(np.expand_dims(vol, -1), 0)) 64 | if flatten: 65 | return np.vstack(vol_list) 66 | else: 67 | return np.concatenate(vol_list, 0) 68 | 69 | 70 | def point_cloud_to_volume(points, vsize, radius=1.0): 71 | """ input is Nx3 points. 72 | output is vsize*vsize*vsize 73 | assumes points are in range [-radius, radius] 74 | """ 75 | vol = np.zeros((vsize,vsize,vsize)) 76 | voxel = 2*radius/float(vsize) 77 | locations = (points + radius)/voxel 78 | locations = locations.astype(int) 79 | vol[locations[:,0],locations[:,1],locations[:,2]] = 1.0 80 | return vol 81 | 82 | #a = np.zeros((16,1024,3)) 83 | #print point_cloud_to_volume_batch(a, 12, 1.0, False).shape 84 | 85 | def volume_to_point_cloud(vol): 86 | """ vol is occupancy grid (value = 0 or 1) of size vsize*vsize*vsize 87 | return Nx3 numpy array. 88 | """ 89 | vsize = vol.shape[0] 90 | assert(vol.shape[1] == vsize and vol.shape[1] == vsize) 91 | points = [] 92 | for a in range(vsize): 93 | for b in range(vsize): 94 | for c in range(vsize): 95 | if vol[a,b,c] == 1: 96 | points.append(np.array([a,b,c])) 97 | if len(points) == 0: 98 | return np.zeros((0,3)) 99 | points = np.vstack(points) 100 | return points 101 | 102 | def point_cloud_to_volume_v2_batch(point_clouds, vsize=12, radius=1.0, num_sample=128): 103 | """ Input is BxNx3 a batch of point cloud 104 | Output is BxVxVxVxnum_samplex3 105 | Added on Feb 19 106 | """ 107 | vol_list = [] 108 | for b in range(point_clouds.shape[0]): 109 | vol = point_cloud_to_volume_v2(point_clouds[b,:,:], vsize, radius, num_sample) 110 | vol_list.append(np.expand_dims(vol, 0)) 111 | return np.concatenate(vol_list, 0) 112 | 113 | def point_cloud_to_volume_v2(points, vsize, radius=1.0, num_sample=128): 114 | """ input is Nx3 points 115 | output is vsize*vsize*vsize*num_sample*3 116 | assumes points are in range [-radius, radius] 117 | samples num_sample points in each voxel, if there are less than 118 | num_sample points, replicate the points 119 | Added on Feb 19 120 | """ 121 | vol = np.zeros((vsize,vsize,vsize,num_sample,3)) 122 | voxel = 2*radius/float(vsize) 123 | locations = (points + radius)/voxel 124 | locations = locations.astype(int) 125 | loc2pc = {} 126 | for n in range(points.shape[0]): 127 | loc = tuple(locations[n,:]) 128 | if loc not in loc2pc: 129 | loc2pc[loc] = [] 130 | loc2pc[loc].append(points[n,:]) 131 | #print loc2pc 132 | 133 | for i in range(vsize): 134 | for j in range(vsize): 135 | for k in range(vsize): 136 | if (i,j,k) not in loc2pc: 137 | vol[i,j,k,:,:] = np.zeros((num_sample,3)) 138 | else: 139 | pc = loc2pc[(i,j,k)] # a list of (3,) arrays 140 | pc = np.vstack(pc) # kx3 141 | # Sample/pad to num_sample points 142 | if pc.shape[0]>num_sample: 143 | choices = np.random.choice(pc.shape[0], num_sample, replace=False) 144 | pc = pc[choices,:] 145 | elif pc.shape[0]num_sample: 193 | choices = np.random.choice(pc.shape[0], num_sample, replace=False) 194 | pc = pc[choices,:] 195 | elif pc.shape[0]np.max(labels)) 242 | colors = np.array([pyplot.cm.hsv(i/float(num_classes)) for i in range(num_classes)]) 243 | 244 | new_colors = [] 245 | for i in range(points.shape[0]): 246 | c = colors[labels[i]] 247 | c = [int(x*255) for x in c] 248 | new_colors.append(c) 249 | 250 | colors = np.array(new_colors) 251 | points = [(points[i,0], points[i,1], points[i,2], colors[i,0], colors[i,1], colors[i,2]) for i in range(points.shape[0])] 252 | 253 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4'), ('red', 'u1'), ('green', 'u1'),('blue', 'u1')]) 254 | el = PlyElement.describe(vertex, 'vertex') 255 | PlyData([el], text=text).write(filename) 256 | 257 | import matplotlib.pyplot as pyplot 258 | 259 | def getColor(labels, num_classes): 260 | colors = np.array([pyplot.cm.hsv(i/float(num_classes)) for i in range(num_classes)]) 261 | 262 | new_colors = [] 263 | for i in range(labels.shape[0]): 264 | c = colors[labels[i]] 265 | c = [int(x*255) for x in c] 266 | new_colors.append(c) 267 | 268 | colors = np.array(new_colors, dtype = np.float32) 269 | 270 | return colors 271 | 272 | def write_ply_label2(points, labels, filename, num_classes=None, text=True): 273 | """ Color (N,3) points with labels (N) within range 0 ~ num_classes-1 as OBJ file """ 274 | import matplotlib.pyplot as pyplot 275 | labels = labels.astype(int) 276 | N = points.shape[0] 277 | 278 | colors = getColor(labels, num_classes) 279 | c0=colors[:,0] 280 | c1=colors[:,1] 281 | c2=colors[:,2] 282 | c0/=(c0.max()+1e-14)/255.0 283 | c1/=(c1.max()+1e-14)/255.0 284 | c2/=(c2.max()+1e-14)/255.0 285 | 286 | c0=np.require(c0,'float32','C') 287 | c1=np.require(c1,'float32','C') 288 | c2=np.require(c2,'float32','C') 289 | 290 | points = [(points[i,0], points[i,1], points[i,2], c0[i], c1[i], c2[i]) for i in range(points.shape[0])] 291 | 292 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4'), ('red', 'u1'), ('green', 'u1'),('blue', 'u1')]) 293 | el = PlyElement.describe(vertex, 'vertex') 294 | PlyData([el], text=text).write(filename) 295 | 296 | 297 | # ---------------------------------------- 298 | # Simple Point cloud and Volume Renderers 299 | # ---------------------------------------- 300 | 301 | def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25, 302 | xrot=0, yrot=0, zrot=0, switch_xyz=[0,1,2], normalize=True): 303 | """ Render point cloud to image with alpha channel. 304 | Input: 305 | points: Nx3 numpy array (+y is up direction) 306 | Output: 307 | gray image as numpy array of size canvasSizexcanvasSize 308 | """ 309 | image = np.zeros((canvasSize, canvasSize)) 310 | if input_points is None or input_points.shape[0] == 0: 311 | return image 312 | 313 | points = input_points[:, switch_xyz] 314 | M = euler2mat(zrot, yrot, xrot) 315 | points = (np.dot(M, points.transpose())).transpose() 316 | 317 | # Normalize the point cloud 318 | # We normalize scale to fit points in a unit sphere 319 | if normalize: 320 | centroid = np.mean(points, axis=0) 321 | points -= centroid 322 | furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1))) 323 | points /= furthest_distance 324 | 325 | # Pre-compute the Gaussian disk 326 | radius = (diameter-1)/2.0 327 | disk = np.zeros((diameter, diameter)) 328 | for i in range(diameter): 329 | for j in range(diameter): 330 | if (i - radius) * (i-radius) + (j-radius) * (j-radius) <= radius * radius: 331 | disk[i, j] = np.exp((-(i-radius)**2 - (j-radius)**2)/(radius**2)) 332 | mask = np.argwhere(disk > 0) 333 | dx = mask[:, 0] 334 | dy = mask[:, 1] 335 | dv = disk[disk > 0] 336 | 337 | # Order points by z-buffer 338 | zorder = np.argsort(points[:, 2]) 339 | points = points[zorder, :] 340 | points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2]))) 341 | max_depth = np.max(points[:, 2]) 342 | 343 | for i in range(points.shape[0]): 344 | j = points.shape[0] - i - 1 345 | x = points[j, 0] 346 | y = points[j, 1] 347 | xc = canvasSize/2 + (x*space) 348 | yc = canvasSize/2 + (y*space) 349 | xc = int(np.round(xc)) 350 | yc = int(np.round(yc)) 351 | 352 | px = dx + xc 353 | py = dy + yc 354 | 355 | image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3 356 | 357 | image = image / np.max(image) 358 | return image 359 | 360 | def point_cloud_three_views(points): 361 | """ input points Nx3 numpy array (+y is up direction). 362 | return an numpy array gray image of size 500x1500. """ 363 | # +y is up direction 364 | # xrot is azimuth 365 | # yrot is in-plane 366 | # zrot is elevation 367 | img1 = draw_point_cloud(points, zrot=110/180.0*np.pi, xrot=45/180.0*np.pi, yrot=0/180.0*np.pi) 368 | img2 = draw_point_cloud(points, zrot=70/180.0*np.pi, xrot=135/180.0*np.pi, yrot=0/180.0*np.pi) 369 | img3 = draw_point_cloud(points, zrot=180.0/180.0*np.pi, xrot=90/180.0*np.pi, yrot=0/180.0*np.pi) 370 | image_large = np.concatenate([img1, img2, img3], 1) 371 | return image_large 372 | 373 | 374 | def point_cloud_three_views_demo(): 375 | """ Demo for draw_point_cloud function """ 376 | from PIL import Image 377 | points = read_ply('../third_party/mesh_sampling/piano.ply') 378 | im_array = point_cloud_three_views(points) 379 | img = Image.fromarray(np.uint8(im_array*255.0)) 380 | img.save('piano.jpg') 381 | 382 | if __name__=="__main__": 383 | point_cloud_three_views_demo() 384 | 385 | 386 | def pyplot_draw_point_cloud(points, output_filename): 387 | """ points is a Nx3 numpy array """ 388 | import matplotlib.pyplot as plt 389 | fig = plt.figure() 390 | ax = fig.add_subplot(111, projection='3d') 391 | ax.scatter(points[:,0], points[:,1], points[:,2]) 392 | ax.set_xlabel('x') 393 | ax.set_ylabel('y') 394 | ax.set_zlabel('z') 395 | #savefig(output_filename) 396 | 397 | def pyplot_draw_volume(vol, output_filename): 398 | """ vol is of size vsize*vsize*vsize 399 | output an image to output_filename 400 | """ 401 | points = volume_to_point_cloud(vol) 402 | pyplot_draw_point_cloud(points, output_filename) 403 | 404 | def write_ply_color(points, labels, out_filename, num_classes=None): 405 | """ Color (N,3) points with labels (N) within range 0 ~ num_classes-1 as OBJ file """ 406 | import matplotlib.pyplot as pyplot 407 | labels = labels.astype(int) 408 | N = points.shape[0] 409 | if num_classes is None: 410 | num_classes = np.max(labels)+1 411 | else: 412 | assert(num_classes>np.max(labels)) 413 | fout = open(out_filename, 'w') 414 | colors = [pyplot.cm.hsv(i/float(num_classes)) for i in range(num_classes)] 415 | for i in range(N): 416 | c = colors[labels[i]] 417 | c = [int(x*255) for x in c] 418 | fout.write('v %f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2],c[0],c[1],c[2])) 419 | fout.close() 420 | -------------------------------------------------------------------------------- /scannet/scannet_dataset_rgb.py: -------------------------------------------------------------------------------- 1 | """ ScanNet Class From Charles R. Qi, Hao Su. 2 | Modiyied to support rgb in ScanNet v2. 3 | Author: Wenxuan Wu 4 | Date: July 2018 5 | """ 6 | 7 | import pickle 8 | import os 9 | import sys 10 | import numpy as np 11 | import pc_util 12 | 13 | class ScannetDataset(): 14 | def __init__(self, root, block_points=8192, split='train', with_rgb = False): 15 | self.npoints = block_points 16 | self.root = root 17 | self.with_rgb = with_rgb 18 | self.split = split 19 | self.data_filename = os.path.join(self.root, 'scannet_%s_rgb21c_pointid.pickle'%(split)) 20 | with open(self.data_filename,'rb') as fp: 21 | self.scene_points_list = pickle.load(fp) 22 | self.semantic_labels_list = pickle.load(fp) 23 | self.scene_points_id = pickle.load(fp) 24 | self.scene_points_num = pickle.load(fp) 25 | if split=='train': 26 | labelweights = np.zeros(21) 27 | for seg in self.semantic_labels_list: 28 | tmp,_ = np.histogram(seg,range(22)) 29 | labelweights += tmp 30 | labelweights = labelweights.astype(np.float32) 31 | labelweights = labelweights/np.sum(labelweights) 32 | self.labelweights = np.power(np.amax(labelweights[1:]) / labelweights, 1/3.0) 33 | print(self.labelweights) 34 | elif split=='val': 35 | self.labelweights = np.ones(21) 36 | 37 | def __getitem__(self, index): 38 | if self.with_rgb: 39 | point_set = self.scene_points_list[index] 40 | else: 41 | point_set = self.scene_points_list[index][:, 0:3] 42 | semantic_seg = self.semantic_labels_list[index].astype(np.int32) 43 | coordmax = np.max(point_set[:, 0:3],axis=0) 44 | coordmin = np.min(point_set[:, 0:3],axis=0) 45 | isvalid = False 46 | for i in range(10): 47 | curcenter = point_set[np.random.choice(len(semantic_seg),1)[0],0:3] 48 | curmin = curcenter-[0.75,0.75,1.5] 49 | curmax = curcenter+[0.75,0.75,1.5] 50 | curmin[2] = coordmin[2] 51 | curmax[2] = coordmax[2] 52 | curchoice = np.sum((point_set[:, 0:3]>=(curmin-0.2))*(point_set[:, 0:3]<=(curmax+0.2)),axis=1)==3 53 | cur_point_set = point_set[curchoice,0:3] 54 | cur_point_full = point_set[curchoice,:] 55 | cur_semantic_seg = semantic_seg[curchoice] 56 | if len(cur_semantic_seg)==0: 57 | continue 58 | mask = np.sum((cur_point_set>=(curmin-0.01))*(cur_point_set<=(curmax+0.01)),axis=1)==3 59 | vidx = np.ceil((cur_point_set[mask,:]-curmin)/(curmax-curmin)*[31.0,31.0,62.0]) 60 | vidx = np.unique(vidx[:,0]*31.0*62.0+vidx[:,1]*62.0+vidx[:,2]) 61 | isvalid = np.sum(cur_semantic_seg>0)/len(cur_semantic_seg)>=0.7 and len(vidx)/31.0/31.0/62.0>=0.02 62 | if isvalid: 63 | break 64 | choice = np.random.choice(len(cur_semantic_seg), self.npoints, replace=True) 65 | point_set = cur_point_full[choice,:] 66 | semantic_seg = cur_semantic_seg[choice] 67 | mask = mask[choice] 68 | sample_weight = self.labelweights[semantic_seg] 69 | sample_weight *= mask 70 | return point_set, semantic_seg, sample_weight 71 | 72 | def __len__(self): 73 | return len(self.scene_points_list) 74 | 75 | class ScannetDatasetWholeScene(): 76 | def __init__(self, root, block_points=8192, split='val', with_rgb = False): 77 | self.npoints = block_points 78 | self.root = root 79 | self.with_rgb = with_rgb 80 | self.split = split 81 | self.data_filename = os.path.join(self.root, 'scannet_%s_rgb21c_pointid.pickle'%(split)) 82 | with open(self.data_filename,'rb') as fp: 83 | self.scene_points_list = pickle.load(fp) 84 | self.semantic_labels_list = pickle.load(fp) 85 | self.scene_points_id = pickle.load(fp) 86 | self.scene_points_num = pickle.load(fp) 87 | if split=='train': 88 | labelweights = np.zeros(21) 89 | for seg in self.semantic_labels_list: 90 | tmp,_ = np.histogram(seg,range(22)) 91 | labelweights += tmp 92 | labelweights = labelweights.astype(np.float32) 93 | labelweights = labelweights/np.sum(labelweights) 94 | self.labelweights = 1/np.log(1.2+labelweights) 95 | elif split=='val': 96 | self.labelweights = np.ones(21) 97 | 98 | def __getitem__(self, index): 99 | if self.with_rgb: 100 | point_set_ini = self.scene_points_list[index] 101 | else: 102 | point_set_ini = self.scene_points_list[index][:, 0:3] 103 | semantic_seg_ini = self.semantic_labels_list[index].astype(np.int32) 104 | coordmax = np.max(point_set_ini[:, 0:3],axis=0) 105 | coordmin = np.min(point_set_ini[:, 0:3],axis=0) 106 | nsubvolume_x = np.ceil((coordmax[0]-coordmin[0])/1.5).astype(np.int32) 107 | nsubvolume_y = np.ceil((coordmax[1]-coordmin[1])/1.5).astype(np.int32) 108 | point_sets = list() 109 | semantic_segs = list() 110 | sample_weights = list() 111 | for i in range(nsubvolume_x): 112 | for j in range(nsubvolume_y): 113 | curmin = coordmin+[i*1.5,j*1.5,0] 114 | curmax = coordmin+[(i+1)*1.5,(j+1)*1.5,coordmax[2]-coordmin[2]] 115 | curchoice = np.sum((point_set_ini[:, 0:3]>=(curmin-0.2))*(point_set_ini[:, 0:3]<=(curmax+0.2)),axis=1)==3 116 | cur_point_set = point_set_ini[curchoice,0:3] 117 | cur_point_full = point_set_ini[curchoice,:] 118 | cur_semantic_seg = semantic_seg_ini[curchoice] 119 | if len(cur_semantic_seg)==0: 120 | continue 121 | mask = np.sum((cur_point_set>=(curmin-0.001))*(cur_point_set<=(curmax+0.001)),axis=1)==3 122 | choice = np.random.choice(len(cur_semantic_seg), self.npoints, replace=True) 123 | point_set = cur_point_full[choice,:] # Nx3/6 124 | semantic_seg = cur_semantic_seg[choice] # N 125 | mask = mask[choice] 126 | if sum(mask)/float(len(mask))<0.01: 127 | continue 128 | sample_weight = self.labelweights[semantic_seg] 129 | sample_weight *= mask # N 130 | point_sets.append(np.expand_dims(point_set,0)) # 1xNx3 131 | semantic_segs.append(np.expand_dims(semantic_seg,0)) # 1xN 132 | sample_weights.append(np.expand_dims(sample_weight,0)) # 1xN 133 | point_sets = np.concatenate(tuple(point_sets),axis=0) 134 | semantic_segs = np.concatenate(tuple(semantic_segs),axis=0) 135 | sample_weights = np.concatenate(tuple(sample_weights),axis=0) 136 | return point_sets, semantic_segs, sample_weights 137 | 138 | def __len__(self): 139 | return len(self.scene_points_list) 140 | 141 | 142 | if __name__=='__main__': 143 | import pdb 144 | pdb.set_trace() 145 | d = ScannetDatasetWholeScene(root = './', split='val', block_points=8192) 146 | labelweights_vox = np.zeros(21) 147 | for ii in range(len(d)): 148 | print(ii) 149 | #ps,seg,smpw = d[ii] 150 | ps,seg,smpw = d[ii] 151 | for b in range(ps.shape[0]): 152 | _, uvlabel, _ = pc_util.point_cloud_label_to_surface_voxel_label_fast(ps[b,smpw[b,:]>0,:], seg[b,smpw[b,:]>0], res=0.02) 153 | tmp,_ = np.histogram(uvlabel,range(22)) 154 | labelweights_vox += tmp 155 | print(labelweights_vox[1:].astype(np.float32)/np.sum(labelweights_vox[1:].astype(np.float32))) 156 | exit() 157 | 158 | 159 | -------------------------------------------------------------------------------- /scannet/scannet_dataset_sw_rgb.py: -------------------------------------------------------------------------------- 1 | """ ScanNet Class From Charles R. Qi, Hao Su. 2 | Modiyied to support point-wise evaluation in ScanNet v2. 3 | Author: Wenxuan Wu 4 | Date: July 2018 5 | """ 6 | 7 | 8 | import pickle 9 | import os 10 | import sys 11 | import numpy as np 12 | 13 | class ScannetDatasetWholeScene_evaluation(): 14 | #prepare to give prediction on each points 15 | def __init__(self, root, split='test', num_class = 21, block_points = 8192, with_rgb = True): 16 | self.root = root 17 | self.split = split 18 | self.with_rgb = with_rgb 19 | self.block_points = block_points 20 | self.point_num = [] 21 | self.data_filename = os.path.join(self.root, 'scannet_%s_rgb21c_pointid.pickle'%(split)) 22 | with open(self.data_filename,'rb') as fp: 23 | self.scene_points_list = pickle.load(fp) 24 | self.semantic_labels_list = pickle.load(fp) 25 | self.scene_points_id = pickle.load(fp) 26 | self.scene_points_num = pickle.load(fp) 27 | if split=='train': 28 | labelweights = np.zeros(num_class) 29 | for seg in self.semantic_labels_list: 30 | self.point_num.append(seg.shape[0]) 31 | tmp,_ = np.histogram(seg,range(num_class+1)) 32 | labelweights += tmp 33 | labelweights = labelweights.astype(np.float32) 34 | labelweights = labelweights/np.sum(labelweights) 35 | #self.labelweights = 1/np.log(1.2+labelweights) 36 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1/3.0) 37 | else: 38 | self.labelweights = np.ones(num_class) 39 | for seg in self.semantic_labels_list: 40 | self.point_num.append(seg.shape[0]) 41 | 42 | def chunks(self, l, n): 43 | """Yield successive n-sized chunks from l.""" 44 | for i in range(0, len(l), n): 45 | yield l[i:i + n] 46 | 47 | def split_data(self, data, idx): 48 | new_data = [] 49 | for i in range(len(idx)): 50 | new_data += [np.expand_dims(data[idx[i]], axis = 0)] 51 | return new_data 52 | 53 | def nearest_dist(self, block_center, block_center_list): 54 | num_blocks = len(block_center_list) 55 | dist = np.zeros(num_blocks) 56 | for i in range(num_blocks): 57 | dist[i] = np.linalg.norm(block_center_list[i] - block_center, ord = 2) #i->j 58 | return np.argsort(dist)[0] 59 | 60 | def __getitem__(self, index): 61 | delta = 0.5 62 | if self.with_rgb: 63 | point_set_ini = self.scene_points_list[index] 64 | else: 65 | point_set_ini = self.scene_points_list[index][:, 0:3] 66 | semantic_seg_ini = self.semantic_labels_list[index].astype(np.int32) 67 | coordmax = np.max(point_set_ini[:, 0:3],axis=0) 68 | coordmin = np.min(point_set_ini[:, 0:3],axis=0) 69 | nsubvolume_x = np.ceil((coordmax[0]-coordmin[0])/delta).astype(np.int32) 70 | nsubvolume_y = np.ceil((coordmax[1]-coordmin[1])/delta).astype(np.int32) 71 | point_sets = [] 72 | semantic_segs = [] 73 | sample_weights = [] 74 | point_idxs = [] 75 | block_center = [] 76 | for i in range(nsubvolume_x): 77 | for j in range(nsubvolume_y): 78 | curmin = coordmin+[i*delta,j*delta,0] 79 | curmax = curmin+[1.5,1.5,coordmax[2]-coordmin[2]] 80 | curchoice = np.sum((point_set_ini[:,0:3]>=(curmin-0.2))*(point_set_ini[:,0:3]<=(curmax+0.2)),axis=1)==3 81 | curchoice_idx = np.where(curchoice)[0] 82 | cur_point_set = point_set_ini[curchoice,:] 83 | cur_semantic_seg = semantic_seg_ini[curchoice] 84 | if len(cur_semantic_seg)==0: 85 | continue 86 | mask = np.sum((cur_point_set[:,0:3]>=(curmin-0.001))*(cur_point_set[:,0:3]<=(curmax+0.001)),axis=1)==3 87 | sample_weight = self.labelweights[cur_semantic_seg] 88 | sample_weight *= mask # N 89 | point_sets.append(cur_point_set) # 1xNx3/6 90 | semantic_segs.append(cur_semantic_seg) # 1xN 91 | sample_weights.append(sample_weight) # 1xN 92 | point_idxs.append(curchoice_idx) #1xN 93 | block_center.append((curmin[0:2] + curmax[0:2]) / 2.0) 94 | 95 | # merge small blocks 96 | num_blocks = len(point_sets) 97 | block_idx = 0 98 | while block_idx < num_blocks: 99 | if point_sets[block_idx].shape[0] > 4096: 100 | block_idx += 1 101 | continue 102 | 103 | small_block_data = point_sets[block_idx].copy() 104 | small_block_seg = semantic_segs[block_idx].copy() 105 | small_block_smpw = sample_weights[block_idx].copy() 106 | small_block_idxs = point_idxs[block_idx].copy() 107 | small_block_center = block_center[block_idx].copy() 108 | point_sets.pop(block_idx) 109 | semantic_segs.pop(block_idx) 110 | sample_weights.pop(block_idx) 111 | point_idxs.pop(block_idx) 112 | block_center.pop(block_idx) 113 | nearest_block_idx = self.nearest_dist(small_block_center, block_center) 114 | point_sets[nearest_block_idx] = np.concatenate((point_sets[nearest_block_idx], small_block_data), axis = 0) 115 | semantic_segs[nearest_block_idx] = np.concatenate((semantic_segs[nearest_block_idx], small_block_seg), axis = 0) 116 | sample_weights[nearest_block_idx] = np.concatenate((sample_weights[nearest_block_idx], small_block_smpw), axis = 0) 117 | point_idxs[nearest_block_idx] = np.concatenate((point_idxs[nearest_block_idx], small_block_idxs), axis = 0) 118 | num_blocks = len(point_sets) 119 | 120 | #divide large blocks 121 | num_blocks = len(point_sets) 122 | div_blocks = [] 123 | div_blocks_seg = [] 124 | div_blocks_smpw = [] 125 | div_blocks_idxs = [] 126 | div_blocks_center = [] 127 | for block_idx in range(num_blocks): 128 | cur_num_pts = point_sets[block_idx].shape[0] 129 | 130 | point_idx_block = np.array([x for x in range(cur_num_pts)]) 131 | if point_idx_block.shape[0]%self.block_points != 0: 132 | makeup_num = self.block_points - point_idx_block.shape[0]%self.block_points 133 | np.random.shuffle(point_idx_block) 134 | point_idx_block = np.concatenate((point_idx_block,point_idx_block[0:makeup_num].copy())) 135 | 136 | np.random.shuffle(point_idx_block) 137 | 138 | sub_blocks = list(self.chunks(point_idx_block, self.block_points)) 139 | 140 | div_blocks += self.split_data(point_sets[block_idx], sub_blocks) 141 | div_blocks_seg += self.split_data(semantic_segs[block_idx], sub_blocks) 142 | div_blocks_smpw += self.split_data(sample_weights[block_idx], sub_blocks) 143 | div_blocks_idxs += self.split_data(point_idxs[block_idx], sub_blocks) 144 | div_blocks_center += [block_center[block_idx].copy() for i in range(len(sub_blocks))] 145 | div_blocks = np.concatenate(tuple(div_blocks),axis=0) 146 | div_blocks_seg = np.concatenate(tuple(div_blocks_seg),axis=0) 147 | div_blocks_smpw = np.concatenate(tuple(div_blocks_smpw),axis=0) 148 | div_blocks_idxs = np.concatenate(tuple(div_blocks_idxs),axis=0) 149 | return div_blocks, div_blocks_seg, div_blocks_smpw, div_blocks_idxs 150 | def __len__(self): 151 | return len(self.scene_points_list) 152 | 153 | if __name__=='__main__': 154 | import pdb 155 | pdb.set_trace() 156 | #d = ScannetDataset(root = '../data/scannet/scannet_v2', split='test', npoints=8192) 157 | d = ScannetDatasetWholeScene_evaluation(root = './data_v2') 158 | labelweights_vox = np.zeros(21) 159 | for ii in range(len(d)): 160 | print(ii) 161 | ps,seg,smpw, idxs = d[ii] 162 | print(labelweights_vox[1:].astype(np.float32)/np.sum(labelweights_vox[1:].astype(np.float32))) 163 | exit() 164 | 165 | -------------------------------------------------------------------------------- /scannet/scannetv2_seg_dataset_rgb21c_pointid.py: -------------------------------------------------------------------------------- 1 | """ 2 | ScanNet v2 data preprocessing. 3 | Extract point clouds data from .ply files to genrate .pickle files for training and testing. 4 | Author: Wenxuan Wu 5 | Date: July 2018 6 | """ 7 | 8 | import os 9 | import sys 10 | import numpy as np 11 | import util 12 | import h5py 13 | import pickle 14 | from plyfile import PlyData, PlyElement 15 | 16 | def remove_unano(scene_data, scene_label, scene_data_id): 17 | keep_idx = np.where((scene_label > 0) & (scene_label < 41)) # 0: unanotated 18 | scene_data_clean = scene_data[keep_idx] 19 | scene_label_clean = scene_label[keep_idx] 20 | scene_data_id_clean = scene_data_id[keep_idx] 21 | return scene_data_clean, scene_label_clean, scene_data_id_clean 22 | 23 | test_class = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39] 24 | def gen_label_map(): 25 | label_map = np.zeros(41) 26 | for i in range(41): 27 | if i in test_class: 28 | label_map[i] = test_class.index(i) 29 | else: 30 | label_map[i] = 0 31 | print(label_map) 32 | return label_map 33 | 34 | def gen_pickle(split = "val", root = "DataSet/Scannet_v2"): 35 | if split == 'test': 36 | root = root + "/scans_test" 37 | else: 38 | root = root + "/scans" 39 | file_list = "scannetv2_%s.txt"%(split) 40 | with open(file_list) as fl: 41 | scene_id = fl.read().splitlines() 42 | 43 | scene_data = [] 44 | scene_data_labels = [] 45 | scene_data_id = [] 46 | scene_data_num = [] 47 | label_map = gen_label_map() 48 | for i in range(len(scene_id)): #len(scene_id) 49 | print('process...', i) 50 | scene_namergb = os.path.join(root, scene_id[i], scene_id[i]+'_vh_clean_2.ply') 51 | scene_xyzlabelrgb = PlyData.read(scene_namergb) 52 | scene_vertex_rgb = scene_xyzlabelrgb['vertex'] 53 | scene_data_tmp = np.stack((scene_vertex_rgb['x'], scene_vertex_rgb['y'], 54 | scene_vertex_rgb['z'], scene_vertex_rgb['red'], 55 | scene_vertex_rgb['green'], scene_vertex_rgb['blue']), axis = -1).astype(np.float32) 56 | scene_points_num = scene_data_tmp.shape[0] 57 | scene_point_id = np.array([c for c in range(scene_points_num)]) 58 | if split != 'test': 59 | scene_name = os.path.join(root, scene_id[i], scene_id[i]+'_vh_clean_2.labels.ply') 60 | scene_xyzlabel = PlyData.read(scene_name) 61 | scene_vertex = scene_xyzlabel['vertex'] 62 | scene_data_label_tmp = scene_vertex['label'] 63 | scene_data_tmp, scene_data_label_tmp, scene_point_id_tmp = remove_unano(scene_data_tmp, scene_data_label_tmp, scene_point_id) 64 | else: 65 | scene_data_label_tmp = np.zeros((scene_data_tmp.shape[0])).astype(np.int32) 66 | scene_point_id_tmp = scene_point_id 67 | scene_data_label_tmp = label_map[scene_data_label_tmp] 68 | scene_data.append(scene_data_tmp) 69 | scene_data_labels.append(scene_data_label_tmp) 70 | scene_data_id.append(scene_point_id_tmp) 71 | scene_data_num.append(scene_points_num) 72 | 73 | pickle_out = open("scannet_%s_rgb21c_pointid.pickle"%(split),"wb") 74 | pickle.dump(scene_data, pickle_out, protocol=0) 75 | pickle.dump(scene_data_labels, pickle_out, protocol=0) 76 | pickle.dump(scene_data_id, pickle_out, protocol=0) 77 | pickle.dump(scene_data_num, pickle_out, protocol=0) 78 | pickle_out.close() 79 | 80 | if __name__ =='__main__': 81 | 82 | root = "/media/wenxuan/Large/DataSet/Scannet_v2" #modify this path to your Scannet v2 dataset Path 83 | gen_pickle(split = 'train', root = root) 84 | gen_pickle(split = 'val', root = root) 85 | gen_pickle(split = 'test', root = root) 86 | 87 | print('Done!!!') 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /scannet/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /scannet/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00 2 | scene0568_01 3 | scene0568_02 4 | scene0304_00 5 | scene0488_00 6 | scene0488_01 7 | scene0412_00 8 | scene0412_01 9 | scene0217_00 10 | scene0019_00 11 | scene0019_01 12 | scene0414_00 13 | scene0575_00 14 | scene0575_01 15 | scene0575_02 16 | scene0426_00 17 | scene0426_01 18 | scene0426_02 19 | scene0426_03 20 | scene0549_00 21 | scene0549_01 22 | scene0578_00 23 | scene0578_01 24 | scene0578_02 25 | scene0665_00 26 | scene0665_01 27 | scene0050_00 28 | scene0050_01 29 | scene0050_02 30 | scene0257_00 31 | scene0025_00 32 | scene0025_01 33 | scene0025_02 34 | scene0583_00 35 | scene0583_01 36 | scene0583_02 37 | scene0701_00 38 | scene0701_01 39 | scene0701_02 40 | scene0580_00 41 | scene0580_01 42 | scene0565_00 43 | scene0169_00 44 | scene0169_01 45 | scene0655_00 46 | scene0655_01 47 | scene0655_02 48 | scene0063_00 49 | scene0221_00 50 | scene0221_01 51 | scene0591_00 52 | scene0591_01 53 | scene0591_02 54 | scene0678_00 55 | scene0678_01 56 | scene0678_02 57 | scene0462_00 58 | scene0427_00 59 | scene0595_00 60 | scene0193_00 61 | scene0193_01 62 | scene0164_00 63 | scene0164_01 64 | scene0164_02 65 | scene0164_03 66 | scene0598_00 67 | scene0598_01 68 | scene0598_02 69 | scene0599_00 70 | scene0599_01 71 | scene0599_02 72 | scene0328_00 73 | scene0300_00 74 | scene0300_01 75 | scene0354_00 76 | scene0458_00 77 | scene0458_01 78 | scene0423_00 79 | scene0423_01 80 | scene0423_02 81 | scene0307_00 82 | scene0307_01 83 | scene0307_02 84 | scene0606_00 85 | scene0606_01 86 | scene0606_02 87 | scene0432_00 88 | scene0432_01 89 | scene0608_00 90 | scene0608_01 91 | scene0608_02 92 | scene0651_00 93 | scene0651_01 94 | scene0651_02 95 | scene0430_00 96 | scene0430_01 97 | scene0689_00 98 | scene0357_00 99 | scene0357_01 100 | scene0574_00 101 | scene0574_01 102 | scene0574_02 103 | scene0329_00 104 | scene0329_01 105 | scene0329_02 106 | scene0153_00 107 | scene0153_01 108 | scene0616_00 109 | scene0616_01 110 | scene0671_00 111 | scene0671_01 112 | scene0618_00 113 | scene0382_00 114 | scene0382_01 115 | scene0490_00 116 | scene0621_00 117 | scene0607_00 118 | scene0607_01 119 | scene0149_00 120 | scene0695_00 121 | scene0695_01 122 | scene0695_02 123 | scene0695_03 124 | scene0389_00 125 | scene0377_00 126 | scene0377_01 127 | scene0377_02 128 | scene0342_00 129 | scene0139_00 130 | scene0629_00 131 | scene0629_01 132 | scene0629_02 133 | scene0496_00 134 | scene0633_00 135 | scene0633_01 136 | scene0518_00 137 | scene0652_00 138 | scene0406_00 139 | scene0406_01 140 | scene0406_02 141 | scene0144_00 142 | scene0144_01 143 | scene0494_00 144 | scene0278_00 145 | scene0278_01 146 | scene0316_00 147 | scene0609_00 148 | scene0609_01 149 | scene0609_02 150 | scene0609_03 151 | scene0084_00 152 | scene0084_01 153 | scene0084_02 154 | scene0696_00 155 | scene0696_01 156 | scene0696_02 157 | scene0351_00 158 | scene0351_01 159 | scene0643_00 160 | scene0644_00 161 | scene0645_00 162 | scene0645_01 163 | scene0645_02 164 | scene0081_00 165 | scene0081_01 166 | scene0081_02 167 | scene0647_00 168 | scene0647_01 169 | scene0535_00 170 | scene0353_00 171 | scene0353_01 172 | scene0353_02 173 | scene0559_00 174 | scene0559_01 175 | scene0559_02 176 | scene0593_00 177 | scene0593_01 178 | scene0246_00 179 | scene0653_00 180 | scene0653_01 181 | scene0064_00 182 | scene0064_01 183 | scene0356_00 184 | scene0356_01 185 | scene0356_02 186 | scene0030_00 187 | scene0030_01 188 | scene0030_02 189 | scene0222_00 190 | scene0222_01 191 | scene0338_00 192 | scene0338_01 193 | scene0338_02 194 | scene0378_00 195 | scene0378_01 196 | scene0378_02 197 | scene0660_00 198 | scene0553_00 199 | scene0553_01 200 | scene0553_02 201 | scene0527_00 202 | scene0663_00 203 | scene0663_01 204 | scene0663_02 205 | scene0664_00 206 | scene0664_01 207 | scene0664_02 208 | scene0334_00 209 | scene0334_01 210 | scene0334_02 211 | scene0046_00 212 | scene0046_01 213 | scene0046_02 214 | scene0203_00 215 | scene0203_01 216 | scene0203_02 217 | scene0088_00 218 | scene0088_01 219 | scene0088_02 220 | scene0088_03 221 | scene0086_00 222 | scene0086_01 223 | scene0086_02 224 | scene0670_00 225 | scene0670_01 226 | scene0256_00 227 | scene0256_01 228 | scene0256_02 229 | scene0249_00 230 | scene0441_00 231 | scene0658_00 232 | scene0704_00 233 | scene0704_01 234 | scene0187_00 235 | scene0187_01 236 | scene0131_00 237 | scene0131_01 238 | scene0131_02 239 | scene0207_00 240 | scene0207_01 241 | scene0207_02 242 | scene0461_00 243 | scene0011_00 244 | scene0011_01 245 | scene0343_00 246 | scene0251_00 247 | scene0077_00 248 | scene0077_01 249 | scene0684_00 250 | scene0684_01 251 | scene0550_00 252 | scene0686_00 253 | scene0686_01 254 | scene0686_02 255 | scene0208_00 256 | scene0500_00 257 | scene0500_01 258 | scene0552_00 259 | scene0552_01 260 | scene0648_00 261 | scene0648_01 262 | scene0435_00 263 | scene0435_01 264 | scene0435_02 265 | scene0435_03 266 | scene0690_00 267 | scene0690_01 268 | scene0693_00 269 | scene0693_01 270 | scene0693_02 271 | scene0700_00 272 | scene0700_01 273 | scene0700_02 274 | scene0699_00 275 | scene0231_00 276 | scene0231_01 277 | scene0231_02 278 | scene0697_00 279 | scene0697_01 280 | scene0697_02 281 | scene0697_03 282 | scene0474_00 283 | scene0474_01 284 | scene0474_02 285 | scene0474_03 286 | scene0474_04 287 | scene0474_05 288 | scene0355_00 289 | scene0355_01 290 | scene0146_00 291 | scene0146_01 292 | scene0146_02 293 | scene0196_00 294 | scene0702_00 295 | scene0702_01 296 | scene0702_02 297 | scene0314_00 298 | scene0277_00 299 | scene0277_01 300 | scene0277_02 301 | scene0095_00 302 | scene0095_01 303 | scene0015_00 304 | scene0100_00 305 | scene0100_01 306 | scene0100_02 307 | scene0558_00 308 | scene0558_01 309 | scene0558_02 310 | scene0685_00 311 | scene0685_01 312 | scene0685_02 313 | -------------------------------------------------------------------------------- /scannet/util.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import csv 3 | try: 4 | import numpy as np 5 | except: 6 | print("Failed to import numpy package.") 7 | sys.exit(-1) 8 | try: 9 | import imageio 10 | except: 11 | print("Please install the module 'imageio' for image processing, e.g.") 12 | print("pip install imageio") 13 | sys.exit(-1) 14 | 15 | # print an error message and quit 16 | def print_error(message, user_fault=False): 17 | sys.stderr.write('ERROR: ' + str(message) + '\n') 18 | if user_fault: 19 | sys.exit(2) 20 | sys.exit(-1) 21 | 22 | 23 | # if string s represents an int 24 | def represents_int(s): 25 | try: 26 | int(s) 27 | return True 28 | except ValueError: 29 | return False 30 | 31 | 32 | def read_label_mapping(filename, label_from='raw_category', label_to='nyu40id'): 33 | assert os.path.isfile(filename) 34 | mapping = dict() 35 | with open(filename) as csvfile: 36 | reader = csv.DictReader(csvfile, delimiter='\t') 37 | for row in reader: 38 | mapping[row[label_from]] = int(row[label_to]) 39 | # if ints convert 40 | if represents_int(mapping.keys()[0]): 41 | mapping = {int(k):v for k,v in mapping.items()} 42 | return mapping 43 | 44 | 45 | # input: scene_types.txt or scene_types_all.txt 46 | def read_scene_types_mapping(filename, remove_spaces=True): 47 | assert os.path.isfile(filename) 48 | mapping = dict() 49 | lines = open(filename).read().splitlines() 50 | lines = [line.split('\t') for line in lines] 51 | if remove_spaces: 52 | mapping = { x[1].strip():int(x[0]) for x in lines } 53 | else: 54 | mapping = { x[1]:int(x[0]) for x in lines } 55 | return mapping 56 | 57 | 58 | # color by label 59 | def visualize_label_image(filename, image): 60 | height = image.shape[0] 61 | width = image.shape[1] 62 | vis_image = np.zeros([height, width, 3], dtype=np.uint8) 63 | color_palette = create_color_palette() 64 | for idx, color in enumerate(color_palette): 65 | vis_image[image==idx] = color 66 | imageio.imwrite(filename, vis_image) 67 | 68 | 69 | # color by different instances (mod length of color palette) 70 | def visualize_instance_image(filename, image): 71 | height = image.shape[0] 72 | width = image.shape[1] 73 | vis_image = np.zeros([height, width, 3], dtype=np.uint8) 74 | color_palette = create_color_palette() 75 | instances = np.unique(image) 76 | for idx, inst in enumerate(instances): 77 | vis_image[image==inst] = color_palette[inst%len(color_palette)] 78 | imageio.imwrite(filename, vis_image) 79 | 80 | 81 | # color palette for nyu40 labels 82 | def create_color_palette(): 83 | return [ 84 | (0, 0, 0), 85 | (174, 199, 232), # wall 86 | (152, 223, 138), # floor 87 | (31, 119, 180), # cabinet 88 | (255, 187, 120), # bed 89 | (188, 189, 34), # chair 90 | (140, 86, 75), # sofa 91 | (255, 152, 150), # table 92 | (214, 39, 40), # door 93 | (197, 176, 213), # window 94 | (148, 103, 189), # bookshelf 95 | (196, 156, 148), # picture 96 | (23, 190, 207), # counter 97 | (178, 76, 76), 98 | (247, 182, 210), # desk 99 | (66, 188, 102), 100 | (219, 219, 141), # curtain 101 | (140, 57, 197), 102 | (202, 185, 52), 103 | (51, 176, 203), 104 | (200, 54, 131), 105 | (92, 193, 61), 106 | (78, 71, 183), 107 | (172, 114, 82), 108 | (255, 127, 14), # refrigerator 109 | (91, 163, 138), 110 | (153, 98, 156), 111 | (140, 153, 101), 112 | (158, 218, 229), # shower curtain 113 | (100, 125, 154), 114 | (178, 127, 135), 115 | (120, 185, 128), 116 | (146, 111, 194), 117 | (44, 160, 44), # toilet 118 | (112, 128, 144), # sink 119 | (96, 207, 209), 120 | (227, 119, 194), # bathtub 121 | (213, 92, 176), 122 | (94, 106, 211), 123 | (82, 84, 163), # otherfurn 124 | (100, 85, 144) 125 | ] 126 | -------------------------------------------------------------------------------- /scannet/visualize/util.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import csv 3 | try: 4 | import numpy as np 5 | except: 6 | print("Failed to import numpy package.") 7 | sys.exit(-1) 8 | try: 9 | import imageio 10 | except: 11 | print("Please install the module 'imageio' for image processing, e.g.") 12 | print("pip install imageio") 13 | sys.exit(-1) 14 | 15 | # print an error message and quit 16 | def print_error(message, user_fault=False): 17 | sys.stderr.write('ERROR: ' + str(message) + '\n') 18 | if user_fault: 19 | sys.exit(2) 20 | sys.exit(-1) 21 | 22 | 23 | # if string s represents an int 24 | def represents_int(s): 25 | try: 26 | int(s) 27 | return True 28 | except ValueError: 29 | return False 30 | 31 | 32 | def read_label_mapping(filename, label_from='raw_category', label_to='nyu40id'): 33 | assert os.path.isfile(filename) 34 | mapping = dict() 35 | with open(filename) as csvfile: 36 | reader = csv.DictReader(csvfile, delimiter='\t') 37 | for row in reader: 38 | mapping[row[label_from]] = int(row[label_to]) 39 | # if ints convert 40 | if represents_int(mapping.keys()[0]): 41 | mapping = {int(k):v for k,v in mapping.items()} 42 | return mapping 43 | 44 | 45 | # input: scene_types.txt or scene_types_all.txt 46 | def read_scene_types_mapping(filename, remove_spaces=True): 47 | assert os.path.isfile(filename) 48 | mapping = dict() 49 | lines = open(filename).read().splitlines() 50 | lines = [line.split('\t') for line in lines] 51 | if remove_spaces: 52 | mapping = { x[1].strip():int(x[0]) for x in lines } 53 | else: 54 | mapping = { x[1]:int(x[0]) for x in lines } 55 | return mapping 56 | 57 | 58 | # color by label 59 | def visualize_label_image(filename, image): 60 | height = image.shape[0] 61 | width = image.shape[1] 62 | vis_image = np.zeros([height, width, 3], dtype=np.uint8) 63 | color_palette = create_color_palette() 64 | for idx, color in enumerate(color_palette): 65 | vis_image[image==idx] = color 66 | imageio.imwrite(filename, vis_image) 67 | 68 | 69 | # color by different instances (mod length of color palette) 70 | def visualize_instance_image(filename, image): 71 | height = image.shape[0] 72 | width = image.shape[1] 73 | vis_image = np.zeros([height, width, 3], dtype=np.uint8) 74 | color_palette = create_color_palette() 75 | instances = np.unique(image) 76 | for idx, inst in enumerate(instances): 77 | vis_image[image==inst] = color_palette[inst%len(color_palette)] 78 | imageio.imwrite(filename, vis_image) 79 | 80 | 81 | # color palette for nyu40 labels 82 | def create_color_palette(): 83 | return [ 84 | (0, 0, 0), 85 | (174, 199, 232), # wall 86 | (152, 223, 138), # floor 87 | (31, 119, 180), # cabinet 88 | (255, 187, 120), # bed 89 | (188, 189, 34), # chair 90 | (140, 86, 75), # sofa 91 | (255, 152, 150), # table 92 | (214, 39, 40), # door 93 | (197, 176, 213), # window 94 | (148, 103, 189), # bookshelf 95 | (196, 156, 148), # picture 96 | (23, 190, 207), # counter 97 | (178, 76, 76), 98 | (247, 182, 210), # desk 99 | (66, 188, 102), 100 | (219, 219, 141), # curtain 101 | (140, 57, 197), 102 | (202, 185, 52), 103 | (51, 176, 203), 104 | (200, 54, 131), 105 | (92, 193, 61), 106 | (78, 71, 183), 107 | (172, 114, 82), 108 | (255, 127, 14), # refrigerator 109 | (91, 163, 138), 110 | (153, 98, 156), 111 | (140, 153, 101), 112 | (158, 218, 229), # shower curtain 113 | (100, 125, 154), 114 | (178, 127, 135), 115 | (120, 185, 128), 116 | (146, 111, 194), 117 | (44, 160, 44), # toilet 118 | (112, 128, 144), # sink 119 | (96, 207, 209), 120 | (227, 119, 194), # bathtub 121 | (213, 92, 176), 122 | (94, 106, 211), 123 | (82, 84, 163), # otherfurn 124 | (100, 85, 144) 125 | ] 126 | -------------------------------------------------------------------------------- /scannet/visualize/util_3d.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json 3 | 4 | try: 5 | import numpy as np 6 | except: 7 | print("Failed to import numpy package.") 8 | sys.exit(-1) 9 | 10 | try: 11 | from plyfile import PlyData, PlyElement 12 | except: 13 | print("Please install the module 'plyfile' for PLY i/o, e.g.") 14 | print("pip install plyfile") 15 | sys.exit(-1) 16 | 17 | import util 18 | 19 | 20 | # matrix: 4x4 np array 21 | # points Nx3 np array 22 | def transform_points(matrix, points): 23 | assert len(points.shape) == 2 and points.shape[1] == 3 24 | num_points = points.shape[0] 25 | p = np.concatenate([points, np.ones((num_points, 1))], axis=1) 26 | p = np.matmul(matrix, np.transpose(p)) 27 | p = np.transpose(p) 28 | p[:,:3] /= p[:,3,None] 29 | return p[:,:3] 30 | 31 | 32 | def export_ids(filename, ids): 33 | with open(filename, 'w') as f: 34 | for id in ids: 35 | f.write('%d\n' % id) 36 | 37 | 38 | def load_ids(filename): 39 | ids = open(filename).read().splitlines() 40 | ids = np.array(ids, dtype=np.int64) 41 | return ids 42 | 43 | 44 | def read_mesh_vertices(filename): 45 | assert os.path.isfile(filename) 46 | with open(filename, 'rb') as f: 47 | plydata = PlyData.read(f) 48 | num_verts = plydata['vertex'].count 49 | vertices = np.zeros(shape=[num_verts, 3], dtype=np.float32) 50 | vertices[:,0] = plydata['vertex'].data['x'] 51 | vertices[:,1] = plydata['vertex'].data['y'] 52 | vertices[:,2] = plydata['vertex'].data['z'] 53 | return vertices 54 | 55 | 56 | # export 3d instance labels for instance evaluation 57 | def export_instance_ids_for_eval(filename, label_ids, instance_ids): 58 | assert label_ids.shape[0] == instance_ids.shape[0] 59 | output_mask_path_relative = 'pred_mask' 60 | name = os.path.splitext(os.path.basename(filename))[0] 61 | output_mask_path = os.path.join(os.path.dirname(filename), output_mask_path_relative) 62 | if not os.path.isdir(output_mask_path): 63 | os.mkdir(output_mask_path) 64 | insts = np.unique(instance_ids) 65 | zero_mask = np.zeros(shape=(instance_ids.shape[0]), dtype=np.int32) 66 | with open(filename, 'w') as f: 67 | for idx, inst_id in enumerate(insts): 68 | if inst_id == 0: # 0 -> no instance for this vertex 69 | continue 70 | output_mask_file = os.path.join(output_mask_path_relative, name + '_' + str(idx) + '.txt') 71 | loc = np.where(instance_ids == inst_id) 72 | label_id = label_ids[loc[0][0]] 73 | f.write('%s %d %f\n' % (output_mask_file, label_id, 1.0)) 74 | # write mask 75 | mask = np.copy(zero_mask) 76 | mask[loc[0]] = 1 77 | export_ids(output_mask_file, mask) 78 | 79 | 80 | # ------------ Instance Utils ------------ # 81 | 82 | class Instance(object): 83 | instance_id = 0 84 | label_id = 0 85 | vert_count = 0 86 | med_dist = -1 87 | dist_conf = 0.0 88 | 89 | def __init__(self, mesh_vert_instances, instance_id): 90 | if (instance_id == -1): 91 | return 92 | self.instance_id = int(instance_id) 93 | self.label_id = int(self.get_label_id(instance_id)) 94 | self.vert_count = int(self.get_instance_verts(mesh_vert_instances, instance_id)) 95 | 96 | def get_label_id(self, instance_id): 97 | return int(instance_id // 1000) 98 | 99 | def get_instance_verts(self, mesh_vert_instances, instance_id): 100 | return (mesh_vert_instances == instance_id).sum() 101 | 102 | def to_json(self): 103 | return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) 104 | 105 | def to_dict(self): 106 | dict = {} 107 | dict["instance_id"] = self.instance_id 108 | dict["label_id"] = self.label_id 109 | dict["vert_count"] = self.vert_count 110 | dict["med_dist"] = self.med_dist 111 | dict["dist_conf"] = self.dist_conf 112 | return dict 113 | 114 | def from_json(self, data): 115 | self.instance_id = int(data["instance_id"]) 116 | self.label_id = int(data["label_id"]) 117 | self.vert_count = int(data["vert_count"]) 118 | if ("med_dist" in data): 119 | self.med_dist = float(data["med_dist"]) 120 | self.dist_conf = float(data["dist_conf"]) 121 | 122 | def __str__(self): 123 | return "("+str(self.instance_id)+")" 124 | 125 | def read_instance_prediction_file(filename, pred_path): 126 | lines = open(filename).read().splitlines() 127 | instance_info = {} 128 | abs_pred_path = os.path.abspath(pred_path) 129 | for line in lines: 130 | parts = line.split(' ') 131 | if len(parts) != 3: 132 | util.print_error('invalid instance prediction file. Expected (per line): [rel path prediction] [label id prediction] [confidence prediction]') 133 | if os.path.isabs(parts[0]): 134 | util.print_error('invalid instance prediction file. First entry in line must be a relative path') 135 | mask_file = os.path.join(os.path.dirname(filename), parts[0]) 136 | mask_file = os.path.abspath(mask_file) 137 | # check that mask_file lives inside prediction path 138 | if os.path.commonprefix([mask_file, abs_pred_path]) != abs_pred_path: 139 | util.print_error('predicted mask {} in prediction text file {} points outside of prediction path.'.format(mask_file,filename)) 140 | 141 | info = {} 142 | info["label_id"] = int(float(parts[1])) 143 | info["conf"] = float(parts[2]) 144 | instance_info[mask_file] = info 145 | return instance_info 146 | 147 | 148 | def get_instances(ids, class_ids, class_labels, id2label): 149 | instances = {} 150 | for label in class_labels: 151 | instances[label] = [] 152 | instance_ids = np.unique(ids) 153 | for id in instance_ids: 154 | if id == 0: 155 | continue 156 | inst = Instance(ids, id) 157 | if inst.label_id in class_ids: 158 | instances[id2label[inst.label_id]].append(inst.to_dict()) 159 | return instances 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /scannet/visualize/visualize_labels_on_mesh.py: -------------------------------------------------------------------------------- 1 | # Example script to visualize labels in the evaluation format on the corresponding mesh. 2 | # Inputs: 3 | # - predicted labels as a .txt file with one line per vertex 4 | # - the corresponding *_vh_clean_2.ply mesh 5 | # Outputs a .ply with vertex colors, a different color per value in the predicted .txt file 6 | # 7 | # example usage: visualize_labels_on_mesh.py --pred_file [path to predicted labels file] --mesh_file [path to the *_vh_clean_2.ply mesh] --output_file [output file] 8 | 9 | # python imports 10 | import math 11 | import os, sys, argparse 12 | import inspect 13 | import json 14 | 15 | try: 16 | import numpy as np 17 | except: 18 | print("Failed to import numpy package.") 19 | sys.exit(-1) 20 | try: 21 | from plyfile import PlyData, PlyElement 22 | except: 23 | print("Please install the module 'plyfile' for PLY i/o, e.g.") 24 | print("pip install plyfile") 25 | sys.exit(-1) 26 | 27 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 28 | parentdir = os.path.dirname(currentdir) 29 | sys.path.insert(0,parentdir) 30 | import util 31 | import util_3d 32 | 33 | def visualize(pred_file, mesh_file, output_file): 34 | if not output_file.endswith('.ply'): 35 | util.print_error('output file must be a .ply file') 36 | colors = util.create_color_palette() 37 | num_colors = len(colors) 38 | ids = util_3d.load_ids(pred_file) 39 | with open(mesh_file, 'rb') as f: 40 | plydata = PlyData.read(f) 41 | num_verts = plydata['vertex'].count 42 | if num_verts != len(ids): 43 | util.print_error('#predicted labels = ' + str(len(ids)) + 'vs #mesh vertices = ' + str(num_verts)) 44 | # *_vh_clean_2.ply has colors already 45 | for i in range(num_verts): 46 | if ids[i] >= num_colors: 47 | util.print_error('found predicted label ' + str(ids[i]) + ' not in nyu40 label set') 48 | color = colors[ids[i]] 49 | plydata['vertex']['red'][i] = color[0] 50 | plydata['vertex']['green'][i] = color[1] 51 | plydata['vertex']['blue'][i] = color[2] 52 | plydata.write(output_file) 53 | -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // Find three nearest neigbors with square distance 19 | // input: xyz1 (b,n,3), xyz2(b,m,3) 20 | // output: dist (b,n,3), idx (b,n,3) 21 | void threenn_cpu(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx) { 22 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | using namespace tensorflow; 11 | 12 | REGISTER_OP("ThreeNN") 13 | .Input("xyz1: float32") 14 | .Input("xyz2: float32") 15 | .Output("dist: float32") 16 | .Output("idx: int32") 17 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 18 | c->set_output(0, c->input(0)); 19 | c->set_output(1, c->input(0)); 20 | return Status::OK(); 21 | }); 22 | REGISTER_OP("ThreeInterpolate") 23 | .Input("points: float32") 24 | .Input("idx: int32") 25 | .Input("weight: float32") 26 | .Output("out: float32") 27 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 28 | ::tensorflow::shape_inference::ShapeHandle dims1; // (b,m,c) 29 | c->WithRank(c->input(0), 3, &dims1); 30 | ::tensorflow::shape_inference::ShapeHandle dims2; // (b,n,3) 31 | c->WithRank(c->input(1), 3, &dims2); 32 | // (b,n,c) 33 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), c->Dim(dims2, 1), c->Dim(dims1, 2)}); 34 | c->set_output(0, output); 35 | return Status::OK(); 36 | }); 37 | REGISTER_OP("ThreeInterpolateGrad") 38 | .Input("points: float32") 39 | .Input("idx: int32") 40 | .Input("weight: float32") 41 | .Input("grad_out: float32") 42 | .Output("grad_points: float32") 43 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 44 | c->set_output(0, c->input(0)); 45 | return Status::OK(); 46 | }); 47 | 48 | float randomf(){ 49 | return (rand()+0.5)/(RAND_MAX+1.0); 50 | } 51 | static double get_time(){ 52 | timespec tp; 53 | clock_gettime(CLOCK_MONOTONIC,&tp); 54 | return tp.tv_sec+tp.tv_nsec*1e-9; 55 | } 56 | 57 | // Find three nearest neigbors with square distance 58 | // input: xyz1 (b,n,3), xyz2(b,m,3) 59 | // output: dist (b,n,3), idx (b,n,3) 60 | void threenn_cpu(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx) { 61 | for (int i=0;iinput(0); 163 | OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeNN expects (b,n,3) xyz1 shape.")); 164 | int b = xyz1_tensor.shape().dim_size(0); 165 | int n = xyz1_tensor.shape().dim_size(1); 166 | 167 | const Tensor& xyz2_tensor = context->input(1); 168 | OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeNN expects (b,m,3) xyz2 shape.")); 169 | int m = xyz2_tensor.shape().dim_size(1); 170 | 171 | Tensor *dist_tensor = nullptr; 172 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,n,3}, &dist_tensor)); 173 | Tensor *idx_tensor = nullptr; 174 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,n,3}, &idx_tensor)); 175 | 176 | auto xyz1_flat = xyz1_tensor.flat(); 177 | const float *xyz1 = &(xyz1_flat(0)); 178 | auto xyz2_flat = xyz2_tensor.flat(); 179 | const float *xyz2 = &(xyz2_flat(0)); 180 | auto dist_flat = dist_tensor->flat(); 181 | float *dist = &(dist_flat(0)); 182 | auto idx_flat = idx_tensor->flat(); 183 | int *idx = &(idx_flat(0)); 184 | threenn_cpu(b,n,m,xyz1,xyz2,dist,idx); 185 | } 186 | }; 187 | REGISTER_KERNEL_BUILDER(Name("ThreeNN").Device(DEVICE_CPU), ThreeNNOp); 188 | 189 | 190 | 191 | class ThreeInterpolateOp: public OpKernel{ 192 | public: 193 | explicit ThreeInterpolateOp(OpKernelConstruction * context):OpKernel(context){} 194 | 195 | void Compute(OpKernelContext * context) override { 196 | const Tensor& points_tensor=context->input(0); 197 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("ThreeInterpolate expects (b,m,c) points shape")); 198 | int b = points_tensor.shape().dim_size(0); 199 | int m = points_tensor.shape().dim_size(1); 200 | int c = points_tensor.shape().dim_size(2); 201 | 202 | const Tensor& idx_tensor=context->input(1); 203 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b && idx_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolate expects (b,n,3) idx shape")); 204 | int n = idx_tensor.shape().dim_size(1); 205 | const Tensor& weight_tensor=context->input(2); 206 | OP_REQUIRES(context,weight_tensor.dims()==3 && weight_tensor.shape().dim_size(0)==b && weight_tensor.shape().dim_size(1)==n && weight_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolate expects (b,n,3) weight shape")); 207 | 208 | Tensor * out_tensor = nullptr; 209 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,n,c}, &out_tensor)); 210 | 211 | auto points_flat = points_tensor.flat(); 212 | const float *points = &(points_flat(0)); 213 | auto idx_flat = idx_tensor.flat(); 214 | const int *idx = &(idx_flat(0)); 215 | auto weight_flat = weight_tensor.flat(); 216 | const float *weight = &(weight_flat(0)); 217 | auto out_flat = out_tensor->flat(); 218 | float *out = &(out_flat(0)); 219 | threeinterpolate_cpu(b,m,c,n,points,idx,weight,out); 220 | } 221 | }; 222 | REGISTER_KERNEL_BUILDER(Name("ThreeInterpolate").Device(DEVICE_CPU),ThreeInterpolateOp); 223 | 224 | 225 | class ThreeInterpolateGradOp: public OpKernel{ 226 | public: 227 | explicit ThreeInterpolateGradOp(OpKernelConstruction * context):OpKernel(context){} 228 | 229 | void Compute(OpKernelContext * context) override { 230 | const Tensor& points_tensor=context->input(0); 231 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("ThreeInterpolateGrad expects (b,m,c) points shape")); 232 | int b = points_tensor.shape().dim_size(0); 233 | int m = points_tensor.shape().dim_size(1); 234 | int c = points_tensor.shape().dim_size(2); 235 | 236 | const Tensor& idx_tensor=context->input(1); 237 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,3) idx shape")); 238 | int n = idx_tensor.shape().dim_size(1); 239 | const Tensor& weight_tensor=context->input(2); 240 | OP_REQUIRES(context,weight_tensor.dims()==3 && weight_tensor.shape().dim_size(0)==b && weight_tensor.shape().dim_size(1)==n && weight_tensor.shape().dim_size(2)==3, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,3) weight shape")); 241 | 242 | const Tensor& grad_out_tensor=context->input(3); 243 | OP_REQUIRES(context,grad_out_tensor.dims()==3 && grad_out_tensor.shape().dim_size(0)==b && grad_out_tensor.shape().dim_size(1)==n && grad_out_tensor.shape().dim_size(2)==c, errors::InvalidArgument("ThreeInterpolateGrad expects (b,n,c) grad_out shape")); 244 | 245 | Tensor * grad_points_tensor = nullptr; 246 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,m,c}, &grad_points_tensor)); 247 | 248 | auto points_flat = points_tensor.flat(); 249 | const float *points = &(points_flat(0)); 250 | auto idx_flat = idx_tensor.flat(); 251 | const int *idx = &(idx_flat(0)); 252 | auto weight_flat = weight_tensor.flat(); 253 | const float *weight = &(weight_flat(0)); 254 | auto grad_out_flat = grad_out_tensor.flat(); 255 | const float *grad_out = &(grad_out_flat(0)); 256 | auto grad_points_flat = grad_points_tensor->flat(); 257 | float *grad_points = &(grad_points_flat(0)); 258 | memset(grad_points, 0, sizeof(float)*b*m*c); 259 | threeinterpolate_grad_cpu(b,n,c,m,grad_out,idx,weight,grad_points); 260 | } 261 | }; 262 | REGISTER_KERNEL_BUILDER(Name("ThreeInterpolateGrad").Device(DEVICE_CPU),ThreeInterpolateGradOp); 263 | 264 | 265 | -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/tf_interpolate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | import sys 4 | import os 5 | BASE_DIR = os.path.dirname(__file__) 6 | sys.path.append(BASE_DIR) 7 | interpolate_module=tf.load_op_library(os.path.join(BASE_DIR, 'tf_interpolate_so.so')) 8 | def three_nn(xyz1, xyz2): 9 | ''' 10 | Input: 11 | xyz1: (b,n,3) float32 array, unknown points 12 | xyz2: (b,m,3) float32 array, known points 13 | Output: 14 | dist: (b,n,3) float32 array, distances to known points 15 | idx: (b,n,3) int32 array, indices to known points 16 | ''' 17 | return interpolate_module.three_nn(xyz1, xyz2) 18 | ops.NoGradient('ThreeNN') 19 | def three_interpolate(points, idx, weight): 20 | ''' 21 | Input: 22 | points: (b,m,c) float32 array, known points 23 | idx: (b,n,3) int32 array, indices to known points 24 | weight: (b,n,3) float32 array, weights on known points 25 | Output: 26 | out: (b,n,c) float32 array, interpolated point values 27 | ''' 28 | return interpolate_module.three_interpolate(points, idx, weight) 29 | @tf.RegisterGradient('ThreeInterpolate') 30 | def _three_interpolate_grad(op, grad_out): 31 | points = op.inputs[0] 32 | idx = op.inputs[1] 33 | weight = op.inputs[2] 34 | return [interpolate_module.three_interpolate_grad(points, idx, weight, grad_out), None, None] 35 | 36 | if __name__=='__main__': 37 | import numpy as np 38 | import time 39 | np.random.seed(100) 40 | pts = np.random.random((32,128,64)).astype('float32') 41 | tmp1 = np.random.random((32,512,3)).astype('float32') 42 | tmp2 = np.random.random((32,128,3)).astype('float32') 43 | with tf.device('/cpu:0'): 44 | points = tf.constant(pts) 45 | xyz1 = tf.constant(tmp1) 46 | xyz2 = tf.constant(tmp2) 47 | dist, idx = three_nn(xyz1, xyz2) 48 | weight = tf.ones_like(dist)/3.0 49 | interpolated_points = three_interpolate(points, idx, weight) 50 | with tf.Session('') as sess: 51 | now = time.time() 52 | for _ in range(100): 53 | ret = sess.run(interpolated_points) 54 | print(time.time() - now) 55 | print(ret.shape, ret.dtype) 56 | #print ret 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/tf_interpolate_compile.sh: -------------------------------------------------------------------------------- 1 | # TF1.2 2 | #g++ -std=c++11 tf_interpolate.cpp -o tf_interpolate_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -lcudart -L /usr/local/cuda-8.0/lib64/ -O2 -D_GLIBCXX_USE_CXX11_ABI=0 3 | 4 | # TF1.4 5 | CUDA_PATH=/usr/local/cuda-9.0 6 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 7 | TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 8 | g++ -std=c++11 tf_interpolate.cpp -o tf_interpolate_so.so -shared -fPIC -fPIC -I $TF_INC -I $CUDA_PATH/include -lcudart -L $CUDA_PATH/lib64/ -L$TF_LIB -I$TF_INC/external/nsync/public -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 9 | -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/tf_interpolate_op_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tf_interpolate import three_nn, three_interpolate 4 | 5 | class GroupPointTest(tf.test.TestCase): 6 | def test(self): 7 | pass 8 | 9 | def test_grad(self): 10 | with self.test_session(): 11 | points = tf.constant(np.random.random((1,8,16)).astype('float32')) 12 | print(points) 13 | xyz1 = tf.constant(np.random.random((1,128,3)).astype('float32')) 14 | xyz2 = tf.constant(np.random.random((1,8,3)).astype('float32')) 15 | dist, idx = three_nn(xyz1, xyz2) 16 | weight = tf.ones_like(dist)/3.0 17 | interpolated_points = three_interpolate(points, idx, weight) 18 | print(interpolated_points) 19 | err = tf.test.compute_gradient_error(points, (1,8,16), interpolated_points, (1,128,16)) 20 | print(err) 21 | self.assertLess(err, 1e-4) 22 | 23 | if __name__=='__main__': 24 | tf.test.main() 25 | -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/tf_interpolate_so.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DylanWusee/pointconv/f39dc3e101af2f52544181ee20c14f73279b48ae/tf_ops/3d_interpolation/tf_interpolate_so.so -------------------------------------------------------------------------------- /tf_ops/3d_interpolation/visu_interpolation.py: -------------------------------------------------------------------------------- 1 | ''' Visualize part segmentation ''' 2 | import os 3 | import sys 4 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | sys.path.append('/home/rqi/Projects/toolkits/visualization') 6 | from show3d_balls import showpoints 7 | import numpy as np 8 | from tf_interpolate import three_nn, three_interpolate 9 | import tensorflow as tf 10 | 11 | 12 | pts2 = np.array([[0,0,1],[1,0,0],[0,1,0],[1,1,0]]).astype('float32') 13 | xyz1 = np.random.random((100,3)).astype('float32') 14 | xyz2 = np.array([[0,0,0],[1,0,0],[0,1,0],[1,1,1]]).astype('float32') 15 | 16 | def fun(xyz1,xyz2,pts2): 17 | with tf.device('/cpu:0'): 18 | points = tf.constant(np.expand_dims(pts2,0)) 19 | xyz1 = tf.constant(np.expand_dims(xyz1,0)) 20 | xyz2 = tf.constant(np.expand_dims(xyz2,0)) 21 | dist, idx = three_nn(xyz1, xyz2) 22 | #weight = tf.ones_like(dist)/3.0 23 | dist = tf.maximum(dist, 1e-10) 24 | norm = tf.reduce_sum((1.0/dist),axis=2,keep_dims=True) 25 | norm = tf.tile(norm, [1,1,3]) 26 | print(norm) 27 | weight = (1.0/dist) / norm 28 | interpolated_points = three_interpolate(points, idx, weight) 29 | with tf.Session('') as sess: 30 | tmp,pts1,d,w = sess.run([xyz1, interpolated_points, dist, weight]) 31 | #print w 32 | pts1 = pts1.squeeze() 33 | return pts1 34 | 35 | pts1 = fun(xyz1,xyz2,pts2) 36 | all_pts = np.zeros((104,3)) 37 | all_pts[0:100,:] = pts1 38 | all_pts[100:,:] = pts2 39 | all_xyz = np.zeros((104,3)) 40 | all_xyz[0:100,:]=xyz1 41 | all_xyz[100:,:]=xyz2 42 | showpoints(xyz2, pts2, ballradius=8) 43 | showpoints(xyz1, pts1, ballradius=8) 44 | showpoints(all_xyz, all_pts, ballradius=8) 45 | -------------------------------------------------------------------------------- /tf_ops/grouping/.gitignore: -------------------------------------------------------------------------------- 1 | a.out 2 | query_ball_point 3 | query_ball_point_block 4 | query_ball_point_cuda 5 | query_ball_point_grid 6 | tf_grouping_g.cu.o 7 | tf_grouping_so.so 8 | selection_sort 9 | selection_sort_cuda 10 | selection_sort_const_cuda 11 | -------------------------------------------------------------------------------- /tf_ops/grouping/test/compile.sh: -------------------------------------------------------------------------------- 1 | g++ query_ball_point.cpp -o query_ball_point 2 | nvcc query_ball_point.cu -o query_ball_point_cuda 3 | nvcc query_ball_point_block.cu -o query_ball_point_block 4 | nvcc query_ball_point_grid.cu -o query_ball_point_grid 5 | g++ -Wall selection_sort.cpp -o selection_sort 6 | nvcc selection_sort.cu -o selection_sort_cuda 7 | -------------------------------------------------------------------------------- /tf_ops/grouping/test/query_ball_point.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | void query_ball_point_cpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | for (int i=0;i>>(b,n,m,radius,nsample,xyz1,xyz2,idx); 113 | cudaDeviceSynchronize(); 114 | printf("query_ball_point gpu time %f\n",get_time()-t0); 115 | 116 | t0=get_time(); 117 | group_point_gpu<<<1,1>>>(b,n,c,m,nsample,points,idx,out); 118 | cudaDeviceSynchronize(); 119 | printf("grou_point gpu time %f\n",get_time()-t0); 120 | 121 | t0=get_time(); 122 | group_point_grad_gpu<<<1,1>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 123 | cudaDeviceSynchronize(); 124 | printf("grou_point_grad gpu time %f\n",get_time()-t0); 125 | 126 | cudaFree(xyz1); 127 | cudaFree(xyz2); 128 | cudaFree(points); 129 | cudaFree(idx); 130 | cudaFree(out); 131 | cudaFree(grad_out); 132 | cudaFree(grad_points); 133 | return 0; 134 | } 135 | -------------------------------------------------------------------------------- /tf_ops/grouping/test/query_ball_point_block.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | int index = threadIdx.x; 21 | xyz1 += n*3*index; 22 | xyz2 += m*3*index; 23 | idx += m*nsample*index; 24 | 25 | for (int j=0;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx); 113 | cudaDeviceSynchronize(); 114 | printf("query_ball_point gpu time %f\n",get_time()-t0); 115 | 116 | t0=get_time(); 117 | group_point_gpu<<<1,b>>>(b,n,c,m,nsample,points,idx,out); 118 | cudaDeviceSynchronize(); 119 | printf("grou_point gpu time %f\n",get_time()-t0); 120 | 121 | t0=get_time(); 122 | group_point_grad_gpu<<<1,b>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 123 | cudaDeviceSynchronize(); 124 | printf("grou_point_grad gpu time %f\n",get_time()-t0); 125 | 126 | cudaFree(xyz1); 127 | cudaFree(xyz2); 128 | cudaFree(points); 129 | cudaFree(idx); 130 | cudaFree(out); 131 | cudaFree(grad_out); 132 | cudaFree(grad_points); 133 | return 0; 134 | } 135 | -------------------------------------------------------------------------------- /tf_ops/grouping/test/query_ball_point_grid.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | int batch_index = blockIdx.x; 21 | xyz1 += n*3*batch_index; 22 | xyz2 += m*3*batch_index; 23 | idx += m*nsample*batch_index; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | 28 | for (int j=index;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx); 123 | cudaDeviceSynchronize(); 124 | printf("query_ball_point gpu time %f\n",get_time()-t0); 125 | 126 | t0=get_time(); 127 | group_point_gpu<<>>(b,n,c,m,nsample,points,idx,out); 128 | cudaDeviceSynchronize(); 129 | printf("grou_point gpu time %f\n",get_time()-t0); 130 | 131 | t0=get_time(); 132 | group_point_grad_gpu<<>>(b,n,c,m,nsample,grad_out,idx,grad_points); 133 | cudaDeviceSynchronize(); 134 | printf("grou_point_grad gpu time %f\n",get_time()-t0); 135 | 136 | cudaFree(xyz1); 137 | cudaFree(xyz2); 138 | cudaFree(points); 139 | cudaFree(idx); 140 | cudaFree(out); 141 | cudaFree(grad_out); 142 | cudaFree(grad_points); 143 | return 0; 144 | } 145 | -------------------------------------------------------------------------------- /tf_ops/grouping/test/selection_sort.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // input: k (1), distance matrix dist (b,m,n) 19 | // output: idx (b,m,n), val (b,m,n) 20 | void selection_sort_cpu(int b, int n, int m, int k, const float *dist, int *idx, float *val) { 21 | float *p_dist; 22 | float tmp; 23 | int tmpi; 24 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // input: k (1), distance matrix dist (b,m,n) 19 | // output: idx (b,m,k), val (b,m,k) 20 | __global__ void selection_sort_gpu(int b, int n, int m, int k, float *dist, int *idx, float *val) { 21 | int batch_index = blockIdx.x; 22 | dist+=m*n*batch_index; 23 | idx+=m*k*batch_index; 24 | val+=m*k*batch_index; 25 | 26 | int index = threadIdx.x; 27 | int stride = blockDim.x; 28 | 29 | float *p_dist; 30 | for (int j=index;j>>(b,n,m,k,dist,idx,val); 68 | cudaDeviceSynchronize(); 69 | printf("selection sort cpu time %f\n",get_time()-t0); 70 | 71 | return 0; 72 | } 73 | -------------------------------------------------------------------------------- /tf_ops/grouping/test/selection_sort_const.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // input: k (1), distance matrix dist (b,m,n) 19 | // output: idx (b,m,n), dist_out (b,m,n) 20 | __global__ void selection_sort_gpu(int b, int n, int m, int k, const float *dist, int *outi, float *out) { 21 | int batch_index = blockIdx.x; 22 | dist+=m*n*batch_index; 23 | outi+=m*n*batch_index; 24 | out+=m*n*batch_index; 25 | 26 | int index = threadIdx.x; 27 | int stride = blockDim.x; 28 | 29 | // copy from dist to dist_out 30 | for (int j=index;j>>(b,n,m,k,dist,idx,dist_out); 84 | cudaDeviceSynchronize(); 85 | printf("selection sort cpu time %f\n",get_time()-t0); 86 | 87 | //for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | #include 11 | using namespace tensorflow; 12 | 13 | REGISTER_OP("QueryBallPoint") 14 | .Attr("radius: float") 15 | .Attr("nsample: int") 16 | .Input("xyz1: float32") 17 | .Input("xyz2: float32") 18 | .Output("idx: int32") 19 | .Output("pts_cnt: int32") 20 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 21 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoint * 3 22 | c->WithRank(c->input(1), 3, &dims2); 23 | int nsample; 24 | TF_RETURN_IF_ERROR(c->GetAttr("nsample", &nsample)); 25 | ::tensorflow::shape_inference::ShapeHandle output1 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), nsample}); 26 | c->set_output(0, output1); 27 | ::tensorflow::shape_inference::ShapeHandle output2 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)}); 28 | c->set_output(1, output2); 29 | return Status::OK(); 30 | }); 31 | REGISTER_OP("SelectionSort") 32 | .Attr("k: int") 33 | .Input("dist: float32") 34 | .Output("outi: int32") 35 | .Output("out: float32") 36 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 37 | c->set_output(0, c->input(0)); 38 | c->set_output(1, c->input(0)); 39 | return Status::OK(); 40 | }); 41 | REGISTER_OP("GroupPoint") 42 | .Input("points: float32") 43 | .Input("idx: int32") 44 | .Output("out: float32") 45 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 46 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ndataset * channels 47 | c->WithRank(c->input(0), 3, &dims1); 48 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints * nsample 49 | c->WithRank(c->input(1), 3, &dims2); 50 | // batch_size * npoints * nsample * channels 51 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), c->Dim(dims2, 2), c->Dim(dims1, 2)}); 52 | c->set_output(0, output); 53 | return Status::OK(); 54 | }); 55 | REGISTER_OP("GroupPointGrad") 56 | .Input("points: float32") 57 | .Input("idx: int32") 58 | .Input("grad_out: float32") 59 | .Output("grad_points: float32") 60 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 61 | c->set_output(0, c->input(0)); 62 | return Status::OK(); 63 | }); 64 | 65 | 66 | void queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt); 67 | class QueryBallPointGpuOp : public OpKernel { 68 | public: 69 | explicit QueryBallPointGpuOp(OpKernelConstruction* context) : OpKernel(context) { 70 | OP_REQUIRES_OK(context, context->GetAttr("radius", &radius_)); 71 | OP_REQUIRES(context, radius_ > 0, errors::InvalidArgument("QueryBallPoint expects positive radius")); 72 | 73 | OP_REQUIRES_OK(context, context->GetAttr("nsample", &nsample_)); 74 | OP_REQUIRES(context, nsample_ > 0, errors::InvalidArgument("QueryBallPoint expects positive nsample")); 75 | } 76 | 77 | void Compute(OpKernelContext* context) override { 78 | const Tensor& xyz1_tensor = context->input(0); 79 | OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("QueryBallPoint expects (batch_size, ndataset, 3) xyz1 shape.")); 80 | int b = xyz1_tensor.shape().dim_size(0); 81 | int n = xyz1_tensor.shape().dim_size(1); 82 | 83 | const Tensor& xyz2_tensor = context->input(1); 84 | OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("QueryBallPoint expects (batch_size, npoint, 3) xyz2 shape.")); 85 | int m = xyz2_tensor.shape().dim_size(1); 86 | 87 | Tensor *idx_tensor = nullptr; 88 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,nsample_}, &idx_tensor)); 89 | Tensor *pts_cnt_tensor = nullptr; 90 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m}, &pts_cnt_tensor)); 91 | 92 | auto xyz1_flat = xyz1_tensor.flat(); 93 | const float *xyz1 = &(xyz1_flat(0)); 94 | auto xyz2_flat = xyz2_tensor.flat(); 95 | const float *xyz2 = &(xyz2_flat(0)); 96 | auto idx_flat = idx_tensor->flat(); 97 | int *idx = &(idx_flat(0)); 98 | auto pts_cnt_flat = pts_cnt_tensor->flat(); 99 | int *pts_cnt = &(pts_cnt_flat(0)); 100 | queryBallPointLauncher(b,n,m,radius_,nsample_,xyz1,xyz2,idx,pts_cnt); 101 | } 102 | private: 103 | float radius_; 104 | int nsample_; 105 | }; 106 | REGISTER_KERNEL_BUILDER(Name("QueryBallPoint").Device(DEVICE_GPU), QueryBallPointGpuOp); 107 | 108 | void selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out); 109 | class SelectionSortGpuOp : public OpKernel { 110 | public: 111 | explicit SelectionSortGpuOp(OpKernelConstruction* context) : OpKernel(context) { 112 | OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); 113 | OP_REQUIRES(context, k_ > 0, errors::InvalidArgument("SelectionSort expects positive k")); 114 | } 115 | 116 | void Compute(OpKernelContext* context) override { 117 | const Tensor& dist_tensor = context->input(0); 118 | OP_REQUIRES(context, dist_tensor.dims()==3, errors::InvalidArgument("SelectionSort expects (b,m,n) dist shape.")); 119 | int b = dist_tensor.shape().dim_size(0); 120 | int m = dist_tensor.shape().dim_size(1); 121 | int n = dist_tensor.shape().dim_size(2); 122 | 123 | Tensor *outi_tensor = nullptr; 124 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,n}, &outi_tensor)); 125 | Tensor *out_tensor = nullptr; 126 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m,n}, &out_tensor)); 127 | 128 | auto dist_flat = dist_tensor.flat(); 129 | const float *dist = &(dist_flat(0)); 130 | auto outi_flat = outi_tensor->flat(); 131 | int *outi = &(outi_flat(0)); 132 | auto out_flat = out_tensor->flat(); 133 | float *out = &(out_flat(0)); 134 | selectionSortLauncher(b,n,m,k_,dist,outi,out); 135 | } 136 | private: 137 | int k_; 138 | }; 139 | REGISTER_KERNEL_BUILDER(Name("SelectionSort").Device(DEVICE_GPU), SelectionSortGpuOp); 140 | 141 | 142 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out); 143 | class GroupPointGpuOp: public OpKernel{ 144 | public: 145 | explicit GroupPointGpuOp(OpKernelConstruction * context):OpKernel(context){} 146 | 147 | void Compute(OpKernelContext * context) override { 148 | const Tensor& points_tensor=context->input(0); 149 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("GroupPoint expects (batch_size, num_points, channel) points shape")); 150 | int b = points_tensor.shape().dim_size(0); 151 | int n = points_tensor.shape().dim_size(1); 152 | int c = points_tensor.shape().dim_size(2); 153 | 154 | const Tensor& idx_tensor=context->input(1); 155 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("GroupPoint expects (batch_size, npoints, nsample) idx shape")); 156 | int m = idx_tensor.shape().dim_size(1); 157 | int nsample = idx_tensor.shape().dim_size(2); 158 | 159 | Tensor * out_tensor = nullptr; 160 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,m,nsample,c}, &out_tensor)); 161 | 162 | auto points_flat = points_tensor.flat(); 163 | const float *points = &(points_flat(0)); 164 | auto idx_flat = idx_tensor.flat(); 165 | const int *idx = &(idx_flat(0)); 166 | auto out_flat = out_tensor->flat(); 167 | float *out = &(out_flat(0)); 168 | groupPointLauncher(b,n,c,m,nsample,points,idx,out); 169 | } 170 | }; 171 | REGISTER_KERNEL_BUILDER(Name("GroupPoint").Device(DEVICE_GPU),GroupPointGpuOp); 172 | 173 | void groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points); 174 | class GroupPointGradGpuOp: public OpKernel{ 175 | public: 176 | explicit GroupPointGradGpuOp(OpKernelConstruction * context):OpKernel(context){} 177 | 178 | void Compute(OpKernelContext * context) override { 179 | const Tensor& points_tensor=context->input(0); 180 | OP_REQUIRES(context, points_tensor.dims()==3, errors::InvalidArgument("GroupPointGrad expects (batch_size, num_points, channel) points shape")); 181 | int b = points_tensor.shape().dim_size(0); 182 | int n = points_tensor.shape().dim_size(1); 183 | int c = points_tensor.shape().dim_size(2); 184 | 185 | const Tensor& idx_tensor=context->input(1); 186 | OP_REQUIRES(context,idx_tensor.dims()==3 && idx_tensor.shape().dim_size(0)==b, errors::InvalidArgument("GroupPointGrad expects (batch_size, npoints, nsample) idx shape")); 187 | int m = idx_tensor.shape().dim_size(1); 188 | int nsample = idx_tensor.shape().dim_size(2); 189 | 190 | const Tensor& grad_out_tensor=context->input(2); 191 | OP_REQUIRES(context,grad_out_tensor.dims()==4 && grad_out_tensor.shape().dim_size(0)==b && grad_out_tensor.shape().dim_size(1)==m && grad_out_tensor.shape().dim_size(2)==nsample && grad_out_tensor.shape().dim_size(3)==c, errors::InvalidArgument("GroupPointGrad expects (batch_size, npoints, nsample, channel) grad_out shape")); 192 | 193 | Tensor * grad_points_tensor = nullptr; 194 | OP_REQUIRES_OK(context, context->allocate_output(0,TensorShape{b,n,c}, &grad_points_tensor)); 195 | 196 | auto points_flat = points_tensor.flat(); 197 | const float *points = &(points_flat(0)); 198 | auto idx_flat = idx_tensor.flat(); 199 | const int *idx = &(idx_flat(0)); 200 | auto grad_out_flat = grad_out_tensor.flat(); 201 | const float *grad_out = &(grad_out_flat(0)); 202 | auto grad_points_flat = grad_points_tensor->flat(); 203 | float *grad_points = &(grad_points_flat(0)); 204 | cudaMemset(grad_points, 0, sizeof(float)*b*n*c); 205 | groupPointGradLauncher(b,n,c,m,nsample,grad_out,idx,grad_points); 206 | } 207 | }; 208 | REGISTER_KERNEL_BUILDER(Name("GroupPointGrad").Device(DEVICE_GPU),GroupPointGradGpuOp); 209 | 210 | 211 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | import sys 4 | import os 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | grouping_module=tf.load_op_library(os.path.join(BASE_DIR, 'tf_grouping_so.so')) 8 | def query_ball_point(radius, nsample, xyz1, xyz2): 9 | ''' 10 | Input: 11 | radius: float32, ball search radius 12 | nsample: int32, number of points selected in each ball region 13 | xyz1: (batch_size, ndataset, 3) float32 array, input points 14 | xyz2: (batch_size, npoint, 3) float32 array, query points 15 | Output: 16 | idx: (batch_size, npoint, nsample) int32 array, indices to input points 17 | pts_cnt: (batch_size, npoint) int32 array, number of unique points in each local region 18 | ''' 19 | #return grouping_module.query_ball_point(radius, nsample, xyz1, xyz2) 20 | return grouping_module.query_ball_point(xyz1, xyz2, radius, nsample) 21 | ops.NoGradient('QueryBallPoint') 22 | def select_top_k(k, dist): 23 | ''' 24 | Input: 25 | k: int32, number of k SMALLEST elements selected 26 | dist: (b,m,n) float32 array, distance matrix, m query points, n dataset points 27 | Output: 28 | idx: (b,m,n) int32 array, first k in n are indices to the top k 29 | dist_out: (b,m,n) float32 array, first k in n are the top k 30 | ''' 31 | return grouping_module.selection_sort(dist, k) 32 | ops.NoGradient('SelectionSort') 33 | def group_point(points, idx): 34 | ''' 35 | Input: 36 | points: (batch_size, ndataset, channel) float32 array, points to sample from 37 | idx: (batch_size, npoint, nsample) int32 array, indices to points 38 | Output: 39 | out: (batch_size, npoint, nsample, channel) float32 array, values sampled from points 40 | ''' 41 | return grouping_module.group_point(points, idx) 42 | @tf.RegisterGradient('GroupPoint') 43 | def _group_point_grad(op, grad_out): 44 | points = op.inputs[0] 45 | idx = op.inputs[1] 46 | return [grouping_module.group_point_grad(points, idx, grad_out), None] 47 | 48 | def knn_point(k, xyz1, xyz2): 49 | ''' 50 | Input: 51 | k: int32, number of k in k-nn search 52 | xyz1: (batch_size, ndataset, c) float32 array, input points 53 | xyz2: (batch_size, npoint, c) float32 array, query points 54 | Output: 55 | val: (batch_size, npoint, k) float32 array, L2 distances 56 | idx: (batch_size, npoint, k) int32 array, indices to input points 57 | ''' 58 | b = xyz1.get_shape()[0].value 59 | n = xyz1.get_shape()[1].value 60 | c = xyz1.get_shape()[2].value 61 | m = xyz2.get_shape()[1].value 62 | print(b, n, c, m) 63 | print(xyz1, (b,1,n,c)) 64 | xyz1 = tf.tile(tf.reshape(xyz1, (b,1,n,c)), [1,m,1,1]) 65 | xyz2 = tf.tile(tf.reshape(xyz2, (b,m,1,c)), [1,1,n,1]) 66 | dist = tf.reduce_sum((xyz1-xyz2)**2, -1) 67 | print(dist, k) 68 | outi, out = select_top_k(k, dist) 69 | idx = tf.slice(outi, [0,0,0], [-1,-1,k]) 70 | val = tf.slice(out, [0,0,0], [-1,-1,k]) 71 | print(idx, val) 72 | #val, idx = tf.nn.top_k(-dist, k=k) # ONLY SUPPORT CPU 73 | return val, idx 74 | 75 | if __name__=='__main__': 76 | knn=True 77 | import numpy as np 78 | import time 79 | np.random.seed(100) 80 | pts = np.random.random((32,512,64)).astype('float32') 81 | tmp1 = np.random.random((32,512,3)).astype('float32') 82 | tmp2 = np.random.random((32,128,3)).astype('float32') 83 | with tf.device('/gpu:1'): 84 | points = tf.constant(pts) 85 | xyz1 = tf.constant(tmp1) 86 | xyz2 = tf.constant(tmp2) 87 | radius = 0.1 88 | nsample = 64 89 | if knn: 90 | _, idx = knn_point(nsample, xyz1, xyz2) 91 | grouped_points = group_point(points, idx) 92 | else: 93 | idx, _ = query_ball_point(radius, nsample, xyz1, xyz2) 94 | grouped_points = group_point(points, idx) 95 | #grouped_points_grad = tf.ones_like(grouped_points) 96 | #points_grad = tf.gradients(grouped_points, points, grouped_points_grad) 97 | with tf.Session('') as sess: 98 | now = time.time() 99 | for _ in range(100): 100 | ret = sess.run(grouped_points) 101 | print(time.time() - now) 102 | print(ret.shape, ret.dtype) 103 | print(ret) 104 | 105 | 106 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DylanWusee/pointconv/f39dc3e101af2f52544181ee20c14f73279b48ae/tf_ops/grouping/tf_grouping.pyc -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping_compile.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | CUDA_PATH=/usr/local/cuda-9.0 3 | $CUDA_PATH/bin/nvcc tf_grouping_g.cu -o tf_grouping_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 4 | 5 | # TF1.2 6 | # g++ -std=c++11 tf_grouping.cpp tf_grouping_g.cu.o -o tf_grouping_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -lcudart -L /usr/local/cuda-8.0/lib64/ -O2 -D_GLIBCXX_USE_CXX11_ABI=0 7 | 8 | # TF1.4 9 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 10 | TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 11 | g++ -std=c++11 tf_grouping.cpp tf_grouping_g.cu.o -o tf_grouping_so.so -shared -fPIC -I $TF_INC -I $CUDA_PATH/include -L$TF_LIB -I$TF_INC/external/nsync/public -lcudart -L $CUDA_PATH/lib64/ -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 12 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping_g.cu: -------------------------------------------------------------------------------- 1 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 2 | // output: idx (b,m,nsample), pts_cnt (b,m) 3 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt) { 4 | int batch_index = blockIdx.x; 5 | xyz1 += n*3*batch_index; 6 | xyz2 += m*3*batch_index; 7 | idx += m*nsample*batch_index; 8 | pts_cnt += m*batch_index; // counting how many unique points selected in local region 9 | 10 | int index = threadIdx.x; 11 | int stride = blockDim.x; 12 | 13 | for (int j=index;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx,pts_cnt); 127 | //cudaDeviceSynchronize(); 128 | } 129 | void selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out) { 130 | selection_sort_gpu<<>>(b,n,m,k,dist,outi,out); 131 | //cudaDeviceSynchronize(); 132 | } 133 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out){ 134 | group_point_gpu<<>>(b,n,c,m,nsample,points,idx,out); 135 | //cudaDeviceSynchronize(); 136 | } 137 | void groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points){ 138 | group_point_grad_gpu<<>>(b,n,c,m,nsample,grad_out,idx,grad_points); 139 | //group_point_grad_gpu<<<1,1>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 140 | //cudaDeviceSynchronize(); 141 | } 142 | -------------------------------------------------------------------------------- /tf_ops/grouping/tf_grouping_op_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tf_grouping import query_ball_point, group_point 4 | 5 | class GroupPointTest(tf.test.TestCase): 6 | def test(self): 7 | pass 8 | 9 | def test_grad(self): 10 | with tf.device('/gpu:0'): 11 | points = tf.constant(np.random.random((1,128,16)).astype('float32')) 12 | print(points) 13 | xyz1 = tf.constant(np.random.random((1,128,3)).astype('float32')) 14 | xyz2 = tf.constant(np.random.random((1,8,3)).astype('float32')) 15 | radius = 0.3 16 | nsample = 32 17 | idx, pts_cnt = query_ball_point(radius, nsample, xyz1, xyz2) 18 | grouped_points = group_point(points, idx) 19 | print(grouped_points) 20 | 21 | with self.test_session(): 22 | print("---- Going to compute gradient error") 23 | err = tf.test.compute_gradient_error(points, (1,128,16), grouped_points, (1,8,32,16)) 24 | print(err) 25 | self.assertLess(err, 1e-4) 26 | 27 | if __name__=='__main__': 28 | tf.test.main() 29 | -------------------------------------------------------------------------------- /tf_ops/sampling/.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.so 3 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling.cpp: -------------------------------------------------------------------------------- 1 | /* Furthest point sampling 2 | * Original author: Haoqiang Fan 3 | * Modified by Charles R. Qi 4 | * All Rights Reserved. 2017. 5 | */ 6 | #include "tensorflow/core/framework/op.h" 7 | #include "tensorflow/core/framework/op_kernel.h" 8 | #include "tensorflow/core/framework/shape_inference.h" 9 | #include "tensorflow/core/framework/common_shape_fns.h" 10 | #include 11 | 12 | using namespace tensorflow; 13 | 14 | REGISTER_OP("ProbSample") 15 | .Input("inp: float32") 16 | .Input("inpr: float32") 17 | .Output("out: int32") 18 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 19 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ncategory 20 | c->WithRank(c->input(0), 2, &dims1); 21 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints 22 | c->WithRank(c->input(1), 2, &dims2); 23 | // batch_size * npoints 24 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)}); 25 | c->set_output(0, output); 26 | return Status::OK(); 27 | }); 28 | REGISTER_OP("FarthestPointSample") 29 | .Attr("npoint: int") 30 | .Input("inp: float32") 31 | .Output("out: int32") 32 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 33 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * npoint * 3 34 | c->WithRank(c->input(0), 3, &dims1); 35 | int npoint; 36 | TF_RETURN_IF_ERROR(c->GetAttr("npoint", &npoint)); 37 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), npoint}); 38 | c->set_output(0, output); 39 | return Status::OK(); 40 | }); 41 | REGISTER_OP("GatherPoint") 42 | .Input("inp: float32") 43 | .Input("idx: int32") 44 | .Output("out: float32") 45 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 46 | ::tensorflow::shape_inference::ShapeHandle dims1; // batch_size * ndataset * 3 47 | c->WithRank(c->input(0), 3, &dims1); 48 | ::tensorflow::shape_inference::ShapeHandle dims2; // batch_size * npoints 49 | c->WithRank(c->input(1), 2, &dims2); 50 | // batch_size * npoints * 3 51 | ::tensorflow::shape_inference::ShapeHandle output = c->MakeShape({c->Dim(dims1, 0), c->Dim(dims2, 1), c->Dim(dims1, 2)}); 52 | c->set_output(0, output); 53 | return Status::OK(); 54 | }); 55 | REGISTER_OP("GatherPointGrad") 56 | .Input("inp: float32") 57 | .Input("idx: int32") 58 | .Input("out_g: float32") 59 | .Output("inp_g: float32") 60 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 61 | c->set_output(0, c->input(0)); 62 | return Status::OK(); 63 | }); 64 | 65 | void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out); 66 | class ProbSampleGpuOp: public OpKernel{ 67 | public: 68 | explicit ProbSampleGpuOp(OpKernelConstruction* context):OpKernel(context){} 69 | void Compute(OpKernelContext * context)override{ 70 | const Tensor& inp_tensor=context->input(0); 71 | const Tensor& inpr_tensor=context->input(1); 72 | auto inp_flat=inp_tensor.flat(); 73 | auto inpr_flat=inpr_tensor.flat(); 74 | const float * inp=&(inp_flat(0)); 75 | const float * inpr=&(inpr_flat(0)); 76 | OP_REQUIRES(context,inp_tensor.dims()==2,errors::InvalidArgument("ProbSample expects (batch_size,num_choices) inp shape")); 77 | int b=inp_tensor.shape().dim_size(0); 78 | int n=inp_tensor.shape().dim_size(1); 79 | OP_REQUIRES(context,inpr_tensor.dims()==2 && inpr_tensor.shape().dim_size(0)==b,errors::InvalidArgument("ProbSample expects (batch_size,num_points) inpr shape")); 80 | int m=inpr_tensor.shape().dim_size(1); 81 | Tensor * out_tensor=NULL; 82 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m},&out_tensor)); 83 | auto out_flat=out_tensor->flat(); 84 | int * out=&(out_flat(0)); 85 | Tensor temp_tensor; 86 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{b,n},&temp_tensor)); 87 | auto temp_flat=temp_tensor.flat(); 88 | float * temp=&(temp_flat(0)); 89 | probsampleLauncher(b,n,m,inp,inpr,temp,out); 90 | } 91 | }; 92 | REGISTER_KERNEL_BUILDER(Name("ProbSample").Device(DEVICE_GPU), ProbSampleGpuOp); 93 | 94 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out); 95 | class FarthestPointSampleGpuOp: public OpKernel{ 96 | public: 97 | explicit FarthestPointSampleGpuOp(OpKernelConstruction* context):OpKernel(context) { 98 | OP_REQUIRES_OK(context, context->GetAttr("npoint", &npoint_)); 99 | OP_REQUIRES(context, npoint_ > 0, errors::InvalidArgument("FarthestPointSample expects positive npoint")); 100 | } 101 | void Compute(OpKernelContext * context)override{ 102 | int m = npoint_; 103 | 104 | const Tensor& inp_tensor=context->input(0); 105 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("FarthestPointSample expects (batch_size,num_points,3) inp shape")); 106 | int b=inp_tensor.shape().dim_size(0); 107 | int n=inp_tensor.shape().dim_size(1); 108 | auto inp_flat=inp_tensor.flat(); 109 | const float * inp=&(inp_flat(0)); 110 | Tensor * out_tensor; 111 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m},&out_tensor)); 112 | auto out_flat=out_tensor->flat(); 113 | int * out=&(out_flat(0)); 114 | Tensor temp_tensor; 115 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{32,n},&temp_tensor)); 116 | auto temp_flat=temp_tensor.flat(); 117 | float * temp=&(temp_flat(0)); 118 | farthestpointsamplingLauncher(b,n,m,inp,temp,out); 119 | } 120 | private: 121 | int npoint_; 122 | }; 123 | REGISTER_KERNEL_BUILDER(Name("FarthestPointSample").Device(DEVICE_GPU),FarthestPointSampleGpuOp); 124 | 125 | void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out); 126 | class GatherPointGpuOp: public OpKernel{ 127 | public: 128 | explicit GatherPointGpuOp(OpKernelConstruction * context):OpKernel(context){} 129 | void Compute(OpKernelContext * context)override{ 130 | const Tensor& inp_tensor=context->input(0); 131 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPoint expects (batch_size,num_points,3) inp shape")); 132 | int b=inp_tensor.shape().dim_size(0); 133 | int n=inp_tensor.shape().dim_size(1); 134 | const Tensor& idx_tensor=context->input(1); 135 | OP_REQUIRES(context,idx_tensor.dims()==2 && idx_tensor.shape().dim_size(0)==b,errors::InvalidArgument("GatherPoint expects (batch_size,num_result) idx shape")); 136 | int m=idx_tensor.shape().dim_size(1); 137 | auto inp_flat=inp_tensor.flat(); 138 | const float * inp=&(inp_flat(0)); 139 | auto idx_flat=idx_tensor.flat(); 140 | const int * idx=&(idx_flat(0)); 141 | Tensor * out_tensor=NULL; 142 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,m,3},&out_tensor)); 143 | auto out_flat=out_tensor->flat(); 144 | float * out=&(out_flat(0)); 145 | gatherpointLauncher(b,n,m,inp,idx,out); 146 | } 147 | }; 148 | REGISTER_KERNEL_BUILDER(Name("GatherPoint").Device(DEVICE_GPU),GatherPointGpuOp); 149 | 150 | void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g); 151 | class GatherPointGradGpuOp: public OpKernel{ 152 | public: 153 | explicit GatherPointGradGpuOp(OpKernelConstruction * context):OpKernel(context){} 154 | void Compute(OpKernelContext * context)override{ 155 | const Tensor& inp_tensor=context->input(0); 156 | OP_REQUIRES(context,inp_tensor.dims()==3 && inp_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_points,3) inp")); 157 | int b=inp_tensor.shape().dim_size(0); 158 | int n=inp_tensor.shape().dim_size(1); 159 | const Tensor& idx_tensor=context->input(1); 160 | OP_REQUIRES(context,idx_tensor.dims()==2 && idx_tensor.shape().dim_size(0)==b,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_result) idx shape")); 161 | int m=idx_tensor.shape().dim_size(1); 162 | auto inp_flat=inp_tensor.flat(); 163 | const float * inp=&(inp_flat(0)); 164 | auto idx_flat=idx_tensor.flat(); 165 | const int * idx=&(idx_flat(0)); 166 | const Tensor& out_g_tensor=context->input(2); 167 | OP_REQUIRES(context,out_g_tensor.dims()==3 && out_g_tensor.shape().dim_size(0)==b && out_g_tensor.shape().dim_size(1)==m && out_g_tensor.shape().dim_size(2)==3,errors::InvalidArgument("GatherPointGradGpuOp expects (batch_size,num_result,3) out_g shape")); 168 | auto out_g_flat=out_g_tensor.flat(); 169 | const float * out_g=&(out_g_flat(0)); 170 | Tensor * inp_g_tensor=NULL; 171 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&inp_g_tensor)); 172 | auto inp_g_flat=inp_g_tensor->flat(); 173 | float * inp_g=&(inp_g_flat(0)); 174 | cudaMemset(inp_g,0,b*n*3*4); 175 | scatteraddpointLauncher(b,n,m,out_g,idx,inp_g); 176 | } 177 | }; 178 | REGISTER_KERNEL_BUILDER(Name("GatherPointGrad").Device(DEVICE_GPU),GatherPointGradGpuOp); 179 | 180 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling.py: -------------------------------------------------------------------------------- 1 | ''' Furthest point sampling 2 | Original author: Haoqiang Fan 3 | Modified by Charles R. Qi 4 | All Rights Reserved. 2017. 5 | ''' 6 | import tensorflow as tf 7 | from tensorflow.python.framework import ops 8 | import sys 9 | import os 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sampling_module=tf.load_op_library(os.path.join(BASE_DIR, 'tf_sampling_so.so')) 13 | def prob_sample(inp,inpr): 14 | ''' 15 | input: 16 | batch_size * ncategory float32 17 | batch_size * npoints float32 18 | returns: 19 | batch_size * npoints int32 20 | ''' 21 | return sampling_module.prob_sample(inp,inpr) 22 | ops.NoGradient('ProbSample') 23 | # TF1.0 API requires set shape in C++ 24 | #@tf.RegisterShape('ProbSample') 25 | #def _prob_sample_shape(op): 26 | # shape1=op.inputs[0].get_shape().with_rank(2) 27 | # shape2=op.inputs[1].get_shape().with_rank(2) 28 | # return [tf.TensorShape([shape2.dims[0],shape2.dims[1]])] 29 | def gather_point(inp,idx): 30 | ''' 31 | input: 32 | batch_size * ndataset * 3 float32 33 | batch_size * npoints int32 34 | returns: 35 | batch_size * npoints * 3 float32 36 | ''' 37 | return sampling_module.gather_point(inp,idx) 38 | #@tf.RegisterShape('GatherPoint') 39 | #def _gather_point_shape(op): 40 | # shape1=op.inputs[0].get_shape().with_rank(3) 41 | # shape2=op.inputs[1].get_shape().with_rank(2) 42 | # return [tf.TensorShape([shape1.dims[0],shape2.dims[1],shape1.dims[2]])] 43 | @tf.RegisterGradient('GatherPoint') 44 | def _gather_point_grad(op,out_g): 45 | inp=op.inputs[0] 46 | idx=op.inputs[1] 47 | return [sampling_module.gather_point_grad(inp,idx,out_g),None] 48 | def farthest_point_sample(npoint,inp): 49 | ''' 50 | input: 51 | int32 52 | batch_size * ndataset * 3 float32 53 | returns: 54 | batch_size * npoint int32 55 | ''' 56 | return sampling_module.farthest_point_sample(inp, npoint) 57 | ops.NoGradient('FarthestPointSample') 58 | 59 | 60 | if __name__=='__main__': 61 | import numpy as np 62 | np.random.seed(100) 63 | triangles=np.random.rand(1,5,3,3).astype('float32') 64 | with tf.device('/gpu:1'): 65 | inp=tf.constant(triangles) 66 | tria=inp[:,:,0,:] 67 | trib=inp[:,:,1,:] 68 | tric=inp[:,:,2,:] 69 | areas=tf.sqrt(tf.reduce_sum(tf.cross(trib-tria,tric-tria)**2,2)+1e-9) 70 | randomnumbers=tf.random_uniform((1,8192)) 71 | triids=prob_sample(areas,randomnumbers) 72 | tria_sample=gather_point(tria,triids) 73 | trib_sample=gather_point(trib,triids) 74 | tric_sample=gather_point(tric,triids) 75 | us=tf.random_uniform((1,8192)) 76 | vs=tf.random_uniform((1,8192)) 77 | uplusv=1-tf.abs(us+vs-1) 78 | uminusv=us-vs 79 | us=(uplusv+uminusv)*0.5 80 | vs=(uplusv-uminusv)*0.5 81 | pt_sample=tria_sample+(trib_sample-tria_sample)*tf.expand_dims(us,-1)+(tric_sample-tria_sample)*tf.expand_dims(vs,-1) 82 | print('pt_sample: ', pt_sample) 83 | reduced_sample=gather_point(pt_sample,farthest_point_sample(1024,pt_sample)) 84 | print(reduced_sample) 85 | with tf.Session('') as sess: 86 | ret=sess.run(reduced_sample) 87 | print(ret.shape,ret.dtype) 88 | import cPickle as pickle 89 | pickle.dump(ret,open('1.pkl','wb'),-1) 90 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DylanWusee/pointconv/f39dc3e101af2f52544181ee20c14f73279b48ae/tf_ops/sampling/tf_sampling.pyc -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling_compile.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | CUDA_PATH=/usr/local/cuda-9.0 3 | $CUDA_PATH/bin/nvcc tf_sampling_g.cu -o tf_sampling_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 4 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 5 | TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 6 | g++ -std=c++11 tf_sampling.cpp tf_sampling_g.cu.o -o tf_sampling_so.so -shared -fPIC -I $TF_INC -I $CUDA_PATH/include -lcudart -L $CUDA_PATH/lib64/ -L$TF_LIB -I$TF_INC/external/nsync/public -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 7 | -------------------------------------------------------------------------------- /tf_ops/sampling/tf_sampling_g.cu: -------------------------------------------------------------------------------- 1 | /* Furthest point sampling GPU implementation 2 | * Original author: Haoqiang Fan 3 | * Modified by Charles R. Qi 4 | * All Rights Reserved. 2017. 5 | */ 6 | 7 | __global__ void cumsumKernel(int b,int n,const float * __restrict__ inp,float * __restrict__ out){ 8 | const int BlockSize=2048; 9 | const int paddingLevel=5; 10 | __shared__ float buffer4[BlockSize*4]; 11 | __shared__ float buffer[BlockSize+(BlockSize>>paddingLevel)]; 12 | for (int i=blockIdx.x;i>2; 18 | for (int k=threadIdx.x*4;k>2)+(k>>(2+paddingLevel))]=v4; 33 | }else{ 34 | float v=0; 35 | for (int k2=k;k2>2)+(k>>(2+paddingLevel))]=v; 43 | } 44 | } 45 | int u=0; 46 | for (;(2<>(u+1));k+=blockDim.x){ 49 | int i1=(((k<<1)+2)<>paddingLevel; 52 | i2+=i2>>paddingLevel; 53 | buffer[i1]+=buffer[i2]; 54 | } 55 | } 56 | u--; 57 | for (;u>=0;u--){ 58 | __syncthreads(); 59 | for (int k=threadIdx.x;k>(u+1));k+=blockDim.x){ 60 | int i1=(((k<<1)+3)<>paddingLevel; 63 | i2+=i2>>paddingLevel; 64 | buffer[i1]+=buffer[i2]; 65 | } 66 | } 67 | __syncthreads(); 68 | for (int k=threadIdx.x*4;k>2)-1)+(((k>>2)-1)>>paddingLevel); 71 | buffer4[k]+=buffer[k2]; 72 | buffer4[k+1]+=buffer[k2]; 73 | buffer4[k+2]+=buffer[k2]; 74 | buffer4[k+3]+=buffer[k2]; 75 | } 76 | } 77 | __syncthreads(); 78 | for (int k=threadIdx.x;k>paddingLevel)]+runningsum2; 82 | float r2=runningsum+t; 83 | runningsum2=t-(r2-runningsum); 84 | runningsum=r2; 85 | __syncthreads(); 86 | } 87 | } 88 | } 89 | 90 | __global__ void binarysearchKernel(int b,int n,int m,const float * __restrict__ dataset,const float * __restrict__ query, int * __restrict__ result){ 91 | int base=1; 92 | while (base=1;k>>=1) 99 | if (r>=k && dataset[i*n+r-k]>=q) 100 | r-=k; 101 | result[i*m+j]=r; 102 | } 103 | } 104 | } 105 | __global__ void farthestpointsamplingKernel(int b,int n,int m,const float * __restrict__ dataset,float * __restrict__ temp,int * __restrict__ idxs){ 106 | if (m<=0) 107 | return; 108 | const int BlockSize=512; 109 | __shared__ float dists[BlockSize]; 110 | __shared__ int dists_i[BlockSize]; 111 | const int BufferSize=3072; 112 | __shared__ float buf[BufferSize*3]; 113 | for (int i=blockIdx.x;ibest){ 147 | best=d2; 148 | besti=k; 149 | } 150 | } 151 | dists[threadIdx.x]=best; 152 | dists_i[threadIdx.x]=besti; 153 | for (int u=0;(1<>(u+1))){ 156 | int i1=(threadIdx.x*2)<>>(b,n,inp,out); 196 | } 197 | //require b*n working space 198 | void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out){ 199 | cumsumKernel<<<32,512>>>(b,n,inp_p,temp); 200 | binarysearchKernel<<>>(b,n,m,temp,inp_r,out); 201 | } 202 | //require 32*n working space 203 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out){ 204 | farthestpointsamplingKernel<<<32,512>>>(b,n,m,inp,temp,out); 205 | } 206 | void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out){ 207 | gatherpointKernel<<>>(b,n,m,inp,idx,out); 208 | } 209 | void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g){ 210 | scatteraddpointKernel<<>>(b,n,m,out_g,idx,inp_g); 211 | } 212 | 213 | -------------------------------------------------------------------------------- /utils/pointconv_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper Function for PointConv 3 | Author: Wenxuan Wu 4 | Date: July 2018 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import math 11 | import random 12 | import numpy as np 13 | import tensorflow as tf 14 | from transforms3d.euler import euler2mat 15 | import os 16 | import sys 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(os.path.join(BASE_DIR, '../tf_ops/sampling')) 19 | sys.path.append(os.path.join(BASE_DIR, '../tf_ops/grouping')) 20 | import tf_sampling 21 | import tf_grouping 22 | from sklearn.neighbors import KDTree 23 | 24 | def knn_kdtree(nsample, xyz, new_xyz): 25 | batch_size = xyz.shape[0] 26 | n_points = new_xyz.shape[1] 27 | 28 | indices = np.zeros((batch_size, n_points, nsample), dtype=np.int32) 29 | for batch_idx in range(batch_size): 30 | X = xyz[batch_idx, ...] 31 | q_X = new_xyz[batch_idx, ...] 32 | kdt = KDTree(X, leaf_size=30) 33 | _, indices[batch_idx] = kdt.query(q_X, k = nsample) 34 | 35 | return indices 36 | 37 | def kernel_density_estimation_ball(pts, radius, sigma, N_points = 128, is_norm = False): 38 | with tf.variable_scope("ComputeDensity") as sc: 39 | idx, pts_cnt = tf_grouping.query_ball_point(radius, N_points, pts, pts) 40 | g_pts = tf_grouping.group_point(pts, idx) 41 | g_pts -= tf.tile(tf.expand_dims(pts, 2), [1, 1, N_points, 1]) 42 | 43 | R = tf.sqrt(sigma) 44 | xRinv = tf.div(g_pts, R) 45 | quadform = tf.reduce_sum(tf.square(xRinv), axis = -1) 46 | logsqrtdetSigma = tf.log(R) * 3 47 | mvnpdf = tf.exp(-0.5 * quadform - logsqrtdetSigma - 3 * tf.log(2 * 3.1415926) / 2) 48 | 49 | first_val, _ = tf.split(mvnpdf, [1, N_points - 1], axis = 2) 50 | 51 | mvnpdf = tf.reduce_sum(mvnpdf, axis = 2, keepdims = True) 52 | 53 | num_val_to_sub = tf.expand_dims(tf.cast(tf.subtract(N_points, pts_cnt), dtype = tf.float32), axis = -1) 54 | 55 | val_to_sub = tf.multiply(first_val, num_val_to_sub) 56 | 57 | mvnpdf = tf.subtract(mvnpdf, val_to_sub) 58 | 59 | scale = tf.div(1.0, tf.expand_dims(tf.cast(pts_cnt, dtype = tf.float32), axis = -1)) 60 | density = tf.multiply(mvnpdf, scale) 61 | 62 | if is_norm: 63 | #grouped_xyz_sum = tf.reduce_sum(grouped_xyz, axis = 1, keepdims = True) 64 | density_max = tf.reduce_max(density, axis = 1, keepdims = True) 65 | density = tf.div(density, density_max) 66 | 67 | return density 68 | 69 | def kernel_density_estimation(pts, sigma, kpoint = 32, is_norm = False): 70 | with tf.variable_scope("ComputeDensity") as sc: 71 | batch_size = pts.get_shape()[0] 72 | num_points = pts.get_shape()[1] 73 | if num_points < kpoint: 74 | kpoint = num_points.value - 1 75 | with tf.device('/cpu:0'): 76 | point_indices = tf.py_func(knn_kdtree, [kpoint, pts, pts], tf.int32) 77 | batch_indices = tf.tile(tf.reshape(tf.range(batch_size), (-1, 1, 1, 1)), (1, num_points, kpoint, 1)) 78 | idx = tf.concat([batch_indices, tf.expand_dims(point_indices, axis = 3)], axis = 3) 79 | idx.set_shape([batch_size, num_points, kpoint, 2]) 80 | 81 | grouped_pts = tf.gather_nd(pts, idx) 82 | grouped_pts -= tf.tile(tf.expand_dims(pts, 2), [1,1,kpoint,1]) # translation normalization 83 | 84 | R = tf.sqrt(sigma) 85 | xRinv = tf.div(grouped_pts, R) 86 | quadform = tf.reduce_sum(tf.square(xRinv), axis = -1) 87 | logsqrtdetSigma = tf.log(R) * 3 88 | mvnpdf = tf.exp(-0.5 * quadform - logsqrtdetSigma - 3 * tf.log(2 * 3.1415926) / 2) 89 | mvnpdf = tf.reduce_sum(mvnpdf, axis = 2, keepdims = True) 90 | 91 | scale = 1.0 / kpoint 92 | density = tf.multiply(mvnpdf, scale) 93 | 94 | if is_norm: 95 | #grouped_xyz_sum = tf.reduce_sum(grouped_xyz, axis = 1, keepdims = True) 96 | density_max = tf.reduce_max(density, axis = 1, keepdims = True) 97 | density = tf.div(density, density_max) 98 | 99 | return density 100 | 101 | def sampling(npoint, pts): 102 | ''' 103 | inputs: 104 | npoint: scalar, number of points to sample 105 | pointcloud: B * N * 3, input point cloud 106 | output: 107 | sub_pts: B * npoint * 3, sub-sampled point cloud 108 | ''' 109 | 110 | sub_pts = tf_sampling.gather_point(pts, tf_sampling.farthest_point_sample(npoint, pts)) 111 | return sub_pts 112 | 113 | def grouping(feature, K, src_xyz, q_xyz, use_xyz = True): 114 | ''' 115 | K: neighbor size 116 | src_xyz: original point xyz (batch_size, ndataset, 3) 117 | q_xyz: query point xyz (batch_size, npoint, 3) 118 | ''' 119 | 120 | batch_size = src_xyz.get_shape()[0] 121 | npoint = q_xyz.get_shape()[1] 122 | 123 | point_indices = tf.py_func(knn_kdtree, [K, src_xyz, q_xyz], tf.int32) 124 | batch_indices = tf.tile(tf.reshape(tf.range(batch_size), (-1, 1, 1, 1)), (1, npoint, K, 1)) 125 | idx = tf.concat([batch_indices, tf.expand_dims(point_indices, axis = 3)], axis = 3) 126 | idx.set_shape([batch_size, npoint, K, 2]) 127 | 128 | grouped_xyz = tf.gather_nd(src_xyz, idx) 129 | grouped_xyz -= tf.tile(tf.expand_dims(q_xyz, 2), [1,1,K,1]) # translation normalization 130 | 131 | grouped_feature = tf.gather_nd(feature, idx) 132 | if use_xyz: 133 | new_points = tf.concat([grouped_xyz, grouped_feature], axis = -1) 134 | else: 135 | new_points = grouped_feature 136 | 137 | return grouped_xyz, new_points, idx 138 | 139 | if __name__=='__main__': 140 | #test KDE 141 | import time 142 | batch_size = 8 143 | num_point = 8192 144 | pts = np.random.randn(batch_size, num_point, 3).astype('float32') 145 | 146 | import pdb 147 | pdb.set_trace() 148 | 149 | with tf.device('/gpu:1'): 150 | points = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 151 | density = kernel_density_estimation_ball(points, 1.0) 152 | #density = kernel_density_estimation(points, 1.0) 153 | init = tf.global_variables_initializer() 154 | with tf.Session('') as sess: 155 | 156 | sess.run(init) 157 | t1 = time.time() 158 | den = sess.run(density, feed_dict = {points:pts}) 159 | 160 | print(time.time() - t1) 161 | 162 | #import scipy.io as sio 163 | 164 | #sio.savemat('density.mat', dict([('pts', pts), ('density', den)])) 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /utils/provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import h5py 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | 8 | def shuffle_points(batch_data): 9 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 10 | Use the same shuffling idx for the entire batch. 11 | Input: 12 | BxNxC array 13 | Output: 14 | BxNxC array 15 | """ 16 | idx = np.arange(batch_data.shape[1]) 17 | np.random.shuffle(idx) 18 | return batch_data[:,idx,:] 19 | 20 | def shuffle_data(data, labels): 21 | """ Shuffle data and labels. 22 | Input: 23 | data: B,N,... numpy array 24 | label: B,... numpy array 25 | Return: 26 | shuffled data, label and shuffle indices 27 | """ 28 | idx = np.arange(len(labels)) 29 | np.random.shuffle(idx) 30 | return data[idx, ...], labels[idx], idx 31 | 32 | 33 | def rotate_point_cloud(batch_data): 34 | """ Randomly rotate the point clouds to augument the dataset 35 | rotation is per shape based along up direction 36 | Input: 37 | BxNx3 array, original batch of point clouds 38 | Return: 39 | BxNx3 array, rotated batch of point clouds 40 | """ 41 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 42 | for k in range(batch_data.shape[0]): 43 | rotation_angle = np.random.uniform() * 2 * np.pi 44 | cosval = np.cos(rotation_angle) 45 | sinval = np.sin(rotation_angle) 46 | rotation_matrix = np.array([[cosval, 0, sinval], 47 | [0, 1, 0], 48 | [-sinval, 0, cosval]]) 49 | shape_pc = batch_data[k, ...] 50 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 51 | return rotated_data 52 | 53 | def rotate_point_cloud_z(batch_data): 54 | """ Randomly rotate the point clouds to augument the dataset 55 | rotation is per shape based along z direction 56 | Input: 57 | BxNx3 array, original batch of point clouds 58 | Return: 59 | BxNx3 array, rotated batch of point clouds 60 | """ 61 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 62 | for k in range(batch_data.shape[0]): 63 | rotation_angle = np.random.uniform() * 2 * np.pi 64 | cosval = np.cos(rotation_angle) 65 | sinval = np.sin(rotation_angle) 66 | rotation_matrix = np.array([[cosval, -sinval, 0], 67 | [sinval, cosval, 0], 68 | [0, 0, 1]]) 69 | shape_pc = batch_data[k, ...] 70 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 71 | return rotated_data 72 | 73 | 74 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 75 | """ Rotate the point cloud along up direction with certain angle. 76 | Input: 77 | BxNx3 array, original batch of point clouds 78 | Return: 79 | BxNx3 array, rotated batch of point clouds 80 | """ 81 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 82 | for k in range(batch_data.shape[0]): 83 | #rotation_angle = np.random.uniform() * 2 * np.pi 84 | cosval = np.cos(rotation_angle) 85 | sinval = np.sin(rotation_angle) 86 | rotation_matrix = np.array([[cosval, 0, sinval], 87 | [0, 1, 0], 88 | [-sinval, 0, cosval]]) 89 | shape_pc = batch_data[k, ...] 90 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 91 | return rotated_data 92 | 93 | 94 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 95 | """ Randomly perturb the point clouds by small rotations 96 | Input: 97 | BxNx3 array, original batch of point clouds 98 | Return: 99 | BxNx3 array, rotated batch of point clouds 100 | """ 101 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 102 | for k in range(batch_data.shape[0]): 103 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 104 | Rx = np.array([[1,0,0], 105 | [0,np.cos(angles[0]),-np.sin(angles[0])], 106 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 107 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 108 | [0,1,0], 109 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 110 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 111 | [np.sin(angles[2]),np.cos(angles[2]),0], 112 | [0,0,1]]) 113 | R = np.dot(Rz, np.dot(Ry,Rx)) 114 | shape_pc = batch_data[k, ...] 115 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 116 | return rotated_data 117 | 118 | def random_jitter_rgb(batch_data, r = 2.5): 119 | B, N, C = batch_data.shape 120 | assert(r >= 0) 121 | r = 2.5 / 255 122 | jittered_data = 2 * r * (np.random.uniform(size = (B, N, C)) - 0.5) / 255 123 | jittered_data += batch_data 124 | return jittered_data 125 | 126 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 127 | """ Randomly jitter points. jittering is per point. 128 | Input: 129 | BxNx3 array, original batch of point clouds 130 | Return: 131 | BxNx3 array, jittered batch of point clouds 132 | """ 133 | B, N, C = batch_data.shape 134 | assert(clip > 0) 135 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 136 | jittered_data += batch_data 137 | return jittered_data 138 | 139 | def shift_point_cloud(batch_data, shift_range=0.1): 140 | """ Randomly shift point cloud. Shift is per point cloud. 141 | Input: 142 | BxNx3 array, original batch of point clouds 143 | Return: 144 | BxNx3 array, shifted batch of point clouds 145 | """ 146 | B, N, C = batch_data.shape 147 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 148 | for batch_index in range(B): 149 | batch_data[batch_index,:,:] += shifts[batch_index,:] 150 | return batch_data 151 | 152 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 153 | """ Randomly scale the point cloud. Scale is per point cloud. 154 | Input: 155 | BxNx3 array, original batch of point clouds 156 | Return: 157 | BxNx3 array, scaled batch of point clouds 158 | """ 159 | B, N, C = batch_data.shape 160 | scales = np.random.uniform(scale_low, scale_high, B) 161 | for batch_index in range(B): 162 | batch_data[batch_index,:,:] *= scales[batch_index] 163 | return batch_data 164 | 165 | def getDataFiles(list_filename): 166 | return [line.rstrip() for line in open(list_filename)] 167 | 168 | def load_h5(h5_filename): 169 | f = h5py.File(h5_filename) 170 | data = f['data'][:] 171 | label = f['label'][:] 172 | return (data, label) 173 | 174 | def loadDataFile(filename): 175 | return load_h5(filename) 176 | --------------------------------------------------------------------------------